From 53caa7e05e1bef3f60bde1581414fe745c1a9b95 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Tue, 24 Apr 2018 17:54:45 +0200
Subject: [PATCH] pystencils: systematic checks for conditions on kernel
 assignments

- SSA form is checked
- loop independence condition is checked
- bug fix in Field.create_generic when using index_shape
---
 cpu/kernelcreation.py     |  11 ++-
 field.py                  |   4 +
 gpucuda/kernelcreation.py |   8 +-
 gpucuda/periodicity.py    |   2 +-
 kernelcreation.py         |  46 +++++-----
 transformations.py        | 178 +++++++++++++++++++++++++-------------
 6 files changed, 160 insertions(+), 89 deletions(-)

diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index c2f7b7a64..6884f3d8c 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -2,7 +2,7 @@ import sympy as sp
 from functools import partial
 from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
 from pystencils.transformations import resolve_buffer_accesses, resolve_field_accesses, make_loop_over_domain, \
-    type_all_equations, get_optimal_loop_ordering, parse_base_pointer_info, move_constants_before_loop, \
+    add_types, get_optimal_loop_ordering, parse_base_pointer_info, move_constants_before_loop, \
     split_inner_loop, substitute_array_accesses_with_constants
 from pystencils.data_types import TypedSymbol, BasicType, StructType, create_type
 from pystencils.field import Field, FieldType
@@ -15,7 +15,8 @@ AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]]
 
 
 def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "kernel", type_info='double',
-                  split_groups=(), iteration_slice=None, ghost_layers=None) -> KernelFunction:
+                  split_groups=(), iteration_slice=None, ghost_layers=None,
+                  skip_independence_check=False) -> KernelFunction:
     """
     Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
 
@@ -34,6 +35,8 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
         ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
                      if None, the number of ghost layers is determined automatically and assumed to be equal for a
                      all dimensions
+        skip_independence_check: don't check that loop iterations are independent. This is needed e.g. for
+                                 periodicity kernel, that access the field outside the iteration bounds. Use with care!
 
     Returns:
         AST node representing a function, that can be printed as C or CUDA code
@@ -50,7 +53,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
         else:
             raise ValueError("Term has to be field access or symbol")
 
-    fields_read, fields_written, assignments = type_all_equations(assignments, type_info)
+    fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
     all_fields = fields_read.union(fields_written)
     read_only_fields = set([f.name for f in fields_read - fields_written])
 
@@ -108,7 +111,7 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu
         function_name: see documentation of :func:`create_kernel`
         coordinate_names: name of the coordinate fields in the struct data type
     """
-    fields_read, fields_written, assignments = type_all_equations(assignments, type_info)
+    fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False)
     all_fields = fields_read.union(fields_written)
 
     for index_field in index_fields:
diff --git a/field.py b/field.py
index 3d6460776..29112e605 100644
--- a/field.py
+++ b/field.py
@@ -165,6 +165,9 @@ class Field:
                         that should be iterated over, and BUFFER fields that are used to generate
                         communication packing/unpacking kernels
         """
+        if index_shape is not None:
+            assert index_dimensions == 0 or index_dimensions == len(index_shape)
+            index_dimensions = len(index_shape)
         if isinstance(layout, str):
             layout = spatial_layout_string_to_tuple(layout, dim=spatial_dimensions)
         shape_symbol = IndexedBase(TypedSymbol(Field.SHAPE_PREFIX + field_name, Field.SHAPE_DTYPE), shape=(1,))
@@ -260,6 +263,7 @@ class Field:
         """Do not use directly. Use static create* methods"""
         self._field_name = field_name
         assert isinstance(field_type, FieldType)
+        assert len(shape) == len(strides)
         self.field_type = field_type
         self._dtype = create_type(dtype)
         self._layout = normalize_layout(layout)
diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py
index 460db7e12..bb60dedfa 100644
--- a/gpucuda/kernelcreation.py
+++ b/gpucuda/kernelcreation.py
@@ -1,7 +1,7 @@
 from functools import partial
 
 from pystencils.gpucuda.indexing import BlockIndexing
-from pystencils.transformations import resolve_field_accesses, type_all_equations, parse_base_pointer_info, \
+from pystencils.transformations import resolve_field_accesses, add_types, parse_base_pointer_info, \
     get_common_shape, substitute_array_accesses_with_constants, resolve_buffer_accesses, unify_shape_symbols
 from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate
 from pystencils.data_types import TypedSymbol, BasicType, StructType
@@ -10,8 +10,8 @@ from pystencils.gpucuda.cudajit import make_python_function
 
 
 def create_cuda_kernel(assignments, function_name="kernel", type_info=None, indexing_creator=BlockIndexing,
-                       iteration_slice=None, ghost_layers=None):
-    fields_read, fields_written, assignments = type_all_equations(assignments, type_info)
+                       iteration_slice=None, ghost_layers=None, skip_independence_check=False):
+    fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
     all_fields = fields_read.union(fields_written)
     read_only_fields = set([f.name for f in fields_read - fields_written])
 
@@ -93,7 +93,7 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
 
 def created_indexed_cuda_kernel(assignments, index_fields, function_name="kernel", type_info=None,
                                 coordinate_names=('x', 'y', 'z'), indexing_creator=BlockIndexing):
-    fields_read, fields_written, assignments = type_all_equations(assignments, type_info)
+    fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False)
     all_fields = fields_read.union(fields_written)
     read_only_fields = set([f.name for f in fields_read - fields_written])
 
diff --git a/gpucuda/periodicity.py b/gpucuda/periodicity.py
index 39e3737ff..5657d4618 100644
--- a/gpucuda/periodicity.py
+++ b/gpucuda/periodicity.py
@@ -23,7 +23,7 @@ def create_copy_kernel(domain_size, from_slice, to_slice, index_dimensions=0, in
         eq = Assignment(f(i), f[tuple(offset)](i))
         update_eqs.append(eq)
 
-    ast = create_cuda_kernel(update_eqs, iteration_slice=to_slice)
+    ast = create_cuda_kernel(update_eqs, iteration_slice=to_slice, skip_independence_check=True)
     return make_python_function(ast)
 
 
diff --git a/kernelcreation.py b/kernelcreation.py
index be918b0bd..45eebc34c 100644
--- a/kernelcreation.py
+++ b/kernelcreation.py
@@ -12,27 +12,28 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice
                   gpu_indexing='block', gpu_indexing_params=MappingProxyType({})):
     """
     Creates abstract syntax tree (AST) of kernel, using a list of update equations.
-    :param assignments: either be a plain list of equations or a AssignmentCollection object
-    :param target: 'cpu', 'llvm' or 'gpu'
-    :param data_type: data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name
-                     to type
-    :param iteration_slice: rectangular subset to iterate over, if not specified the complete non-ghost layer \
-                            part of the field is iterated over
-    :param ghost_layers: if left to default, the number of necessary ghost layers is determined automatically
-                        a single integer specifies the ghost layer count at all borders, can also be a sequence of
-                        pairs [(x_lower_gl, x_upper_gl), .... ]
-
-    CPU specific Parameters:
-    :param cpu_openmp: True or number of threads for OpenMP parallelization, False for no OpenMP
-    :param cpu_vectorize_info: pair of instruction set name ('sse, 'avx', 'avx512') and data type ('float', 'double')
-
-    GPU specific Parameters
-    :param gpu_indexing: either 'block' or 'line' , or custom indexing class (see gpucuda/indexing.py)
-    :param gpu_indexing_params: dict with indexing parameters (constructor parameters of indexing class)
-                              e.g. for 'block' one can specify {'block_size': (20, 20, 10) }
-
-    :return: abstract syntax tree object, that can either be printed as source code or can be compiled with
-             through its compile() function
+
+    Args:
+        assignments: either be a plain list of equations or a AssignmentCollection object
+        target: 'cpu', 'llvm' or 'gpu'
+        data_type: data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name
+                  to type
+        iteration_slice: rectangular subset to iterate over, if not specified the complete non-ghost layer \
+                         part of the field is iterated over
+        ghost_layers: if left to default, the number of necessary ghost layers is determined automatically
+                     a single integer specifies the ghost layer count at all borders, can also be a sequence of
+                     pairs [(x_lower_gl, x_upper_gl), .... ]
+
+        cpu_openmp: True or number of threads for OpenMP parallelization, False for no OpenMP
+        cpu_vectorize_info: pair of instruction set name ('sse, 'avx', 'avx512') and data type ('float', 'double')
+
+        gpu_indexing: either 'block' or 'line' , or custom indexing class (see gpucuda/indexing.py)
+        gpu_indexing_params: dict with indexing parameters (constructor parameters of indexing class)
+                           e.g. for 'block' one can specify {'block_size': (20, 20, 10) }
+
+    Returns:
+        abstract syntax tree object, that can either be printed as source code with `show_code` or can be compiled with
+        through its `compile()` member
     """
 
     # ----  Normalizing parameters
@@ -124,8 +125,9 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar
         subexpressions: optional sequence of Assignments, that define subexpressions used in the main expressions
         target: 'cpu' or 'gpu'
         kwargs: passed directly to create_kernel, iteration slice and ghost_layers parameters are not allowed
+
     Returns:
-        AST
+        AST, see `create_kernel`
     """
     assert 'iteration_slice' not in kwargs and 'ghost_layers' not in kwargs
     assert staggered_field.index_dimensions == 1, 'Staggered field must have exactly one index dimension'
diff --git a/transformations.py b/transformations.py
index e5158d556..2347fb7e0 100644
--- a/transformations.py
+++ b/transformations.py
@@ -1,5 +1,5 @@
 import warnings
-from collections import defaultdict, OrderedDict
+from collections import defaultdict, OrderedDict, namedtuple
 from copy import deepcopy
 from types import MappingProxyType
 import sympy as sp
@@ -139,17 +139,21 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
 
 def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
     r"""
-    Addressing elements in structured arrays are done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
+    Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
     where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
     The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
     This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
     Returns a new typed symbol, where the name encodes which coordinates have been resolved.
-    :param field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
-    :param coordinates: mapping of coordinate ids to its value, where stride*value is calculated
-    :param previous_ptr: the pointer which is de-referenced
-    :return: tuple with the new pointer symbol and the calculated offset
 
-    Example:
+    Args:
+        field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
+        coordinates: mapping of coordinate ids to its value, where stride*value is calculated
+        previous_ptr: the pointer which is de-referenced
+
+    Returns
+        tuple with the new pointer symbol and the calculated offset
+
+    Examples:
         >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
         >>> x, y = sp.symbols("x y")
         >>> prev_pointer = TypedSymbol("ptr", "double")
@@ -193,7 +197,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field):
 
     Specification of how many and which intermediate pointers are created for a field access.
     For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
-    zero directly in the field access. These specifications are more sensible defined dependent on the loop ordering.
+    zero directly in the field access. These specifications are defined dependent on the loop ordering.
     This function translates more readable version into the specification above.
 
     Allowed specifications:
@@ -362,13 +366,16 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
     """
     Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing
 
-    :param ast_node: the AST root
-    :param read_only_field_names: set of field names which are considered read-only
-    :param field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
-                                   for details see :func:`parse_base_pointer_info`
-    :param field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
+    Args:
+        ast_node: the AST root
+        read_only_field_names: set of field names which are considered read-only
+        field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
+                                    for details see :func:`parse_base_pointer_info`
+        field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
                                     counters to index the field these symbols are used as coordinates
-    :return: transformed AST
+
+    Returns
+        transformed AST
     """
     field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
     field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
@@ -393,8 +400,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
                         if field.name in field_to_fixed_coordinates:
                             coordinates[e] = field_to_fixed_coordinates[field.name][e]
                         else:
-                            ctr_name = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
-                            coordinates[e] = TypedSymbol("%s_%d" % (ctr_name, e), 'int')
+                            coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
                         coordinates[e] *= field.dtype.item_size
                     else:
                         if isinstance(field.dtype, StructType):
@@ -418,7 +424,6 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
                 last_pointer = new_ptr
 
             coord_dict = create_coordinate_dict(base_pointer_info[0])
-
             _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
             result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
                                              field_access.offsets, field_access.index)
@@ -652,68 +657,125 @@ def symbol_name_to_variable_name(symbol_name):
     return symbol_name.replace("^", "_")
 
 
-def type_all_equations(eqs, type_for_symbol):
+class KernelConstraintsCheck:
+    """Checks if the input to create_kernel is valid.
+
+    Test the following conditions:
+
+    - SSA Form for pure symbols:
+        -  Every pure symbol may occur only once as left-hand-side of an assignment
+        -  Every pure symbol that is read, may not be written to later
+    - Independence / Parallelization condition:
+        - a field that is written may only be read at exact the same spatial position
+
+    (Pure symbols are symbols that are not Field.Accesses)
     """
-    Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
+    FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
+
+    def __init__(self, type_for_symbol, check_independence_condition):
+        self._type_for_symbol = type_for_symbol
+        self._defined_pure_symbols = set()
+        self._accessed_pure_symbols = set()
+
+        self._field_writes = defaultdict(set)
+        self.fields_read = set()
+        self.check_independence_condition = check_independence_condition
+
+    def process_assignment(self, assignment):
+        # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
+        new_rhs = self.process_expression(assignment.rhs)
+        new_lhs = self._process_lhs(assignment.lhs)
+        return ast.SympyAssignment(new_lhs, new_rhs)
+
+    def process_expression(self, rhs):
+        self._update_accesses_rhs(rhs)
+        if isinstance(rhs, Field.Access):
+            return rhs
+        elif isinstance(rhs, TypedSymbol):
+            return rhs
+        elif isinstance(rhs, sp.Symbol):
+            return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name])
+        else:
+            new_args = [self.process_expression(arg) for arg in rhs.args]
+            return rhs.func(*new_args) if new_args else rhs
+
+    @property
+    def fields_written(self):
+        return set(k.field for k, v in self._field_writes.items() if len(v))
+
+    def _process_lhs(self, lhs):
+        assert isinstance(lhs, sp.Symbol)
+        self._update_accesses_lhs(lhs)
+        if not isinstance(lhs, Field.Access) and not isinstance(lhs, TypedSymbol):
+            return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
+        else:
+            return lhs
+
+    def _update_accesses_lhs(self, lhs):
+        if isinstance(lhs, Field.Access):
+            fai = self.FieldAndIndex(lhs.field, lhs.index)
+            self._field_writes[fai].add(lhs.offsets)
+            if len(self._field_writes[fai]) > 1:
+                raise ValueError(f"Field {lhs.field.name} is written at two different locations")
+        elif isinstance(lhs, sp.Symbol):
+            if lhs in self._defined_pure_symbols:
+                raise ValueError(f"Assignments not in SSA form, multiple assignments to {lhs.name}")
+            if lhs in self._accessed_pure_symbols:
+                raise ValueError(f"Symbol {lhs.name} is written, after it has been read")
+            self._defined_pure_symbols.add(lhs)
+
+    def _update_accesses_rhs(self, rhs):
+        if isinstance(rhs, Field.Access) and self.check_independence_condition:
+            writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
+            for write_offset in writes:
+                assert len(writes) == 1
+                if write_offset != rhs.offsets:
+                    raise ValueError(f"Violation of loop independence condition. "
+                                     f"Field {rhs.field} is read at {rhs.offsets} and written at {write_offset}")
+            self.fields_read.add(rhs.field)
+        elif isinstance(rhs, sp.Symbol):
+            self._accessed_pure_symbols.add(rhs)
+
+
+def add_types(eqs, type_for_symbol, check_independence_condition):
+    """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
+
     Additionally returns sets of all fields which are read/written
 
-    :param eqs: list of equations
-    :param type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
-    :return: ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
-              list of equations where symbols have been replaced by typed symbols
+    Args:
+        eqs: list of equations
+        type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
+        check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
+                                      kernels
+
+    Returns:
+        ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
+         list of equations where symbols have been replaced by typed symbols
     """
     if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
         type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
 
-    fields_written = set()
-    fields_read = set()
-
-    def process_rhs(term):
-        """Replaces Symbols by:
-            - TypedSymbol if symbol is not a field access
-        """
-        if isinstance(term, Field.Access):
-            fields_read.add(term.field)
-            return term
-        elif isinstance(term, TypedSymbol):
-            return term
-        elif isinstance(term, sp.Symbol):
-            return TypedSymbol(symbol_name_to_variable_name(term.name), type_for_symbol[term.name])
-        else:
-            new_args = [process_rhs(arg) for arg in term.args]
-            return term.func(*new_args) if new_args else term
-
-    def process_lhs(term):
-        """Replaces symbol by TypedSymbol and adds field to fieldsWriten"""
-        if isinstance(term, Field.Access):
-            fields_written.add(term.field)
-            return term
-        elif isinstance(term, TypedSymbol):
-            return term
-        elif isinstance(term, sp.Symbol):
-            return TypedSymbol(term.name, type_for_symbol[term.name])
-        else:
-            assert False, "Expected a symbol as left-hand-side"
+    check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
 
     def visit(obj):
         if isinstance(obj, list) or isinstance(obj, tuple):
             return [visit(e) for e in obj]
         if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
-            new_lhs = process_lhs(obj.lhs)
-            new_rhs = process_rhs(obj.rhs)
-            return ast.SympyAssignment(new_lhs, new_rhs)
+            return check.process_assignment(obj)
         elif isinstance(obj, ast.Conditional):
             false_block = None if obj.false_block is None else visit(obj.false_block)
-            return ast.Conditional(process_rhs(obj.condition_expr),
+            return ast.Conditional(check.process_expression(obj.condition_expr),
                                    true_block=visit(obj.true_block), false_block=false_block)
         elif isinstance(obj, ast.Block):
             return ast.Block([visit(e) for e in obj.args])
-        else:
+        elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
             return obj
+        else:
+            raise ValueError("Invalid object in kernel " + str(type(obj)))
 
     typed_equations = visit(eqs)
 
-    return fields_read, fields_written, typed_equations
+    return check.fields_read, check.fields_written, typed_equations
 
 
 def insert_casts(node):
-- 
GitLab