Commit aa2c89d8 authored by Stephan Seitz's avatar Stephan Seitz Committed by Martin Bauer
Browse files

More robust parsing of ps.fields() string

parent ce39e21f
......@@ -3,6 +3,7 @@ from itertools import chain
from typing import Tuple, Sequence, Optional, List, Set
import numpy as np
import sympy as sp
import re
from sympy.core.cache import cacheit
from pystencils.alignedarray import aligned_empty
from pystencils.data_types import create_type, StructType
......@@ -730,53 +731,68 @@ def compute_strides(shape, layout):
return tuple(strides)
def _parse_type_description(type_description):
if not type_description:
return np.float64, None
elif '[' in type_description:
assert type_description[-1] == ']'
type_part, size_part = type_description[:-1].split("[", )
if not type_part:
type_part = "float64"
if size_part.lower()[-1] == 'd':
size_part = int(size_part[:-1])
else:
size_part = tuple(int(i) for i in size_part.split(','))
else:
type_part, size_part = type_description, None
dtype = np.dtype(type_part).type
return dtype, size_part
def _parse_field_description(description):
if '(' not in description:
return description, ()
assert description[-1] == ')'
name, index_shape = description[:-1].split('(')
index_shape = tuple(int(i) for i in index_shape.split(','))
return name, index_shape
# ---------------------------------------- Parsing of string in fields() function --------------------------------------
field_description_regex = re.compile(r"""
\s* # ignore leading white spaces
(\w+) # identifier is a sequence of alphanumeric characters, is stored in first group
(?: # optional index specification e.g. (1, 4, 2)
\s*
\(
([^\)]+) # read everything up to closing bracket
\)
\s*
)?
\s*,?\s* # ignore trailing white spaces and comma
""", re.VERBOSE)
type_description_regex = re.compile(r"""
\s*
(\w+)? # optional dtype
\s*
\[
([^\]]+)
\]
\s*
""", re.VERBOSE | re.IGNORECASE)
def _parse_description(description):
description = description.replace(' ', '')
def parse_part1(d):
result = field_description_regex.match(d)
while result:
name, index_str = result[1], result[2]
index = tuple(int(e) for e in index_str.split(",")) if index_str else ()
yield name, index
d = d[result.end():]
result = field_description_regex.match(d)
def parse_part2(d):
result = type_description_regex.match(d)
if result:
data_type_str, size_info = result[1], result[2].strip().lower()
if data_type_str is None:
data_type_str = 'float64'
data_type_str = data_type_str.lower().strip()
if not data_type_str:
data_type_str = 'float64'
if size_info.endswith('d'):
size_info = int(size_info[:-1])
else:
size_info = tuple(int(e) for e in size_info.split(","))
return data_type_str, size_info
else:
raise ValueError("Could not parse field description")
if ':' in description:
name_descr, type_descr = description.split(':')
field_description, field_info = description.split(':')
else:
name_descr, type_descr = description, ''
# correct ',' splits inside brackets
field_names = name_descr.split(',')
cleaned_field_names = []
prefix = ''
for field_name in field_names:
full_field_name = prefix + field_name
if '(' in full_field_name and ')' not in full_field_name:
prefix += field_name + ','
else:
prefix = ''
cleaned_field_names.append(full_field_name)
field_description, field_info = description, 'float64[2D]'
fields_info = [e for e in parse_part1(field_description)]
if not field_info:
raise ValueError("Could not parse field description")
dtype, size = _parse_type_description(type_descr)
fields_info = tuple(_parse_field_description(fd) for fd in cleaned_field_names)
return fields_info, dtype, size
data_type, size = parse_part2(field_info)
return fields_info, data_type, size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment