From c512c755e782836c409d6f6aeb94ae853fb06041 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 5 Jul 2019 19:59:02 +0200
Subject: [PATCH] Enable usage of templated Field type

---
 pystencils/astnodes.py   | 12 ++++++++++--
 pystencils/data_types.py |  2 +-
 pystencils/field.py      |  4 ++++
 3 files changed, 15 insertions(+), 3 deletions(-)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index a924c85..2d3174a 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -1,5 +1,6 @@
 from typing import Any, List, Optional, Sequence, Set, Union
 
+import jinja2
 import sympy as sp
 
 from pystencils.data_types import TypedSymbol, cast_func, create_type
@@ -655,7 +656,12 @@ class DestructuringBindingsForFieldClass(Node):
         FieldShapeSymbol: "shape",
         FieldStrideSymbol: "stride"
     }
-    CLASS_NAME = "Field"
+    CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>")
+
+    @property
+    def fields_accessed(self) -> Set['ResolvedFieldAccess']:
+        """Set of Field instances: fields which are accessed inside this kernel function"""
+        return set(o.field for o in self.atoms(ResolvedFieldAccess))
 
     def __init__(self, body):
         super(DestructuringBindingsForFieldClass, self).__init__()
@@ -676,10 +682,12 @@ class DestructuringBindingsForFieldClass(Node):
 
     @property
     def undefined_symbols(self) -> Set[sp.Symbol]:
+        field_map = {f.name: f for f in self.fields_accessed}
         undefined_field_symbols = self.symbols_defined
         corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
         corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
-        return {TypedSymbol(f, self.CLASS_NAME + '&') for f in corresponding_field_names} | \
+        return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.render(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&')
+                for f in corresponding_field_names} | \
             (self.body.undefined_symbols - undefined_field_symbols)
 
     def subs(self, subs_dict) -> None:
diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index 93ef5c1..3f1a02c 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -108,7 +108,7 @@ class TypedSymbol(sp.Symbol):
         obj = super(TypedSymbol, cls).__xnew__(cls, name)
         try:
             obj._dtype = create_type(dtype)
-        except TypeError:
+        except (TypeError, ValueError):
             # on error keep the string
             obj._dtype = dtype
         return obj
diff --git a/pystencils/field.py b/pystencils/field.py
index c172065..0c3af6d 100644
--- a/pystencils/field.py
+++ b/pystencils/field.py
@@ -306,6 +306,10 @@ class Field(AbstractField):
     def index_dimensions(self) -> int:
         return len(self.shape) - len(self._layout)
 
+    @property
+    def ndim(self) -> int:
+        return len(self.shape)
+
     @property
     def layout(self):
         return self._layout
-- 
GitLab