diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index a924c85162c39f94fdcb38215c9679bfa67bde02..2d3174a1a564bfd00633d36150d8119f7dacaec5 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -1,5 +1,6 @@ from typing import Any, List, Optional, Sequence, Set, Union +import jinja2 import sympy as sp from pystencils.data_types import TypedSymbol, cast_func, create_type @@ -655,7 +656,12 @@ class DestructuringBindingsForFieldClass(Node): FieldShapeSymbol: "shape", FieldStrideSymbol: "stride" } - CLASS_NAME = "Field" + CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>") + + @property + def fields_accessed(self) -> Set['ResolvedFieldAccess']: + """Set of Field instances: fields which are accessed inside this kernel function""" + return set(o.field for o in self.atoms(ResolvedFieldAccess)) def __init__(self, body): super(DestructuringBindingsForFieldClass, self).__init__() @@ -676,10 +682,12 @@ class DestructuringBindingsForFieldClass(Node): @property def undefined_symbols(self) -> Set[sp.Symbol]: + field_map = {f.name: f for f in self.fields_accessed} undefined_field_symbols = self.symbols_defined corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} - return {TypedSymbol(f, self.CLASS_NAME + '&') for f in corresponding_field_names} | \ + return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.render(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&') + for f in corresponding_field_names} | \ (self.body.undefined_symbols - undefined_field_symbols) def subs(self, subs_dict) -> None: diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 93ef5c1c798df3b0316fdd4232462e13a9415538..3f1a02c6833756985308282c49148fe0ecaa0153 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -108,7 +108,7 @@ class TypedSymbol(sp.Symbol): obj = super(TypedSymbol, cls).__xnew__(cls, name) try: obj._dtype = create_type(dtype) - except TypeError: + except (TypeError, ValueError): # on error keep the string obj._dtype = dtype return obj diff --git a/pystencils/field.py b/pystencils/field.py index c1720654867fe5b86ad321e1a1c6b78ec9b328f6..0c3af6dfa4edec15cb0f1cdff4822083eaacc017 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -306,6 +306,10 @@ class Field(AbstractField): def index_dimensions(self) -> int: return len(self.shape) - len(self._layout) + @property + def ndim(self) -> int: + return len(self.shape) + @property def layout(self): return self._layout