From aa2c89d8b9fd27966d476d10f066bfbd276838d3 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Thu, 22 Nov 2018 14:23:09 +0100 Subject: [PATCH] More robust parsing of ps.fields() string --- field.py | 104 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 60 insertions(+), 44 deletions(-) diff --git a/field.py b/field.py index 20501752a..03bcb0fec 100644 --- a/field.py +++ b/field.py @@ -3,6 +3,7 @@ from itertools import chain from typing import Tuple, Sequence, Optional, List, Set import numpy as np import sympy as sp +import re from sympy.core.cache import cacheit from pystencils.alignedarray import aligned_empty from pystencils.data_types import create_type, StructType @@ -730,53 +731,68 @@ def compute_strides(shape, layout): return tuple(strides) -def _parse_type_description(type_description): - if not type_description: - return np.float64, None - elif '[' in type_description: - assert type_description[-1] == ']' - type_part, size_part = type_description[:-1].split("[", ) - if not type_part: - type_part = "float64" - if size_part.lower()[-1] == 'd': - size_part = int(size_part[:-1]) - else: - size_part = tuple(int(i) for i in size_part.split(',')) - else: - type_part, size_part = type_description, None - - dtype = np.dtype(type_part).type - return dtype, size_part - - -def _parse_field_description(description): - if '(' not in description: - return description, () - assert description[-1] == ')' - name, index_shape = description[:-1].split('(') - index_shape = tuple(int(i) for i in index_shape.split(',')) - return name, index_shape +# ---------------------------------------- Parsing of string in fields() function -------------------------------------- + +field_description_regex = re.compile(r""" + \s* # ignore leading white spaces + (\w+) # identifier is a sequence of alphanumeric characters, is stored in first group + (?: # optional index specification e.g. (1, 4, 2) + \s* + \( + ([^\)]+) # read everything up to closing bracket + \) + \s* + )? + \s*,?\s* # ignore trailing white spaces and comma +""", re.VERBOSE) + +type_description_regex = re.compile(r""" + \s* + (\w+)? # optional dtype + \s* + \[ + ([^\]]+) + \] + \s* +""", re.VERBOSE | re.IGNORECASE) def _parse_description(description): - description = description.replace(' ', '') + def parse_part1(d): + result = field_description_regex.match(d) + while result: + name, index_str = result[1], result[2] + index = tuple(int(e) for e in index_str.split(",")) if index_str else () + yield name, index + d = d[result.end():] + result = field_description_regex.match(d) + + def parse_part2(d): + result = type_description_regex.match(d) + if result: + data_type_str, size_info = result[1], result[2].strip().lower() + if data_type_str is None: + data_type_str = 'float64' + data_type_str = data_type_str.lower().strip() + + if not data_type_str: + data_type_str = 'float64' + if size_info.endswith('d'): + size_info = int(size_info[:-1]) + else: + size_info = tuple(int(e) for e in size_info.split(",")) + return data_type_str, size_info + else: + raise ValueError("Could not parse field description") + if ':' in description: - name_descr, type_descr = description.split(':') + field_description, field_info = description.split(':') else: - name_descr, type_descr = description, '' - - # correct ',' splits inside brackets - field_names = name_descr.split(',') - cleaned_field_names = [] - prefix = '' - for field_name in field_names: - full_field_name = prefix + field_name - if '(' in full_field_name and ')' not in full_field_name: - prefix += field_name + ',' - else: - prefix = '' - cleaned_field_names.append(full_field_name) + field_description, field_info = description, 'float64[2D]' + + fields_info = [e for e in parse_part1(field_description)] + if not field_info: + raise ValueError("Could not parse field description") - dtype, size = _parse_type_description(type_descr) - fields_info = tuple(_parse_field_description(fd) for fd in cleaned_field_names) - return fields_info, dtype, size + data_type, size = parse_part2(field_info) + return fields_info, data_type, size -- GitLab