diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 0147efabad6e92f49b91cd3fc8fae0711ac0a2d5..07f73f57caaf643a027b66b0f1438c25e4d339b3 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -16,7 +16,7 @@ from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift bitwise_or, modulo_ceil from pystencils.astnodes import Node, KernelFunction from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \ - vector_memory_access, reinterpret_cast_func + vector_memory_access, reinterpret_cast_func, get_base_type __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] @@ -517,6 +517,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = self._print(expr.args[-1][0]) for true_expr, condition in reversed(expr.args[:-1]): - # noinspection SpellCheckingInspection - result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition)) + if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"): + result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), result) + else: + # noinspection SpellCheckingInspection + result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition)) return result diff --git a/pystencils_tests/test_conditional_vec.py b/pystencils_tests/test_conditional_vec.py index c1d11f14162c40e2feb02c4ad4a4d1e6ad23a923..b66cffcb862ca1332aa735373e44c8249f9d64c4 100644 --- a/pystencils_tests/test_conditional_vec.py +++ b/pystencils_tests/test_conditional_vec.py @@ -41,3 +41,23 @@ def test_vec_all(): before = data_arr.copy() kernel(data=data_arr) np.testing.assert_equal(data_arr, before) + + +def test_boolean_before_loop(): + t1, t2 = sp.symbols('t1, t2') + f_arr = np.ones((10, 10)) + g_arr = np.zeros_like(f_arr) + f, g = ps.fields("f, g : double[2D]", f=f_arr, g=g_arr) + + a = [ + ps.Assignment(t1, t2 > 0), + ps.Assignment(g[0, 0], + sp.Piecewise((f[0, 0], t1), (42, True))) + ] + ast = ps.create_kernel(a, cpu_vectorize_info={'instruction_set': 'avx'}) + kernel = ast.compile() + kernel(f=f_arr, g=g_arr, t2=1.0) + print(g) + np.testing.assert_array_equal(g_arr, 1.0) + kernel(f=f_arr, g=g_arr, t2=-1.0) + np.testing.assert_array_equal(g_arr, 42.0)