From a822ffc940f27e766d5426acbc9c78ee83dcd6a7 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 11 Oct 2019 08:09:07 +0200
Subject: [PATCH] Better boolean support in sympy printer

- bugfix: loop counter of vectorized loop now correctly stored as SIMD vector with entries i, i+1, i+2, ...
- basis for in-kernel boundary handling
---
 pystencils/astnodes.py                       | 18 ++++++++++++++++--
 pystencils/backends/cbackend.py              | 15 +++++++++++----
 pystencils/backends/simd_instruction_sets.py | 19 ++++++++++++++++++-
 pystencils/cpu/vectorization.py              | 13 +++++++++----
 pystencils/sympyextensions.py                |  2 +-
 pystencils/transformations.py                |  7 ++++++-
 6 files changed, 61 insertions(+), 13 deletions(-)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 47f1fd7d1..3f67248f6 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -293,6 +293,10 @@ class Block(Node):
         for a in self.args:
             a.subs(subs_dict)
 
+    def fast_subs(self, subs_dict, skip=None):
+        self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes]
+        return self
+
     def insert_front(self, node):
         if isinstance(node, collections.abc.Iterable):
             node = list(node)
@@ -408,6 +412,16 @@ class LoopOverCoordinate(Node):
         if hasattr(self.step, "subs"):
             self.step = self.step.subs(subs_dict)
 
+    def fast_subs(self, subs_dict, skip=None):
+        self.body = fast_subs(self.body, subs_dict, skip)
+        if isinstance(self.start, sp.Basic):
+            self.start = fast_subs(self.start, subs_dict, skip)
+        if isinstance(self.stop, sp.Basic):
+            self.stop = fast_subs(self.stop, subs_dict, skip)
+        if isinstance(self.step, sp.Basic):
+            self.step = fast_subs(self.step, subs_dict, skip)
+        return self
+
     @property
     def args(self):
         result = [self.body]
@@ -538,7 +552,7 @@ class SympyAssignment(Node):
 
     @property
     def args(self):
-        return [self._lhs_symbol, self.rhs]
+        return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)]
 
     @property
     def symbols_defined(self):
@@ -603,7 +617,7 @@ class ResolvedFieldAccess(sp.Indexed):
                                    self.args[1].subs(old, new),
                                    self.field, self.offsets, self.idx_coordinate_values)
 
-    def fast_subs(self, substitutions):
+    def fast_subs(self, substitutions, skip=None):
         if self in substitutions:
             return substitutions[self]
         return ResolvedFieldAccess(self.args[0].subs(substitutions),
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 556065251..0ae5e3640 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -5,7 +5,6 @@ import numpy as np
 import sympy as sp
 from sympy.core import S
 from sympy.printing.ccode import C89CodePrinter
-
 from pystencils.astnodes import KernelFunction, Node
 from pystencils.cpu.vectorization import vec_all, vec_any
 from pystencils.data_types import (
@@ -457,7 +456,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         elif isinstance(expr, cast_func):
             arg, data_type = expr.args
             if type(data_type) is VectorType:
-                return self.instruction_set['makeVec'].format(self._print(arg))
+                if isinstance(arg, sp.Tuple):
+                    is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
+                    printed_args = [self._print(a) for a in arg]
+                    instruction = 'makeVecBool' if is_boolean else 'makeVec'
+                    return self.instruction_set[instruction].format(*printed_args)
+                else:
+                    is_boolean = get_type_of_expression(arg) == create_type("bool")
+                    instruction = 'makeVecConstBool' if is_boolean else 'makeVecConst'
+                    return self.instruction_set[instruction].format(self._print(arg))
         elif expr.func == fast_division:
             result = self._scalarFallback('_print_Function', expr)
             if not result:
@@ -542,12 +549,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         if result:
             return result
 
-        one = self.instruction_set['makeVec'].format(1.0)
+        one = self.instruction_set['makeVecConst'].format(1.0)
 
         if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
             return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
         elif expr.exp == -1:
-            one = self.instruction_set['makeVec'].format(1.0)
+            one = self.instruction_set['makeVecConst'].format(1.0)
             return self.instruction_set['/'].format(one, self._print(expr.base))
         elif expr.exp == 0.5:
             return self.instruction_set['sqrt'].format(self._print(expr.base))
diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py
index f72b0fec3..4415b9ab6 100644
--- a/pystencils/backends/simd_instruction_sets.py
+++ b/pystencils/backends/simd_instruction_sets.py
@@ -21,7 +21,10 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
 
         'sqrt': 'sqrt[0]',
 
+        'makeVecConst': 'set[]',
         'makeVec': 'set[]',
+        'makeVecBool': 'set[]',
+        'makeVecConstBool': 'set[]',
         'makeZero': 'setzero[]',
 
         'loadU': 'loadu[0]',
@@ -68,8 +71,17 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
         function_shortcut = function_shortcut.strip()
         name = function_shortcut[:function_shortcut.index('[')]
 
-        if intrinsic_id == 'makeVec':
+        if intrinsic_id == 'makeVecConst':
             arg_string = "({})".format(",".join(["{0}"] * result['width']))
+        elif intrinsic_id == 'makeVec':
+            params = ["{" + str(i) + "}" for i in reversed(range(result['width']))]
+            arg_string = "({})".format(",".join(params))
+        elif intrinsic_id == 'makeVecBool':
+            params = ["(({{{i}}} ? -1.0 : 0.0)".format(i=i) for i in reversed(range(result['width']))]
+            arg_string = "({})".format(",".join(params))
+        elif intrinsic_id == 'makeVecConstBool':
+            params = ["(({0}) ? -1.0 : 0.0)" for _ in range(result['width'])]
+            arg_string = "({})".format(",".join(params))
         else:
             args = function_shortcut[function_shortcut.index('[') + 1: -1]
             arg_string = "("
@@ -111,6 +123,11 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
         result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,)
         result['bool'] = "__mmask%d" % (size,)
 
+        params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)])
+        result['makeVecBool'] = "__mmask8(({}) )".format(params)
+        params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
+        result['makeVecConstBool'] = "__mmask8(({}) )".format(params)
+
     if instruction_set == 'avx' and data_type == 'float':
         result['rsqrt'] = "_mm256_rsqrt_ps({0})"
 
diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index 6bf3a26de..d2f206722 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -18,12 +18,12 @@ from pystencils.transformations import (
 
 # noinspection PyPep8Naming
 class vec_any(sp.Function):
-    nargs = (1, )
+    nargs = (1,)
 
 
 # noinspection PyPep8Naming
 class vec_all(sp.Function):
-    nargs = (1, )
+    nargs = (1,)
 
 
 def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
@@ -53,7 +53,7 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
     """
     if instruction_set is None:
         return
-    
+
     all_fields = kernel_ast.fields_accessed
     if nontemporal is None or nontemporal is False:
         nontemporal = {}
@@ -101,7 +101,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
             if len(loop_nodes) == 0:
                 continue
             loop_node = loop_nodes[0]
-        
+
         # Find all array accesses (indexed) that depend on the loop counter as offset
         loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over)
         substitutions = {}
@@ -130,6 +130,11 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
 
         loop_node.step = vector_width
         loop_node.subs(substitutions)
+        vector_loop_counter = cast_func(tuple(loop_counter_symbol + i for i in range(vector_width)),
+                                        VectorType(loop_counter_symbol.dtype, vector_width))
+
+        fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
+                  skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access))
 
 
 def insert_vector_casts(ast_node):
diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index afdf0fde3..7d25f49c7 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -160,7 +160,7 @@ def fast_subs(expression: T, substitutions: Dict,
         if skip and skip(expr):
             return expr
         if hasattr(expr, "fast_subs"):
-            return expr.fast_subs(substitutions)
+            return expr.fast_subs(substitutions, skip)
         if expr in substitutions:
             return substitutions[expr]
         if not hasattr(expr, 'args'):
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index 1bfb0511a..f43b5bdc6 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -481,6 +481,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
                         if isinstance(field.dtype, StructType):
                             assert field.index_dimensions == 1
                             accessed_field_name = field_access.index[0]
+                            if isinstance(accessed_field_name, sp.Symbol):
+                                accessed_field_name = accessed_field_name.name
                             assert isinstance(accessed_field_name, str)
                             coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
                         else:
@@ -504,7 +506,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
                                              field_access.offsets, field_access.index)
 
             if isinstance(get_base_type(field_access.field.dtype), StructType):
-                new_type = field_access.field.dtype.get_element_type(field_access.index[0])
+                accessed_field_name = field_access.index[0]
+                if isinstance(accessed_field_name, sp.Symbol):
+                    accessed_field_name = accessed_field_name.name
+                new_type = field_access.field.dtype.get_element_type(accessed_field_name)
                 result = reinterpret_cast_func(result, new_type)
 
             return visit_sympy_expr(result, enclosing_block, sympy_assignment)
-- 
GitLab