From f13701df9b6d85602e3ab0fd92fa2275618c02e5 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Tue, 17 Apr 2018 11:03:27 +0200
Subject: [PATCH] WIP: ISL based integer condition optimization

---
 astnodes.py               |  4 +-
 gpucuda/kernelcreation.py | 17 +++++---
 integer_set_analysis.py   | 72 ++++++++++++++++++++++++++++++++
 kernelcreation.py         | 43 +++++++++++++++++++
 transformations.py        | 87 +++++++++++++++++++++++++++++----------
 5 files changed, 193 insertions(+), 30 deletions(-)
 create mode 100644 integer_set_analysis.py

diff --git a/astnodes.py b/astnodes.py
index ebd0c82ec..362eb7a38 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -481,12 +481,12 @@ class SympyAssignment(Node):
             raise ValueError('%s is not in args of %s' % (replacement, self.__class__))
 
     def __repr__(self):
-        return repr(self.lhs) + " = " + repr(self.rhs)
+        return repr(self.lhs) + " ← " + repr(self.rhs)
 
     def _repr_html_(self):
         printed_lhs = sp.latex(self.lhs)
         printed_rhs = sp.latex(self.rhs)
-        return f"${printed_lhs} = {printed_rhs}$"
+        return f"${printed_lhs} \leftarrow {printed_rhs}$"
 
 
 class ResolvedFieldAccess(sp.Indexed):
diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py
index a51801136..d17d23aa0 100644
--- a/gpucuda/kernelcreation.py
+++ b/gpucuda/kernelcreation.py
@@ -43,17 +43,22 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
                                              -ghost_layers[i][1] if ghost_layers[i][1] > 0 else None))
 
     indexing = indexing_creator(field=list(fields_without_buffers)[0], iteration_slice=iteration_slice)
+    coord_mapping = indexing.coordinates
+
+    cell_idx_assignments = [SympyAssignment(LoopOverCoordinate.get_loop_counter_symbol(i), value)
+                            for i, value in enumerate(coord_mapping)]
+    cell_idx_symbols = [LoopOverCoordinate.get_loop_counter_symbol(i) for i, _ in enumerate(coord_mapping)]
+    assignments = cell_idx_assignments + assignments
 
     block = Block(assignments)
     block = indexing.guard(block, common_shape)
     ast = KernelFunction(block, function_name=function_name, ghost_layers=ghost_layers, backend='gpucuda')
     ast.global_variables.update(indexing.index_variables)
 
-    coord_mapping = indexing.coordinates
     base_pointer_info = [['spatialInner0']]
     base_pointer_infos = {f.name: parse_base_pointer_info(base_pointer_info, [2, 1, 0], f) for f in all_fields}
 
-    coord_mapping = {f.name: coord_mapping for f in all_fields}
+    coord_mapping = {f.name: cell_idx_symbols for f in all_fields}
 
     loop_vars = [num_buffer_accesses * i for i in indexing.coordinates]
     loop_strides = list(fields_without_buffers)[0].shape
@@ -102,11 +107,11 @@ def created_indexed_cuda_kernel(assignments, index_fields, function_name="kernel
     spatial_coordinates = list(spatial_coordinates)[0]
 
     def get_coordinate_symbol_assignment(name):
-        for index_field in index_fields:
-            assert isinstance(index_field.dtype, StructType), "Index fields have to have a struct data type"
-            data_type = index_field.dtype
+        for ind_f in index_fields:
+            assert isinstance(ind_f.dtype, StructType), "Index fields have to have a struct data type"
+            data_type = ind_f.dtype
             if data_type.has_element(name):
-                rhs = index_field[0](name)
+                rhs = ind_f[0](name)
                 lhs = TypedSymbol(name, BasicType(data_type.get_element_type(name)))
                 return SympyAssignment(lhs, rhs)
         raise ValueError("Index %s not found in any of the passed index fields" % (name,))
diff --git a/integer_set_analysis.py b/integer_set_analysis.py
new file mode 100644
index 000000000..034f0049b
--- /dev/null
+++ b/integer_set_analysis.py
@@ -0,0 +1,72 @@
+"""Transformations using integer sets based on ISL library"""
+
+import sympy as sp
+import islpy as isl
+from typing import Tuple
+
+import pystencils.astnodes as ast
+from pystencils.transformations import parents_of_type
+
+#context = isl.Context()
+
+"""
+- find all Condition nodes
+- check if they depend on integers only
+- create ISL space containing names of all loop symbols (counter and bounds) and all integers in Conditional expression
+- build up pre-condition set by iteration over each enclosing loop add ISL constraints
+- build up ISL space for condition
+- if pre_condition_set.intersect(conditional_set) == pre_condition_set
+    always use True condition
+  elif pre_condition_set.intersect(conditional_set).is_empty():
+    always use False condition
+"""
+
+
+def isl_iteration_set(node: ast.Node):
+    """Builds up an ISL set describing the iteration space by analysing the enclosing loops of the given node. """
+    conditions = []
+    loop_symbols = set()
+    for loop in parents_of_type(node, ast.LoopOverCoordinate):
+        if loop.step != 1:
+            raise NotImplementedError("Loops with strides != 1 are not yet supported.")
+
+        loop_symbols.add(loop.loop_counter_symbol)
+        loop_symbols.update(sp.sympify(loop.start).atoms(sp.Symbol))
+        loop_symbols.update(sp.sympify(loop.stop).atoms(sp.Symbol))
+
+        loop_start_str = str(loop.start).replace('[', '_bracket1_').replace(']', '_bracket2_')
+        loop_stop_str = str(loop.stop).replace('[', '_bracket1_').replace(']', '_bracket2_')
+        ctr_name = loop.loop_counter_name
+        conditions.append(f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}")
+
+    symbol_names = ','.join([s.name for s in loop_symbols])
+    condition_str = ' and '.join(conditions)
+    set_description = f"{{ [{symbol_names}] : {condition_str} }}"
+    return loop_symbols, isl.BasicSet(set_description)
+
+    for loop in parents_of_type(node, ast.LoopOverCoordinate):
+        ctr_name = loop.loop_counter_name
+        lower_constraint = isl.Constraint.ineq_from_names(space, {ctr_name: 1, 1: -loop.start})
+        upper_constraint = isl.Constraint.ineq_from_names(space, {ctr_name: 1, })
+
+
+def simplify_conditionals_new(ast_node):
+    for conditional in ast_node.atoms(ast.Conditional):
+        if conditional.condition_expr == sp.true:
+            conditional.parent.replace(conditional, [conditional.true_block])
+        elif conditional.condition_expr == sp.false:
+            conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
+        else:
+            loop_symbols, iteration_set = isl_iteration_set(conditional)
+            symbols_in_condition = conditional.condition_expr.atoms(sp.Symbol)
+            if symbols_in_condition.issubset(loop_symbols):
+                symbol_names = ','.join([s.name for s in loop_symbols])
+                condition_str = str(conditional.condition_expr)
+                condition_set = isl.BasicSet(f"{{ [{symbol_names}] : {condition_str} }}")
+
+                intersection = iteration_set.intersect(condition_set)
+                if intersection.is_empty():
+                    conditional.parent.replace(conditional,
+                                               [conditional.false_block] if conditional.false_block else [])
+                elif intersection == iteration_set:
+                    conditional.parent.replace(conditional, [conditional.true_block])
diff --git a/kernelcreation.py b/kernelcreation.py
index 75ad8d775..c29098ade 100644
--- a/kernelcreation.py
+++ b/kernelcreation.py
@@ -1,6 +1,10 @@
 from types import MappingProxyType
+import sympy as sp
+from pystencils.assignment import Assignment
+from pystencils.astnodes import LoopOverCoordinate, Conditional, Block, SympyAssignment
 from pystencils.assignment_collection import AssignmentCollection
 from pystencils.gpucuda.indexing import indexing_creator_from_params
+from pystencils.transformations import remove_conditionals_in_staggered_kernel
 
 
 def create_kernel(equations, target='cpu', data_type="double", iteration_slice=None, ghost_layers=None,
@@ -104,3 +108,42 @@ def create_indexed_kernel(assignments, index_fields, target='cpu', data_type="do
         return ast
     else:
         raise ValueError("Unknown target %s. Has to be either 'cpu' or 'gpu'" % (target,))
+
+
+def create_staggered_kernel(staggered_field, expressions, subexpressions=(), target='cpu', **kwargs):
+    """Kernel that updates a staggered field.
+
+    Args:
+        staggered_field: field that has one index coordinate and
+                where e.g. f[0,0](0) is interpreted as value at the left cell boundary, f[1,0](0) the right cell
+                boundary and f[0,0](1) the southern cell boundary etc.
+        expressions: sequence of expressions of length dim, defining how the east, southern, (bottom) cell boundary
+                     should be update
+        subexpressions: optional sequence of Assignments, that define subexpressions used in the main expressions
+        target: 'cpu' or 'gpu'
+        kwargs: passed directly to create_kernel, iteration slice and ghost_layers parameters are not allowed
+    Returns:
+        AST
+    """
+    assert 'iteration_slice' not in kwargs and 'ghost_layers' not in kwargs
+    assert staggered_field.index_dimensions == 1, 'Staggered field must have exactly one index dimension'
+    dim = staggered_field.spatial_dimensions
+
+    counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)]
+    conditions = [counters[i] < staggered_field.shape[i] - 1 for i in range(dim)]
+    assert len(expressions) == dim
+    final_assignments = []
+    for d in range(dim):
+        cond = sp.And(*[conditions[i] for i in range(dim) if d != i])
+        a_coll = AssignmentCollection([Assignment(staggered_field(d), expressions[d])], list(subexpressions))
+        a_coll = a_coll.new_filtered([staggered_field(d)])
+        sp_assignments = [SympyAssignment(a.lhs, a.rhs) for a in a_coll.all_assignments]
+        final_assignments.append(Conditional(cond, Block(sp_assignments)))
+    ghost_layers = [(1, 0)] * dim
+
+    ast = create_kernel(final_assignments, ghost_layers=ghost_layers, target=target, **kwargs)
+
+    if target == 'cpu':
+        remove_conditionals_in_staggered_kernel(ast)
+
+    return ast
diff --git a/transformations.py b/transformations.py
index 831e4af89..d7e7f160b 100644
--- a/transformations.py
+++ b/transformations.py
@@ -582,16 +582,27 @@ def cut_loop(loop_node, cutting_points):
 
 def is_condition_necessary(condition, pre_condition, symbol):
     """
-    Determines if a logical condition of a single variable is already contained in a stronger preCondition
-    so if from preCondition follows that condition is always true, then this condition is not necessary
-    :param condition: sympy relational of one variable
-    :param pre_condition: logical expression that is known to be true
-    :param symbol: the single symbol of interest
-    :return: returns  not (preCondition => condition) where "=>" is logical implication
+    Determines if a logical condition of a single variable is already contained in a stronger pre_condition
+    so if from pre_condition follows that condition is always true, then this condition is not necessary
+
+    Args:
+        condition: sympy relational of one variable
+        pre_condition: logical expression that is known to be true
+        symbol: the single symbol of interest
+
+    Returns:
+        returns  not (pre_condition => condition) where "=>" is logical implication
     """
     from sympy.solvers.inequalities import reduce_rational_inequalities
     from sympy.logic.boolalg import to_dnf
 
+    def normalize_relational(e):
+        if isinstance(e, sp.Rel):
+            return e.func(e.lhs - e.rhs, 0)
+        else:
+            new_args = [normalize_relational(a) for a in e.args]
+            return e.func(*new_args) if new_args else e
+
     def to_dnf_list(expr):
         result = to_dnf(expr)
         if isinstance(result, sp.Or):
@@ -599,8 +610,12 @@ def is_condition_necessary(condition, pre_condition, symbol):
         elif isinstance(result, sp.And):
             return [result.args]
         else:
-            return result
+            return [result]
 
+    condition = normalize_relational(condition)
+    pre_condition = normalize_relational(pre_condition)
+    a1 = to_dnf_list(pre_condition)
+    a2 = to_dnf_list(condition)
     t1 = reduce_rational_inequalities(to_dnf_list(sp.And(condition, pre_condition)), symbol)
     t2 = reduce_rational_inequalities(to_dnf_list(pre_condition), symbol)
     return t1 != t2
@@ -619,12 +634,11 @@ def simplify_boolean_expression(expr, single_variable_ranges):
 
     def visit(e):
         if isinstance(e, Relational):
-            symbols = e.atoms(sp.Symbol)
+            symbols = e.atoms(sp.Symbol).intersection(single_variable_ranges.keys())
             if len(symbols) == 1:
                 symbol = symbols.pop()
-                if symbol in single_variable_ranges:
-                    if not is_condition_necessary(e, single_variable_ranges[symbol], symbol):
-                        return sp.true
+                if not is_condition_necessary(e, single_variable_ranges[symbol], symbol):
+                    return sp.true
             return e
         else:
             new_args = [visit(a) for a in e.args]
@@ -635,24 +649,23 @@ def simplify_boolean_expression(expr, single_variable_ranges):
 
 def simplify_conditionals(node, loop_conditionals=MappingProxyType({})):
     """Simplifies/Removes conditions inside loops that depend on the loop counter."""
-    loop_conditionals = loop_conditionals.copy()
     if isinstance(node, ast.LoopOverCoordinate):
         ctr_sym = node.loop_counter_symbol
+        loop_conditionals = loop_conditionals.copy()
         loop_conditionals[ctr_sym] = sp.And(ctr_sym >= node.start, ctr_sym < node.stop)
-        simplify_conditionals(node.body)
-        del loop_conditionals[ctr_sym]
+        simplify_conditionals(node.body, loop_conditionals)
     elif isinstance(node, ast.Conditional):
         node.condition_expr = simplify_boolean_expression(node.condition_expr, loop_conditionals)
         simplify_conditionals(node.true_block)
         if node.false_block:
-            simplify_conditionals(node.false_block)
+            simplify_conditionals(node.false_block, loop_conditionals)
         if node.condition_expr == sp.true:
             node.parent.replace(node, [node.true_block])
         if node.condition_expr == sp.false:
             node.parent.replace(node, [node.false_block] if node.false_block else [])
     elif isinstance(node, ast.Block):
         for a in list(node.args):
-            simplify_conditionals(a)
+            simplify_conditionals(a, loop_conditionals)
     elif isinstance(node, ast.SympyAssignment):
         return node
     else:
@@ -829,6 +842,22 @@ def insert_casts(node):
     return node.func(*args)
 
 
+def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
+    """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""
+
+    all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
+    assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
+    inner_loop = all_inner_loops.pop()
+
+    for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
+        cut_loop(loop, [loop.stop-1])
+
+    simplify_conditionals(function_node.body)
+    cleanup_blocks(function_node.body)
+    move_constants_before_loop(function_node.body)
+    cleanup_blocks(function_node.body)
+
+
 # --------------------------------------- Helper Functions -------------------------------------------------------------
 
 
@@ -836,9 +865,12 @@ def typing_from_sympy_inspection(eqs, default_type="double"):
     """
     Creates a default symbol name to type mapping.
     If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
-    :param eqs: list of equations
-    :param default_type: the type for non-boolean symbols
-    :return: dictionary, mapping symbol name to type
+
+    Args:
+        eqs: list of equations
+        default_type: the type for non-boolean symbols
+    Returns:
+        dictionary, mapping symbol name to type
     """
     result = defaultdict(lambda: default_type)
     for eq in eqs:
@@ -863,6 +895,16 @@ def get_next_parent_of_type(node, parent_type):
     return None
 
 
+def parents_of_type(node, parent_type, include_current=False):
+    """Similar to get_next_parent_of_type, but as generator"""
+    parent = node if include_current else node.parent
+    while parent is not None:
+        if isinstance(parent, parent_type):
+            yield parent
+        parent = parent.parent
+    return None
+
+
 def get_optimal_loop_ordering(fields):
     """
     Determines the optimal loop order for a given set of fields.
@@ -886,9 +928,10 @@ def get_optimal_loop_ordering(fields):
 
 
 def get_loop_hierarchy(ast_node):
-    """Determines the loop structure around a given AST node.
-    :param ast_node: the AST node
-    :return: list of coordinate ids, where the first list entry is the innermost loop
+    """Determines the loop structure around a given AST node, i.e. the node has to be inside the loops.
+
+    Returns:
+        sequence of LoopOverCoordinate nodes, starting from outer loop to innermost loop
     """
     result = []
     node = ast_node
-- 
GitLab