From 80334597af849cb21856b0e3cf56a82b26650d55 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Sun, 27 May 2018 18:34:17 +0200 Subject: [PATCH] list lbm is working - different approach in pystencils: absolute indexing --- cpu/kernelcreation.py | 1 + field.py | 28 +++++++++++++++++++++++++--- transformations.py | 21 ++++++++++++--------- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index b9abd61a8..c34a83ec4 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -87,6 +87,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke base_buffer_index += var * stride resolve_buffer_accesses(code, base_buffer_index, read_only_fields) + resolve_field_accesses(code, read_only_fields, field_to_base_pointer_info=base_pointer_info) substitute_array_accesses_with_constants(code) move_constants_before_loop(code) diff --git a/field.py b/field.py index 38bdcdc16..58f13cb94 100644 --- a/field.py +++ b/field.py @@ -1,6 +1,6 @@ from enum import Enum from itertools import chain -from typing import Tuple, Sequence, Optional, List +from typing import Tuple, Sequence, Optional, List, Set import numpy as np import sympy as sp from sympy.core.cache import cacheit @@ -366,6 +366,9 @@ class Field: "Got %d, expected %d" % (len(offset), self.spatial_dimensions)) return Field.Access(self, offset) + def absolute_access(self, offset, index): + return Field.Access(self, offset, index, is_absolute_access=True) + def __call__(self, *args, **kwargs): center = tuple([0] * self.spatial_dimensions) return Field.Access(self, center)(*args, **kwargs) @@ -409,7 +412,7 @@ class Field: obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs) return obj - def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None): + def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None, is_absolute_access=False): field_name = field.name offsets_and_index = chain(offsets, idx) if idx is not None else offsets constant_offsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsets_and_index]) @@ -450,10 +453,16 @@ class Field: obj._superscript = superscript obj._index = idx + obj._indirect_addressing_fields = set() + for e in chain(obj._offsets, obj._index): + if isinstance(e, sp.Basic): + obj._indirect_addressing_fields.update(a.field for a in e.atoms(Field.Access)) + + obj._is_absolute_access = is_absolute_access return obj def __getnewargs__(self): - return self.field, self.offsets, self.index + return self.field, self.offsets, self.index, self.is_absolute_access # noinspection SpellCheckingInspection __xnew__ = staticmethod(__new_stage2__) @@ -553,6 +562,19 @@ class Field: """ return Field.Access(self.field, self.offsets, idx_tuple) + @property + def is_absolute_access(self) -> bool: + """Indicates if a field access is relative to the loop counters (this is the default) or absolute""" + return self._is_absolute_access + + @property + def indirect_addressing_fields(self) -> Set['Field']: + """Returns a set of fields that the access depends on. + + e.g. f[index_field[1, 0]], the outer access to f depends on index_field + """ + return self._indirect_addressing_fields + def _hashable_content(self): super_class_contents = list(super(Field.Access, self)._hashable_content()) t = tuple(super_class_contents + [hash(self._field), self._index] + self._offsets) diff --git a/transformations.py b/transformations.py index 32df1ac4d..639a48184 100644 --- a/transformations.py +++ b/transformations.py @@ -3,7 +3,6 @@ from collections import defaultdict, OrderedDict, namedtuple from copy import deepcopy from types import MappingProxyType -import itertools import sympy as sp from sympy.logic.boolalg import Boolean from sympy.tensor import IndexedBase @@ -392,13 +391,16 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), def visit_sympy_expr(expr, enclosing_block, sympy_assignment): if isinstance(expr, Field.Access): field_access = expr + field = field_access.field - if any(isinstance(off, Field.Access) for off in field_access.offsets): + if field_access.indirect_addressing_fields: new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment) for off in field_access.offsets) - field_access = Field.Access(field_access.field, new_offsets, field_access.index) - - field = field_access.field + new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment) + if isinstance(ind, sp.Basic) else ind + for ind in field_access.index) + field_access = Field.Access(field_access.field, new_offsets, + new_indices, field_access.is_absolute_access) if field.name in field_to_base_pointer_info: base_pointer_info = field_to_base_pointer_info[field.name] @@ -415,7 +417,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), if field.name in field_to_fixed_coordinates: coordinates[e] = field_to_fixed_coordinates[field.name][e] else: - coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e) + if not field_access.is_absolute_access: + coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e) + else: + coordinates[e] = 0 coordinates[e] *= field.dtype.item_size else: if isinstance(field.dtype, StructType): @@ -719,9 +724,7 @@ class KernelConstraintsCheck: self._update_accesses_rhs(rhs) if isinstance(rhs, Field.Access): self.fields_read.add(rhs.field) - for e in itertools.chain(rhs.offsets, rhs.index): - if isinstance(e, sp.Basic): - self.fields_read.update(access.field for access in e.atoms(Field.Access)) + self.fields_read.update(rhs.indirect_addressing_fields) return rhs elif isinstance(rhs, TypedSymbol): return rhs -- GitLab