From 4f43b51a0773c4d15602578e14ef5861cbb79ba3 Mon Sep 17 00:00:00 2001 From: Nils Kohl <nils.kohl@fau.de> Date: Wed, 27 Mar 2019 11:34:59 +0100 Subject: [PATCH] Improved support for arbitrary field classes. - introduced AbstractField and AbstractAccess Fixes #28 --- pystencils/field.py | 12 +++++++++--- pystencils/transformations.py | 25 +++++++++++++------------ 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/pystencils/field.py b/pystencils/field.py index bd6b033..17570f6 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -13,7 +13,7 @@ from pystencils.sympyextensions import is_integer_sequence import pickle import hashlib -__all__ = ['Field', 'fields', 'FieldType'] +__all__ = ['Field', 'fields', 'FieldType', 'AbstractField'] def fields(description=None, index_dimensions=0, layout=None, **kwargs): @@ -116,7 +116,13 @@ class FieldType(Enum): return field.field_type == FieldType.CUSTOM -class Field: +class AbstractField: + + class AbstractAccess: + pass + + +class Field(AbstractField): """ With fields one can formulate stencil-like update rules on structured grids. This Field class knows about the dimension, memory layout (strides) and optionally about the size of an array. @@ -394,7 +400,7 @@ class Field: return self.hashable_contents() == other.hashable_contents() # noinspection PyAttributeOutsideInit,PyUnresolvedReferences - class Access(sp.Symbol): + class Access(sp.Symbol, AbstractField.AbstractAccess): """Class representing a relative access into a `Field`. This class behaves like a normal sympy Symbol, it is actually derived from it. One can built up diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 9924b0a..bc325dc 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -9,7 +9,7 @@ from sympy.logic.boolalg import Boolean from sympy.tensor import IndexedBase from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.assignment import Assignment -from pystencils.field import Field, FieldType +from pystencils.field import AbstractField, FieldType, Field from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \ cast_func, pointer_arithmetic_func, get_type_of_expression, collate_types, create_type from pystencils.kernelparameters import FieldPointerSymbol @@ -160,7 +160,7 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts """ # find correct ordering by inspecting participating FieldAccesses - field_accesses = body.atoms(Field.Access) + field_accesses = body.atoms(AbstractField.AbstractAccess) field_accesses = {e for e in field_accesses if not e.is_absolute_access} # exclude accesses to buffers from field_list, because buffers are treated separately @@ -353,7 +353,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): loop_iterations = [(l.stop - l.start) / l.step for l in loops] loop_counters = [l.loop_counter_symbol for l in loops] - field_accesses = ast_node.atoms(Field.Access) + field_accesses = ast_node.atoms(AbstractField.AbstractAccess) buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)} loop_counters = [v * len(buffer_accesses) for v in loop_counters] @@ -369,7 +369,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()): def visit_sympy_expr(expr, enclosing_block, sympy_assignment): - if isinstance(expr, Field.Access): + if isinstance(expr, AbstractField.AbstractAccess): field_access = expr # Do not apply transformation if field is not a buffer @@ -433,7 +433,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0])) def visit_sympy_expr(expr, enclosing_block, sympy_assignment): - if isinstance(expr, Field.Access): + if isinstance(expr, AbstractField.AbstractAccess): field_access = expr field = field_access.field @@ -654,12 +654,13 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): if s in assignment_map: # if there is no assignment inside the loop body it is independent already for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol): - if type(new_symbol) is not Field.Access and new_symbol not in symbols_with_temporary_array: + if not isinstance(new_symbol, AbstractField.AbstractAccess) and \ + new_symbol not in symbols_with_temporary_array: symbols_to_process.append(new_symbol) symbols_resolved.add(s) for symbol in symbol_group: - if type(symbol) is not Field.Access: + if not isinstance(symbol, AbstractField.AbstractAccess): assert type(symbol) is TypedSymbol new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype)) symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol] @@ -668,7 +669,7 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): for assignment in inner_loop.body.args: if assignment.lhs in symbols_resolved: new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items()) - if type(assignment.lhs) is not Field.Access and assignment.lhs in symbol_group: + if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group: assert type(assignment.lhs) is TypedSymbol new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype)) new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol] @@ -792,7 +793,7 @@ class KernelConstraintsCheck: def process_expression(self, rhs, type_constants=True): self._update_accesses_rhs(rhs) - if isinstance(rhs, Field.Access): + if isinstance(rhs, AbstractField.AbstractAccess): self.fields_read.add(rhs.field) self.fields_read.update(rhs.indirect_addressing_fields) return rhs @@ -822,13 +823,13 @@ class KernelConstraintsCheck: def _process_lhs(self, lhs): assert isinstance(lhs, sp.Symbol) self._update_accesses_lhs(lhs) - if not isinstance(lhs, Field.Access) and not isinstance(lhs, TypedSymbol): + if not isinstance(lhs, AbstractField.AbstractAccess) and not isinstance(lhs, TypedSymbol): return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name]) else: return lhs def _update_accesses_lhs(self, lhs): - if isinstance(lhs, Field.Access): + if isinstance(lhs, AbstractField.AbstractAccess): fai = self.FieldAndIndex(lhs.field, lhs.index) self._field_writes[fai].add(lhs.offsets) if len(self._field_writes[fai]) > 1: @@ -841,7 +842,7 @@ class KernelConstraintsCheck: self.scopes.define_symbol(lhs) def _update_accesses_rhs(self, rhs): - if isinstance(rhs, Field.Access) and self.check_independence_condition: + if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition: writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)] for write_offset in writes: assert len(writes) == 1 -- GitLab