Skip to content
Snippets Groups Projects
Commit aa2c89d8 authored by Stephan Seitz's avatar Stephan Seitz Committed by Martin Bauer
Browse files

More robust parsing of ps.fields() string

parent ce39e21f
No related merge requests found
......@@ -3,6 +3,7 @@ from itertools import chain
from typing import Tuple, Sequence, Optional, List, Set
import numpy as np
import sympy as sp
import re
from sympy.core.cache import cacheit
from pystencils.alignedarray import aligned_empty
from pystencils.data_types import create_type, StructType
......@@ -730,53 +731,68 @@ def compute_strides(shape, layout):
return tuple(strides)
def _parse_type_description(type_description):
if not type_description:
return np.float64, None
elif '[' in type_description:
assert type_description[-1] == ']'
type_part, size_part = type_description[:-1].split("[", )
if not type_part:
type_part = "float64"
if size_part.lower()[-1] == 'd':
size_part = int(size_part[:-1])
else:
size_part = tuple(int(i) for i in size_part.split(','))
else:
type_part, size_part = type_description, None
dtype = np.dtype(type_part).type
return dtype, size_part
def _parse_field_description(description):
if '(' not in description:
return description, ()
assert description[-1] == ')'
name, index_shape = description[:-1].split('(')
index_shape = tuple(int(i) for i in index_shape.split(','))
return name, index_shape
# ---------------------------------------- Parsing of string in fields() function --------------------------------------
field_description_regex = re.compile(r"""
\s* # ignore leading white spaces
(\w+) # identifier is a sequence of alphanumeric characters, is stored in first group
(?: # optional index specification e.g. (1, 4, 2)
\s*
\(
([^\)]+) # read everything up to closing bracket
\)
\s*
)?
\s*,?\s* # ignore trailing white spaces and comma
""", re.VERBOSE)
type_description_regex = re.compile(r"""
\s*
(\w+)? # optional dtype
\s*
\[
([^\]]+)
\]
\s*
""", re.VERBOSE | re.IGNORECASE)
def _parse_description(description):
description = description.replace(' ', '')
def parse_part1(d):
result = field_description_regex.match(d)
while result:
name, index_str = result[1], result[2]
index = tuple(int(e) for e in index_str.split(",")) if index_str else ()
yield name, index
d = d[result.end():]
result = field_description_regex.match(d)
def parse_part2(d):
result = type_description_regex.match(d)
if result:
data_type_str, size_info = result[1], result[2].strip().lower()
if data_type_str is None:
data_type_str = 'float64'
data_type_str = data_type_str.lower().strip()
if not data_type_str:
data_type_str = 'float64'
if size_info.endswith('d'):
size_info = int(size_info[:-1])
else:
size_info = tuple(int(e) for e in size_info.split(","))
return data_type_str, size_info
else:
raise ValueError("Could not parse field description")
if ':' in description:
name_descr, type_descr = description.split(':')
field_description, field_info = description.split(':')
else:
name_descr, type_descr = description, ''
# correct ',' splits inside brackets
field_names = name_descr.split(',')
cleaned_field_names = []
prefix = ''
for field_name in field_names:
full_field_name = prefix + field_name
if '(' in full_field_name and ')' not in full_field_name:
prefix += field_name + ','
else:
prefix = ''
cleaned_field_names.append(full_field_name)
field_description, field_info = description, 'float64[2D]'
fields_info = [e for e in parse_part1(field_description)]
if not field_info:
raise ValueError("Could not parse field description")
dtype, size = _parse_type_description(type_descr)
fields_info = tuple(_parse_field_description(fd) for fd in cleaned_field_names)
return fields_info, dtype, size
data_type, size = parse_part2(field_info)
return fields_info, data_type, size
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment