Commit 4f43b51a authored by Nils Kohl's avatar Nils Kohl 🌝 Committed by Martin Bauer
Browse files

Improved support for arbitrary field classes.

- introduced AbstractField and AbstractAccess

Fixes #28
parent eec4dc4b
...@@ -13,7 +13,7 @@ from pystencils.sympyextensions import is_integer_sequence ...@@ -13,7 +13,7 @@ from pystencils.sympyextensions import is_integer_sequence
import pickle import pickle
import hashlib import hashlib
__all__ = ['Field', 'fields', 'FieldType'] __all__ = ['Field', 'fields', 'FieldType', 'AbstractField']
def fields(description=None, index_dimensions=0, layout=None, **kwargs): def fields(description=None, index_dimensions=0, layout=None, **kwargs):
...@@ -116,7 +116,13 @@ class FieldType(Enum): ...@@ -116,7 +116,13 @@ class FieldType(Enum):
return field.field_type == FieldType.CUSTOM return field.field_type == FieldType.CUSTOM
class Field: class AbstractField:
class AbstractAccess:
pass
class Field(AbstractField):
""" """
With fields one can formulate stencil-like update rules on structured grids. With fields one can formulate stencil-like update rules on structured grids.
This Field class knows about the dimension, memory layout (strides) and optionally about the size of an array. This Field class knows about the dimension, memory layout (strides) and optionally about the size of an array.
...@@ -394,7 +400,7 @@ class Field: ...@@ -394,7 +400,7 @@ class Field:
return self.hashable_contents() == other.hashable_contents() return self.hashable_contents() == other.hashable_contents()
# noinspection PyAttributeOutsideInit,PyUnresolvedReferences # noinspection PyAttributeOutsideInit,PyUnresolvedReferences
class Access(sp.Symbol): class Access(sp.Symbol, AbstractField.AbstractAccess):
"""Class representing a relative access into a `Field`. """Class representing a relative access into a `Field`.
This class behaves like a normal sympy Symbol, it is actually derived from it. One can built up This class behaves like a normal sympy Symbol, it is actually derived from it. One can built up
......
...@@ -9,7 +9,7 @@ from sympy.logic.boolalg import Boolean ...@@ -9,7 +9,7 @@ from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase from sympy.tensor import IndexedBase
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.field import Field, FieldType from pystencils.field import AbstractField, FieldType, Field
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \ from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \
cast_func, pointer_arithmetic_func, get_type_of_expression, collate_types, create_type cast_func, pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
...@@ -160,7 +160,7 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer ...@@ -160,7 +160,7 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
:class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
""" """
# find correct ordering by inspecting participating FieldAccesses # find correct ordering by inspecting participating FieldAccesses
field_accesses = body.atoms(Field.Access) field_accesses = body.atoms(AbstractField.AbstractAccess)
field_accesses = {e for e in field_accesses if not e.is_absolute_access} field_accesses = {e for e in field_accesses if not e.is_absolute_access}
# exclude accesses to buffers from field_list, because buffers are treated separately # exclude accesses to buffers from field_list, because buffers are treated separately
...@@ -353,7 +353,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ...@@ -353,7 +353,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
loop_iterations = [(l.stop - l.start) / l.step for l in loops] loop_iterations = [(l.stop - l.start) / l.step for l in loops]
loop_counters = [l.loop_counter_symbol for l in loops] loop_counters = [l.loop_counter_symbol for l in loops]
field_accesses = ast_node.atoms(Field.Access) field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)} buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
loop_counters = [v * len(buffer_accesses) for v in loop_counters] loop_counters = [v * len(buffer_accesses) for v in loop_counters]
...@@ -369,7 +369,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ...@@ -369,7 +369,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()): def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
def visit_sympy_expr(expr, enclosing_block, sympy_assignment): def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, Field.Access): if isinstance(expr, AbstractField.AbstractAccess):
field_access = expr field_access = expr
# Do not apply transformation if field is not a buffer # Do not apply transformation if field is not a buffer
...@@ -433,7 +433,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -433,7 +433,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0])) field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
def visit_sympy_expr(expr, enclosing_block, sympy_assignment): def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, Field.Access): if isinstance(expr, AbstractField.AbstractAccess):
field_access = expr field_access = expr
field = field_access.field field = field_access.field
...@@ -654,12 +654,13 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): ...@@ -654,12 +654,13 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
if s in assignment_map: # if there is no assignment inside the loop body it is independent already if s in assignment_map: # if there is no assignment inside the loop body it is independent already
for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol): for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
if type(new_symbol) is not Field.Access and new_symbol not in symbols_with_temporary_array: if not isinstance(new_symbol, AbstractField.AbstractAccess) and \
new_symbol not in symbols_with_temporary_array:
symbols_to_process.append(new_symbol) symbols_to_process.append(new_symbol)
symbols_resolved.add(s) symbols_resolved.add(s)
for symbol in symbol_group: for symbol in symbol_group:
if type(symbol) is not Field.Access: if not isinstance(symbol, AbstractField.AbstractAccess):
assert type(symbol) is TypedSymbol assert type(symbol) is TypedSymbol
new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype)) new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol] symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
...@@ -668,7 +669,7 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): ...@@ -668,7 +669,7 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
for assignment in inner_loop.body.args: for assignment in inner_loop.body.args:
if assignment.lhs in symbols_resolved: if assignment.lhs in symbols_resolved:
new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items()) new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items())
if type(assignment.lhs) is not Field.Access and assignment.lhs in symbol_group: if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group:
assert type(assignment.lhs) is TypedSymbol assert type(assignment.lhs) is TypedSymbol
new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype)) new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol] new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
...@@ -792,7 +793,7 @@ class KernelConstraintsCheck: ...@@ -792,7 +793,7 @@ class KernelConstraintsCheck:
def process_expression(self, rhs, type_constants=True): def process_expression(self, rhs, type_constants=True):
self._update_accesses_rhs(rhs) self._update_accesses_rhs(rhs)
if isinstance(rhs, Field.Access): if isinstance(rhs, AbstractField.AbstractAccess):
self.fields_read.add(rhs.field) self.fields_read.add(rhs.field)
self.fields_read.update(rhs.indirect_addressing_fields) self.fields_read.update(rhs.indirect_addressing_fields)
return rhs return rhs
...@@ -822,13 +823,13 @@ class KernelConstraintsCheck: ...@@ -822,13 +823,13 @@ class KernelConstraintsCheck:
def _process_lhs(self, lhs): def _process_lhs(self, lhs):
assert isinstance(lhs, sp.Symbol) assert isinstance(lhs, sp.Symbol)
self._update_accesses_lhs(lhs) self._update_accesses_lhs(lhs)
if not isinstance(lhs, Field.Access) and not isinstance(lhs, TypedSymbol): if not isinstance(lhs, AbstractField.AbstractAccess) and not isinstance(lhs, TypedSymbol):
return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name]) return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
else: else:
return lhs return lhs
def _update_accesses_lhs(self, lhs): def _update_accesses_lhs(self, lhs):
if isinstance(lhs, Field.Access): if isinstance(lhs, AbstractField.AbstractAccess):
fai = self.FieldAndIndex(lhs.field, lhs.index) fai = self.FieldAndIndex(lhs.field, lhs.index)
self._field_writes[fai].add(lhs.offsets) self._field_writes[fai].add(lhs.offsets)
if len(self._field_writes[fai]) > 1: if len(self._field_writes[fai]) > 1:
...@@ -841,7 +842,7 @@ class KernelConstraintsCheck: ...@@ -841,7 +842,7 @@ class KernelConstraintsCheck:
self.scopes.define_symbol(lhs) self.scopes.define_symbol(lhs)
def _update_accesses_rhs(self, rhs): def _update_accesses_rhs(self, rhs):
if isinstance(rhs, Field.Access) and self.check_independence_condition: if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition:
writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)] writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
for write_offset in writes: for write_offset in writes:
assert len(writes) == 1 assert len(writes) == 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment