Skip to content
Snippets Groups Projects
Commit c512c755 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Enable usage of templated Field type

parent 8e63c9ff
Branches
Tags
No related merge requests found
from typing import Any, List, Optional, Sequence, Set, Union from typing import Any, List, Optional, Sequence, Set, Union
import jinja2
import sympy as sp import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type from pystencils.data_types import TypedSymbol, cast_func, create_type
...@@ -655,7 +656,12 @@ class DestructuringBindingsForFieldClass(Node): ...@@ -655,7 +656,12 @@ class DestructuringBindingsForFieldClass(Node):
FieldShapeSymbol: "shape", FieldShapeSymbol: "shape",
FieldStrideSymbol: "stride" FieldStrideSymbol: "stride"
} }
CLASS_NAME = "Field" CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>")
@property
def fields_accessed(self) -> Set['ResolvedFieldAccess']:
"""Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess))
def __init__(self, body): def __init__(self, body):
super(DestructuringBindingsForFieldClass, self).__init__() super(DestructuringBindingsForFieldClass, self).__init__()
...@@ -676,10 +682,12 @@ class DestructuringBindingsForFieldClass(Node): ...@@ -676,10 +682,12 @@ class DestructuringBindingsForFieldClass(Node):
@property @property
def undefined_symbols(self) -> Set[sp.Symbol]: def undefined_symbols(self) -> Set[sp.Symbol]:
field_map = {f.name: f for f in self.fields_accessed}
undefined_field_symbols = self.symbols_defined undefined_field_symbols = self.symbols_defined
corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
return {TypedSymbol(f, self.CLASS_NAME + '&') for f in corresponding_field_names} | \ return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.render(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&')
for f in corresponding_field_names} | \
(self.body.undefined_symbols - undefined_field_symbols) (self.body.undefined_symbols - undefined_field_symbols)
def subs(self, subs_dict) -> None: def subs(self, subs_dict) -> None:
......
...@@ -108,7 +108,7 @@ class TypedSymbol(sp.Symbol): ...@@ -108,7 +108,7 @@ class TypedSymbol(sp.Symbol):
obj = super(TypedSymbol, cls).__xnew__(cls, name) obj = super(TypedSymbol, cls).__xnew__(cls, name)
try: try:
obj._dtype = create_type(dtype) obj._dtype = create_type(dtype)
except TypeError: except (TypeError, ValueError):
# on error keep the string # on error keep the string
obj._dtype = dtype obj._dtype = dtype
return obj return obj
......
...@@ -306,6 +306,10 @@ class Field(AbstractField): ...@@ -306,6 +306,10 @@ class Field(AbstractField):
def index_dimensions(self) -> int: def index_dimensions(self) -> int:
return len(self.shape) - len(self._layout) return len(self.shape) - len(self._layout)
@property
def ndim(self) -> int:
return len(self.shape)
@property @property
def layout(self): def layout(self):
return self._layout return self._layout
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment