From 8833346ccb99e761504a8ec45eee3ee04be2495c Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Mon, 14 May 2018 17:52:38 +0200
Subject: [PATCH] Fixes in vectorization to also support float kernels

---
 transformations.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/transformations.py b/transformations.py
index d108b7011..8b288bb4c 100644
--- a/transformations.py
+++ b/transformations.py
@@ -707,7 +707,7 @@ class KernelConstraintsCheck:
         new_lhs = self._process_lhs(assignment.lhs)
         return ast.SympyAssignment(new_lhs, new_rhs)
 
-    def process_expression(self, rhs):
+    def process_expression(self, rhs, type_constants=True):
         self._update_accesses_rhs(rhs)
         if isinstance(rhs, Field.Access):
             self.fields_read.add(rhs.field)
@@ -716,19 +716,19 @@ class KernelConstraintsCheck:
             return rhs
         elif isinstance(rhs, sp.Symbol):
             return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name])
-        elif isinstance(rhs, sp.Number):
+        elif type_constants and isinstance(rhs, sp.Number):
             return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
         elif isinstance(rhs, sp.Mul):
-            new_args = [self.process_expression(arg) if arg not in (-1, 1) else arg for arg in rhs.args]
+            new_args = [self.process_expression(arg, type_constants) if arg not in (-1, 1) else arg for arg in rhs.args]
             return rhs.func(*new_args) if new_args else rhs
         elif isinstance(rhs, sp.Indexed):
             return rhs
         else:
             if isinstance(rhs, sp.Pow):
                 # don't process exponents -> they should remain integers
-                return sp.Pow(self.process_expression(rhs.args[0]), rhs.args[1])
+                return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1])
             else:
-                new_args = [self.process_expression(arg) for arg in rhs.args]
+                new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
                 return rhs.func(*new_args) if new_args else rhs
 
     @property
@@ -796,7 +796,7 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
             return check.process_assignment(obj)
         elif isinstance(obj, ast.Conditional):
             false_block = None if obj.false_block is None else visit(obj.false_block)
-            return ast.Conditional(check.process_expression(obj.condition_expr),
+            return ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False),
                                    true_block=visit(obj.true_block), false_block=false_block)
         elif isinstance(obj, ast.Block):
             return ast.Block([visit(e) for e in obj.args])
-- 
GitLab