From 80334597af849cb21856b0e3cf56a82b26650d55 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Sun, 27 May 2018 18:34:17 +0200
Subject: [PATCH] list lbm is working

- different approach in pystencils: absolute indexing
---
 cpu/kernelcreation.py |  1 +
 field.py              | 28 +++++++++++++++++++++++++---
 transformations.py    | 21 ++++++++++++---------
 3 files changed, 38 insertions(+), 12 deletions(-)

diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index b9abd61a8..c34a83ec4 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -87,6 +87,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
         base_buffer_index += var * stride
 
     resolve_buffer_accesses(code, base_buffer_index, read_only_fields)
+
     resolve_field_accesses(code, read_only_fields, field_to_base_pointer_info=base_pointer_info)
     substitute_array_accesses_with_constants(code)
     move_constants_before_loop(code)
diff --git a/field.py b/field.py
index 38bdcdc16..58f13cb94 100644
--- a/field.py
+++ b/field.py
@@ -1,6 +1,6 @@
 from enum import Enum
 from itertools import chain
-from typing import Tuple, Sequence, Optional, List
+from typing import Tuple, Sequence, Optional, List, Set
 import numpy as np
 import sympy as sp
 from sympy.core.cache import cacheit
@@ -366,6 +366,9 @@ class Field:
                              "Got %d, expected %d" % (len(offset), self.spatial_dimensions))
         return Field.Access(self, offset)
 
+    def absolute_access(self, offset, index):
+        return Field.Access(self, offset, index, is_absolute_access=True)
+
     def __call__(self, *args, **kwargs):
         center = tuple([0] * self.spatial_dimensions)
         return Field.Access(self, center)(*args, **kwargs)
@@ -409,7 +412,7 @@ class Field:
             obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs)
             return obj
 
-        def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None):
+        def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None, is_absolute_access=False):
             field_name = field.name
             offsets_and_index = chain(offsets, idx) if idx is not None else offsets
             constant_offsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsets_and_index])
@@ -450,10 +453,16 @@ class Field:
             obj._superscript = superscript
             obj._index = idx
 
+            obj._indirect_addressing_fields = set()
+            for e in chain(obj._offsets, obj._index):
+                if isinstance(e, sp.Basic):
+                    obj._indirect_addressing_fields.update(a.field for a in e.atoms(Field.Access))
+
+            obj._is_absolute_access = is_absolute_access
             return obj
 
         def __getnewargs__(self):
-            return self.field, self.offsets, self.index
+            return self.field, self.offsets, self.index, self.is_absolute_access
 
         # noinspection SpellCheckingInspection
         __xnew__ = staticmethod(__new_stage2__)
@@ -553,6 +562,19 @@ class Field:
             """
             return Field.Access(self.field, self.offsets, idx_tuple)
 
+        @property
+        def is_absolute_access(self) -> bool:
+            """Indicates if a field access is relative to the loop counters (this is the default) or absolute"""
+            return self._is_absolute_access
+
+        @property
+        def indirect_addressing_fields(self) -> Set['Field']:
+            """Returns a set of fields that the access depends on.
+
+             e.g. f[index_field[1, 0]], the outer access to f depends on index_field
+             """
+            return self._indirect_addressing_fields
+
         def _hashable_content(self):
             super_class_contents = list(super(Field.Access, self)._hashable_content())
             t = tuple(super_class_contents + [hash(self._field), self._index] + self._offsets)
diff --git a/transformations.py b/transformations.py
index 32df1ac4d..639a48184 100644
--- a/transformations.py
+++ b/transformations.py
@@ -3,7 +3,6 @@ from collections import defaultdict, OrderedDict, namedtuple
 from copy import deepcopy
 from types import MappingProxyType
 
-import itertools
 import sympy as sp
 from sympy.logic.boolalg import Boolean
 from sympy.tensor import IndexedBase
@@ -392,13 +391,16 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
     def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
         if isinstance(expr, Field.Access):
             field_access = expr
+            field = field_access.field
 
-            if any(isinstance(off, Field.Access) for off in field_access.offsets):
+            if field_access.indirect_addressing_fields:
                 new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
                                     for off in field_access.offsets)
-                field_access = Field.Access(field_access.field, new_offsets, field_access.index)
-
-            field = field_access.field
+                new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
+                                    if isinstance(ind, sp.Basic) else ind
+                                    for ind in field_access.index)
+                field_access = Field.Access(field_access.field, new_offsets,
+                                            new_indices, field_access.is_absolute_access)
 
             if field.name in field_to_base_pointer_info:
                 base_pointer_info = field_to_base_pointer_info[field.name]
@@ -415,7 +417,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
                         if field.name in field_to_fixed_coordinates:
                             coordinates[e] = field_to_fixed_coordinates[field.name][e]
                         else:
-                            coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
+                            if not field_access.is_absolute_access:
+                                coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
+                            else:
+                                coordinates[e] = 0
                         coordinates[e] *= field.dtype.item_size
                     else:
                         if isinstance(field.dtype, StructType):
@@ -719,9 +724,7 @@ class KernelConstraintsCheck:
         self._update_accesses_rhs(rhs)
         if isinstance(rhs, Field.Access):
             self.fields_read.add(rhs.field)
-            for e in itertools.chain(rhs.offsets, rhs.index):
-                if isinstance(e, sp.Basic):
-                    self.fields_read.update(access.field for access in e.atoms(Field.Access))
+            self.fields_read.update(rhs.indirect_addressing_fields)
             return rhs
         elif isinstance(rhs, TypedSymbol):
             return rhs
-- 
GitLab