diff --git a/transformations.py b/transformations.py index 8b288bb4c9f3a0f67b47fd27362cf55a8f423b41..32df1ac4dabc43dfe0d5e68971b11cb8bb96deb8 100644 --- a/transformations.py +++ b/transformations.py @@ -2,6 +2,8 @@ import warnings from collections import defaultdict, OrderedDict, namedtuple from copy import deepcopy from types import MappingProxyType + +import itertools import sympy as sp from sympy.logic.boolalg import Boolean from sympy.tensor import IndexedBase @@ -390,6 +392,12 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), def visit_sympy_expr(expr, enclosing_block, sympy_assignment): if isinstance(expr, Field.Access): field_access = expr + + if any(isinstance(off, Field.Access) for off in field_access.offsets): + new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment) + for off in field_access.offsets) + field_access = Field.Access(field_access.field, new_offsets, field_access.index) + field = field_access.field if field.name in field_to_base_pointer_info: @@ -711,6 +719,9 @@ class KernelConstraintsCheck: self._update_accesses_rhs(rhs) if isinstance(rhs, Field.Access): self.fields_read.add(rhs.field) + for e in itertools.chain(rhs.offsets, rhs.index): + if isinstance(e, sp.Basic): + self.fields_read.update(access.field for access in e.atoms(Field.Access)) return rhs elif isinstance(rhs, TypedSymbol): return rhs