From c512c755e782836c409d6f6aeb94ae853fb06041 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 5 Jul 2019 19:59:02 +0200 Subject: [PATCH] Enable usage of templated Field type --- pystencils/astnodes.py | 12 ++++++++++-- pystencils/data_types.py | 2 +- pystencils/field.py | 4 ++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index a924c85..2d3174a 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -1,5 +1,6 @@ from typing import Any, List, Optional, Sequence, Set, Union +import jinja2 import sympy as sp from pystencils.data_types import TypedSymbol, cast_func, create_type @@ -655,7 +656,12 @@ class DestructuringBindingsForFieldClass(Node): FieldShapeSymbol: "shape", FieldStrideSymbol: "stride" } - CLASS_NAME = "Field" + CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>") + + @property + def fields_accessed(self) -> Set['ResolvedFieldAccess']: + """Set of Field instances: fields which are accessed inside this kernel function""" + return set(o.field for o in self.atoms(ResolvedFieldAccess)) def __init__(self, body): super(DestructuringBindingsForFieldClass, self).__init__() @@ -676,10 +682,12 @@ class DestructuringBindingsForFieldClass(Node): @property def undefined_symbols(self) -> Set[sp.Symbol]: + field_map = {f.name: f for f in self.fields_accessed} undefined_field_symbols = self.symbols_defined corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} - return {TypedSymbol(f, self.CLASS_NAME + '&') for f in corresponding_field_names} | \ + return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.render(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&') + for f in corresponding_field_names} | \ (self.body.undefined_symbols - undefined_field_symbols) def subs(self, subs_dict) -> None: diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 93ef5c1..3f1a02c 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -108,7 +108,7 @@ class TypedSymbol(sp.Symbol): obj = super(TypedSymbol, cls).__xnew__(cls, name) try: obj._dtype = create_type(dtype) - except TypeError: + except (TypeError, ValueError): # on error keep the string obj._dtype = dtype return obj diff --git a/pystencils/field.py b/pystencils/field.py index c172065..0c3af6d 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -306,6 +306,10 @@ class Field(AbstractField): def index_dimensions(self) -> int: return len(self.shape) - len(self._layout) + @property + def ndim(self) -> int: + return len(self.shape) + @property def layout(self): return self._layout -- GitLab