From aa2c89d8b9fd27966d476d10f066bfbd276838d3 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 22 Nov 2018 14:23:09 +0100
Subject: [PATCH] More robust parsing of ps.fields() string

---
 field.py | 104 ++++++++++++++++++++++++++++++++-----------------------
 1 file changed, 60 insertions(+), 44 deletions(-)

diff --git a/field.py b/field.py
index 20501752a..03bcb0fec 100644
--- a/field.py
+++ b/field.py
@@ -3,6 +3,7 @@ from itertools import chain
 from typing import Tuple, Sequence, Optional, List, Set
 import numpy as np
 import sympy as sp
+import re
 from sympy.core.cache import cacheit
 from pystencils.alignedarray import aligned_empty
 from pystencils.data_types import create_type, StructType
@@ -730,53 +731,68 @@ def compute_strides(shape, layout):
     return tuple(strides)
 
 
-def _parse_type_description(type_description):
-    if not type_description:
-        return np.float64, None
-    elif '[' in type_description:
-        assert type_description[-1] == ']'
-        type_part, size_part = type_description[:-1].split("[", )
-        if not type_part:
-            type_part = "float64"
-        if size_part.lower()[-1] == 'd':
-            size_part = int(size_part[:-1])
-        else:
-            size_part = tuple(int(i) for i in size_part.split(','))
-    else:
-        type_part, size_part = type_description, None
-
-    dtype = np.dtype(type_part).type
-    return dtype, size_part
-
-
-def _parse_field_description(description):
-    if '(' not in description:
-        return description, ()
-    assert description[-1] == ')'
-    name, index_shape = description[:-1].split('(')
-    index_shape = tuple(int(i) for i in index_shape.split(','))
-    return name, index_shape
+# ---------------------------------------- Parsing of string in fields() function --------------------------------------
+
+field_description_regex = re.compile(r"""
+    \s*                 # ignore leading white spaces
+    (\w+)               # identifier is a sequence of alphanumeric characters, is stored in first group
+    (?:                 # optional index specification e.g. (1, 4, 2)
+        \s*
+        \(
+            ([^\)]+)    # read everything up to closing bracket
+        \)
+        \s*
+    )?
+    \s*,?\s*             # ignore trailing white spaces and comma
+""", re.VERBOSE)
+
+type_description_regex = re.compile(r"""
+    \s*
+    (\w+)?       # optional dtype
+    \s*
+    \[
+        ([^\]]+)
+    \]
+    \s*
+""", re.VERBOSE | re.IGNORECASE)
 
 
 def _parse_description(description):
-    description = description.replace(' ', '')
+    def parse_part1(d):
+        result = field_description_regex.match(d)
+        while result:
+            name, index_str = result[1], result[2]
+            index = tuple(int(e) for e in index_str.split(",")) if index_str else ()
+            yield name, index
+            d = d[result.end():]
+            result = field_description_regex.match(d)
+
+    def parse_part2(d):
+        result = type_description_regex.match(d)
+        if result:
+            data_type_str, size_info = result[1], result[2].strip().lower()
+            if data_type_str is None:
+                data_type_str = 'float64'
+            data_type_str = data_type_str.lower().strip()
+
+            if not data_type_str:
+                data_type_str = 'float64'
+            if size_info.endswith('d'):
+                size_info = int(size_info[:-1])
+            else:
+                size_info = tuple(int(e) for e in size_info.split(","))
+            return data_type_str, size_info
+        else:
+            raise ValueError("Could not parse field description")
+
     if ':' in description:
-        name_descr, type_descr = description.split(':')
+        field_description, field_info = description.split(':')
     else:
-        name_descr, type_descr = description, ''
-
-    # correct ',' splits inside brackets
-    field_names = name_descr.split(',')
-    cleaned_field_names = []
-    prefix = ''
-    for field_name in field_names:
-        full_field_name = prefix + field_name
-        if '(' in full_field_name and ')' not in full_field_name:
-            prefix += field_name + ','
-        else:
-            prefix = ''
-            cleaned_field_names.append(full_field_name)
+        field_description, field_info = description, 'float64[2D]'
+
+    fields_info = [e for e in parse_part1(field_description)]
+    if not field_info:
+        raise ValueError("Could not parse field description")
 
-    dtype, size = _parse_type_description(type_descr)
-    fields_info = tuple(_parse_field_description(fd) for fd in cleaned_field_names)
-    return fields_info, dtype, size
+    data_type, size = parse_part2(field_info)
+    return fields_info, data_type, size
-- 
GitLab