Commit 80334597 authored by Martin Bauer's avatar Martin Bauer
Browse files

list lbm is working

- different approach in pystencils: absolute indexing
parent a6e206d7
...@@ -87,6 +87,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke ...@@ -87,6 +87,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
base_buffer_index += var * stride base_buffer_index += var * stride
resolve_buffer_accesses(code, base_buffer_index, read_only_fields) resolve_buffer_accesses(code, base_buffer_index, read_only_fields)
resolve_field_accesses(code, read_only_fields, field_to_base_pointer_info=base_pointer_info) resolve_field_accesses(code, read_only_fields, field_to_base_pointer_info=base_pointer_info)
substitute_array_accesses_with_constants(code) substitute_array_accesses_with_constants(code)
move_constants_before_loop(code) move_constants_before_loop(code)
......
from enum import Enum from enum import Enum
from itertools import chain from itertools import chain
from typing import Tuple, Sequence, Optional, List from typing import Tuple, Sequence, Optional, List, Set
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from sympy.core.cache import cacheit from sympy.core.cache import cacheit
...@@ -366,6 +366,9 @@ class Field: ...@@ -366,6 +366,9 @@ class Field:
"Got %d, expected %d" % (len(offset), self.spatial_dimensions)) "Got %d, expected %d" % (len(offset), self.spatial_dimensions))
return Field.Access(self, offset) return Field.Access(self, offset)
def absolute_access(self, offset, index):
return Field.Access(self, offset, index, is_absolute_access=True)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
center = tuple([0] * self.spatial_dimensions) center = tuple([0] * self.spatial_dimensions)
return Field.Access(self, center)(*args, **kwargs) return Field.Access(self, center)(*args, **kwargs)
...@@ -409,7 +412,7 @@ class Field: ...@@ -409,7 +412,7 @@ class Field:
obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs) obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs)
return obj return obj
def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None): def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None, is_absolute_access=False):
field_name = field.name field_name = field.name
offsets_and_index = chain(offsets, idx) if idx is not None else offsets offsets_and_index = chain(offsets, idx) if idx is not None else offsets
constant_offsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsets_and_index]) constant_offsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsets_and_index])
...@@ -450,10 +453,16 @@ class Field: ...@@ -450,10 +453,16 @@ class Field:
obj._superscript = superscript obj._superscript = superscript
obj._index = idx obj._index = idx
obj._indirect_addressing_fields = set()
for e in chain(obj._offsets, obj._index):
if isinstance(e, sp.Basic):
obj._indirect_addressing_fields.update(a.field for a in e.atoms(Field.Access))
obj._is_absolute_access = is_absolute_access
return obj return obj
def __getnewargs__(self): def __getnewargs__(self):
return self.field, self.offsets, self.index return self.field, self.offsets, self.index, self.is_absolute_access
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__) __xnew__ = staticmethod(__new_stage2__)
...@@ -553,6 +562,19 @@ class Field: ...@@ -553,6 +562,19 @@ class Field:
""" """
return Field.Access(self.field, self.offsets, idx_tuple) return Field.Access(self.field, self.offsets, idx_tuple)
@property
def is_absolute_access(self) -> bool:
"""Indicates if a field access is relative to the loop counters (this is the default) or absolute"""
return self._is_absolute_access
@property
def indirect_addressing_fields(self) -> Set['Field']:
"""Returns a set of fields that the access depends on.
e.g. f[index_field[1, 0]], the outer access to f depends on index_field
"""
return self._indirect_addressing_fields
def _hashable_content(self): def _hashable_content(self):
super_class_contents = list(super(Field.Access, self)._hashable_content()) super_class_contents = list(super(Field.Access, self)._hashable_content())
t = tuple(super_class_contents + [hash(self._field), self._index] + self._offsets) t = tuple(super_class_contents + [hash(self._field), self._index] + self._offsets)
......
...@@ -3,7 +3,6 @@ from collections import defaultdict, OrderedDict, namedtuple ...@@ -3,7 +3,6 @@ from collections import defaultdict, OrderedDict, namedtuple
from copy import deepcopy from copy import deepcopy
from types import MappingProxyType from types import MappingProxyType
import itertools
import sympy as sp import sympy as sp
from sympy.logic.boolalg import Boolean from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase from sympy.tensor import IndexedBase
...@@ -392,13 +391,16 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -392,13 +391,16 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
def visit_sympy_expr(expr, enclosing_block, sympy_assignment): def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, Field.Access): if isinstance(expr, Field.Access):
field_access = expr field_access = expr
field = field_access.field
if any(isinstance(off, Field.Access) for off in field_access.offsets): if field_access.indirect_addressing_fields:
new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment) new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
for off in field_access.offsets) for off in field_access.offsets)
field_access = Field.Access(field_access.field, new_offsets, field_access.index) new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
if isinstance(ind, sp.Basic) else ind
field = field_access.field for ind in field_access.index)
field_access = Field.Access(field_access.field, new_offsets,
new_indices, field_access.is_absolute_access)
if field.name in field_to_base_pointer_info: if field.name in field_to_base_pointer_info:
base_pointer_info = field_to_base_pointer_info[field.name] base_pointer_info = field_to_base_pointer_info[field.name]
...@@ -415,7 +417,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -415,7 +417,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if field.name in field_to_fixed_coordinates: if field.name in field_to_fixed_coordinates:
coordinates[e] = field_to_fixed_coordinates[field.name][e] coordinates[e] = field_to_fixed_coordinates[field.name][e]
else: else:
coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e) if not field_access.is_absolute_access:
coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
else:
coordinates[e] = 0
coordinates[e] *= field.dtype.item_size coordinates[e] *= field.dtype.item_size
else: else:
if isinstance(field.dtype, StructType): if isinstance(field.dtype, StructType):
...@@ -719,9 +724,7 @@ class KernelConstraintsCheck: ...@@ -719,9 +724,7 @@ class KernelConstraintsCheck:
self._update_accesses_rhs(rhs) self._update_accesses_rhs(rhs)
if isinstance(rhs, Field.Access): if isinstance(rhs, Field.Access):
self.fields_read.add(rhs.field) self.fields_read.add(rhs.field)
for e in itertools.chain(rhs.offsets, rhs.index): self.fields_read.update(rhs.indirect_addressing_fields)
if isinstance(e, sp.Basic):
self.fields_read.update(access.field for access in e.atoms(Field.Access))
return rhs return rhs
elif isinstance(rhs, TypedSymbol): elif isinstance(rhs, TypedSymbol):
return rhs return rhs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment