From 4f43b51a0773c4d15602578e14ef5861cbb79ba3 Mon Sep 17 00:00:00 2001
From: Nils Kohl <nils.kohl@fau.de>
Date: Wed, 27 Mar 2019 11:34:59 +0100
Subject: [PATCH] Improved support for arbitrary field classes.

- introduced AbstractField and AbstractAccess

Fixes #28
---
 pystencils/field.py           | 12 +++++++++---
 pystencils/transformations.py | 25 +++++++++++++------------
 2 files changed, 22 insertions(+), 15 deletions(-)

diff --git a/pystencils/field.py b/pystencils/field.py
index bd6b033..17570f6 100644
--- a/pystencils/field.py
+++ b/pystencils/field.py
@@ -13,7 +13,7 @@ from pystencils.sympyextensions import is_integer_sequence
 import pickle
 import hashlib
 
-__all__ = ['Field', 'fields', 'FieldType']
+__all__ = ['Field', 'fields', 'FieldType', 'AbstractField']
 
 
 def fields(description=None, index_dimensions=0, layout=None, **kwargs):
@@ -116,7 +116,13 @@ class FieldType(Enum):
         return field.field_type == FieldType.CUSTOM
 
 
-class Field:
+class AbstractField:
+
+    class AbstractAccess:
+        pass
+
+
+class Field(AbstractField):
     """
     With fields one can formulate stencil-like update rules on structured grids.
     This Field class knows about the dimension, memory layout (strides) and optionally about the size of an array.
@@ -394,7 +400,7 @@ class Field:
         return self.hashable_contents() == other.hashable_contents()
 
     # noinspection PyAttributeOutsideInit,PyUnresolvedReferences
-    class Access(sp.Symbol):
+    class Access(sp.Symbol, AbstractField.AbstractAccess):
         """Class representing a relative access into a `Field`.
 
         This class behaves like a normal sympy Symbol, it is actually derived from it. One can built up
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index 9924b0a..bc325dc 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -9,7 +9,7 @@ from sympy.logic.boolalg import Boolean
 from sympy.tensor import IndexedBase
 from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.assignment import Assignment
-from pystencils.field import Field, FieldType
+from pystencils.field import AbstractField, FieldType, Field
 from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \
     cast_func, pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
 from pystencils.kernelparameters import FieldPointerSymbol
@@ -160,7 +160,7 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
         :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
     """
     # find correct ordering by inspecting participating FieldAccesses
-    field_accesses = body.atoms(Field.Access)
+    field_accesses = body.atoms(AbstractField.AbstractAccess)
     field_accesses = {e for e in field_accesses if not e.is_absolute_access}
 
     # exclude accesses to buffers from field_list, because buffers are treated separately
@@ -353,7 +353,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
         loop_iterations = [(l.stop - l.start) / l.step for l in loops]
         loop_counters = [l.loop_counter_symbol for l in loops]
 
-    field_accesses = ast_node.atoms(Field.Access)
+    field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
     buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
     loop_counters = [v * len(buffer_accesses) for v in loop_counters]
 
@@ -369,7 +369,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
 def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
 
     def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
-        if isinstance(expr, Field.Access):
+        if isinstance(expr, AbstractField.AbstractAccess):
             field_access = expr
 
             # Do not apply transformation if field is not a buffer
@@ -433,7 +433,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
     field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
 
     def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
-        if isinstance(expr, Field.Access):
+        if isinstance(expr, AbstractField.AbstractAccess):
             field_access = expr
             field = field_access.field
 
@@ -654,12 +654,13 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
 
             if s in assignment_map:  # if there is no assignment inside the loop body it is independent already
                 for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
-                    if type(new_symbol) is not Field.Access and new_symbol not in symbols_with_temporary_array:
+                    if not isinstance(new_symbol, AbstractField.AbstractAccess) and \
+                            new_symbol not in symbols_with_temporary_array:
                         symbols_to_process.append(new_symbol)
             symbols_resolved.add(s)
 
         for symbol in symbol_group:
-            if type(symbol) is not Field.Access:
+            if not isinstance(symbol, AbstractField.AbstractAccess):
                 assert type(symbol) is TypedSymbol
                 new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
                 symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
@@ -668,7 +669,7 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
         for assignment in inner_loop.body.args:
             if assignment.lhs in symbols_resolved:
                 new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items())
-                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbol_group:
+                if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group:
                     assert type(assignment.lhs) is TypedSymbol
                     new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
                     new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
@@ -792,7 +793,7 @@ class KernelConstraintsCheck:
 
     def process_expression(self, rhs, type_constants=True):
         self._update_accesses_rhs(rhs)
-        if isinstance(rhs, Field.Access):
+        if isinstance(rhs, AbstractField.AbstractAccess):
             self.fields_read.add(rhs.field)
             self.fields_read.update(rhs.indirect_addressing_fields)
             return rhs
@@ -822,13 +823,13 @@ class KernelConstraintsCheck:
     def _process_lhs(self, lhs):
         assert isinstance(lhs, sp.Symbol)
         self._update_accesses_lhs(lhs)
-        if not isinstance(lhs, Field.Access) and not isinstance(lhs, TypedSymbol):
+        if not isinstance(lhs, AbstractField.AbstractAccess) and not isinstance(lhs, TypedSymbol):
             return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
         else:
             return lhs
 
     def _update_accesses_lhs(self, lhs):
-        if isinstance(lhs, Field.Access):
+        if isinstance(lhs, AbstractField.AbstractAccess):
             fai = self.FieldAndIndex(lhs.field, lhs.index)
             self._field_writes[fai].add(lhs.offsets)
             if len(self._field_writes[fai]) > 1:
@@ -841,7 +842,7 @@ class KernelConstraintsCheck:
             self.scopes.define_symbol(lhs)
 
     def _update_accesses_rhs(self, rhs):
-        if isinstance(rhs, Field.Access) and self.check_independence_condition:
+        if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition:
             writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
             for write_offset in writes:
                 assert len(writes) == 1
-- 
GitLab