Commit 80334597 authored by Martin Bauer's avatar Martin Bauer
Browse files

list lbm is working

- different approach in pystencils: absolute indexing
parent a6e206d7
......@@ -87,6 +87,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
base_buffer_index += var * stride
resolve_buffer_accesses(code, base_buffer_index, read_only_fields)
resolve_field_accesses(code, read_only_fields, field_to_base_pointer_info=base_pointer_info)
substitute_array_accesses_with_constants(code)
move_constants_before_loop(code)
......
from enum import Enum
from itertools import chain
from typing import Tuple, Sequence, Optional, List
from typing import Tuple, Sequence, Optional, List, Set
import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
......@@ -366,6 +366,9 @@ class Field:
"Got %d, expected %d" % (len(offset), self.spatial_dimensions))
return Field.Access(self, offset)
def absolute_access(self, offset, index):
return Field.Access(self, offset, index, is_absolute_access=True)
def __call__(self, *args, **kwargs):
center = tuple([0] * self.spatial_dimensions)
return Field.Access(self, center)(*args, **kwargs)
......@@ -409,7 +412,7 @@ class Field:
obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs)
return obj
def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None):
def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None, is_absolute_access=False):
field_name = field.name
offsets_and_index = chain(offsets, idx) if idx is not None else offsets
constant_offsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsets_and_index])
......@@ -450,10 +453,16 @@ class Field:
obj._superscript = superscript
obj._index = idx
obj._indirect_addressing_fields = set()
for e in chain(obj._offsets, obj._index):
if isinstance(e, sp.Basic):
obj._indirect_addressing_fields.update(a.field for a in e.atoms(Field.Access))
obj._is_absolute_access = is_absolute_access
return obj
def __getnewargs__(self):
return self.field, self.offsets, self.index
return self.field, self.offsets, self.index, self.is_absolute_access
# noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__)
......@@ -553,6 +562,19 @@ class Field:
"""
return Field.Access(self.field, self.offsets, idx_tuple)
@property
def is_absolute_access(self) -> bool:
"""Indicates if a field access is relative to the loop counters (this is the default) or absolute"""
return self._is_absolute_access
@property
def indirect_addressing_fields(self) -> Set['Field']:
"""Returns a set of fields that the access depends on.
e.g. f[index_field[1, 0]], the outer access to f depends on index_field
"""
return self._indirect_addressing_fields
def _hashable_content(self):
super_class_contents = list(super(Field.Access, self)._hashable_content())
t = tuple(super_class_contents + [hash(self._field), self._index] + self._offsets)
......
......@@ -3,7 +3,6 @@ from collections import defaultdict, OrderedDict, namedtuple
from copy import deepcopy
from types import MappingProxyType
import itertools
import sympy as sp
from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
......@@ -392,13 +391,16 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, Field.Access):
field_access = expr
field = field_access.field
if any(isinstance(off, Field.Access) for off in field_access.offsets):
if field_access.indirect_addressing_fields:
new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
for off in field_access.offsets)
field_access = Field.Access(field_access.field, new_offsets, field_access.index)
field = field_access.field
new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
if isinstance(ind, sp.Basic) else ind
for ind in field_access.index)
field_access = Field.Access(field_access.field, new_offsets,
new_indices, field_access.is_absolute_access)
if field.name in field_to_base_pointer_info:
base_pointer_info = field_to_base_pointer_info[field.name]
......@@ -415,7 +417,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if field.name in field_to_fixed_coordinates:
coordinates[e] = field_to_fixed_coordinates[field.name][e]
else:
if not field_access.is_absolute_access:
coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
else:
coordinates[e] = 0
coordinates[e] *= field.dtype.item_size
else:
if isinstance(field.dtype, StructType):
......@@ -719,9 +724,7 @@ class KernelConstraintsCheck:
self._update_accesses_rhs(rhs)
if isinstance(rhs, Field.Access):
self.fields_read.add(rhs.field)
for e in itertools.chain(rhs.offsets, rhs.index):
if isinstance(e, sp.Basic):
self.fields_read.update(access.field for access in e.atoms(Field.Access))
self.fields_read.update(rhs.indirect_addressing_fields)
return rhs
elif isinstance(rhs, TypedSymbol):
return rhs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment