From a3cb16347b486c7bbf3d64fe04d2a52c4e030bdc Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 3 Apr 2019 15:30:27 +0200
Subject: [PATCH] Bugfix in vectorization, in case conditionals are pulled
 before loop

---
 pystencils/backends/cbackend.py          |  9 ++++++---
 pystencils_tests/test_conditional_vec.py | 20 ++++++++++++++++++++
 2 files changed, 26 insertions(+), 3 deletions(-)

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 0147efaba..07f73f57c 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -16,7 +16,7 @@ from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift
     bitwise_or, modulo_ceil
 from pystencils.astnodes import Node, KernelFunction
 from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
-    vector_memory_access, reinterpret_cast_func
+    vector_memory_access, reinterpret_cast_func, get_base_type
 
 __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
 
@@ -517,6 +517,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
 
         result = self._print(expr.args[-1][0])
         for true_expr, condition in reversed(expr.args[:-1]):
-            # noinspection SpellCheckingInspection
-            result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
+            if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
+                result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), result)
+            else:
+                # noinspection SpellCheckingInspection
+                result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
         return result
diff --git a/pystencils_tests/test_conditional_vec.py b/pystencils_tests/test_conditional_vec.py
index c1d11f141..b66cffcb8 100644
--- a/pystencils_tests/test_conditional_vec.py
+++ b/pystencils_tests/test_conditional_vec.py
@@ -41,3 +41,23 @@ def test_vec_all():
     before = data_arr.copy()
     kernel(data=data_arr)
     np.testing.assert_equal(data_arr, before)
+
+
+def test_boolean_before_loop():
+    t1, t2 = sp.symbols('t1, t2')
+    f_arr = np.ones((10, 10))
+    g_arr = np.zeros_like(f_arr)
+    f, g = ps.fields("f, g : double[2D]", f=f_arr, g=g_arr)
+
+    a = [
+        ps.Assignment(t1, t2 > 0),
+        ps.Assignment(g[0, 0],
+                      sp.Piecewise((f[0, 0], t1), (42, True)))
+    ]
+    ast = ps.create_kernel(a, cpu_vectorize_info={'instruction_set': 'avx'})
+    kernel = ast.compile()
+    kernel(f=f_arr, g=g_arr, t2=1.0)
+    print(g)
+    np.testing.assert_array_equal(g_arr, 1.0)
+    kernel(f=f_arr, g=g_arr, t2=-1.0)
+    np.testing.assert_array_equal(g_arr, 42.0)
-- 
GitLab