Commit 80334597 authored by Martin Bauer's avatar Martin Bauer
Browse files

list lbm is working

- different approach in pystencils: absolute indexing
parent a6e206d7
......@@ -87,6 +87,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
base_buffer_index += var * stride
resolve_buffer_accesses(code, base_buffer_index, read_only_fields)
resolve_field_accesses(code, read_only_fields, field_to_base_pointer_info=base_pointer_info)
from enum import Enum
from itertools import chain
from typing import Tuple, Sequence, Optional, List
from typing import Tuple, Sequence, Optional, List, Set
import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
......@@ -366,6 +366,9 @@ class Field:
"Got %d, expected %d" % (len(offset), self.spatial_dimensions))
return Field.Access(self, offset)
def absolute_access(self, offset, index):
return Field.Access(self, offset, index, is_absolute_access=True)
def __call__(self, *args, **kwargs):
center = tuple([0] * self.spatial_dimensions)
return Field.Access(self, center)(*args, **kwargs)
......@@ -409,7 +412,7 @@ class Field:
obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs)
return obj
def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None):
def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None, is_absolute_access=False):
field_name =
offsets_and_index = chain(offsets, idx) if idx is not None else offsets
constant_offsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsets_and_index])
......@@ -450,10 +453,16 @@ class Field:
obj._superscript = superscript
obj._index = idx
obj._indirect_addressing_fields = set()
for e in chain(obj._offsets, obj._index):
if isinstance(e, sp.Basic):
obj._indirect_addressing_fields.update(a.field for a in e.atoms(Field.Access))
obj._is_absolute_access = is_absolute_access
return obj
def __getnewargs__(self):
return self.field, self.offsets, self.index
return self.field, self.offsets, self.index, self.is_absolute_access
# noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__)
......@@ -553,6 +562,19 @@ class Field:
return Field.Access(self.field, self.offsets, idx_tuple)
def is_absolute_access(self) -> bool:
"""Indicates if a field access is relative to the loop counters (this is the default) or absolute"""
return self._is_absolute_access
def indirect_addressing_fields(self) -> Set['Field']:
"""Returns a set of fields that the access depends on.
e.g. f[index_field[1, 0]], the outer access to f depends on index_field
return self._indirect_addressing_fields
def _hashable_content(self):
super_class_contents = list(super(Field.Access, self)._hashable_content())
t = tuple(super_class_contents + [hash(self._field), self._index] + self._offsets)
......@@ -3,7 +3,6 @@ from collections import defaultdict, OrderedDict, namedtuple
from copy import deepcopy
from types import MappingProxyType
import itertools
import sympy as sp
from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
......@@ -392,13 +391,16 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, Field.Access):
field_access = expr
field = field_access.field
if any(isinstance(off, Field.Access) for off in field_access.offsets):
if field_access.indirect_addressing_fields:
new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
for off in field_access.offsets)
field_access = Field.Access(field_access.field, new_offsets, field_access.index)
field = field_access.field
new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
if isinstance(ind, sp.Basic) else ind
for ind in field_access.index)
field_access = Field.Access(field_access.field, new_offsets,
new_indices, field_access.is_absolute_access)
if in field_to_base_pointer_info:
base_pointer_info = field_to_base_pointer_info[]
......@@ -415,7 +417,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if in field_to_fixed_coordinates:
coordinates[e] = field_to_fixed_coordinates[][e]
coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
if not field_access.is_absolute_access:
coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
coordinates[e] = 0
coordinates[e] *= field.dtype.item_size
if isinstance(field.dtype, StructType):
......@@ -719,9 +724,7 @@ class KernelConstraintsCheck:
if isinstance(rhs, Field.Access):
for e in itertools.chain(rhs.offsets, rhs.index):
if isinstance(e, sp.Basic):
self.fields_read.update(access.field for access in e.atoms(Field.Access))
return rhs
elif isinstance(rhs, TypedSymbol):
return rhs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment