Skip to content
Snippets Groups Projects
Commit a3cb1634 authored by Martin Bauer's avatar Martin Bauer
Browse files

Bugfix in vectorization, in case conditionals are pulled before loop

parent f35dbf6f
Branches
Tags
No related merge requests found
...@@ -16,7 +16,7 @@ from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift ...@@ -16,7 +16,7 @@ from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift
bitwise_or, modulo_ceil bitwise_or, modulo_ceil
from pystencils.astnodes import Node, KernelFunction from pystencils.astnodes import Node, KernelFunction
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \ from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
vector_memory_access, reinterpret_cast_func vector_memory_access, reinterpret_cast_func, get_base_type
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] __all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
...@@ -517,6 +517,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -517,6 +517,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._print(expr.args[-1][0]) result = self._print(expr.args[-1][0])
for true_expr, condition in reversed(expr.args[:-1]): for true_expr, condition in reversed(expr.args[:-1]):
# noinspection SpellCheckingInspection if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition)) result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), result)
else:
# noinspection SpellCheckingInspection
result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
return result return result
...@@ -41,3 +41,23 @@ def test_vec_all(): ...@@ -41,3 +41,23 @@ def test_vec_all():
before = data_arr.copy() before = data_arr.copy()
kernel(data=data_arr) kernel(data=data_arr)
np.testing.assert_equal(data_arr, before) np.testing.assert_equal(data_arr, before)
def test_boolean_before_loop():
t1, t2 = sp.symbols('t1, t2')
f_arr = np.ones((10, 10))
g_arr = np.zeros_like(f_arr)
f, g = ps.fields("f, g : double[2D]", f=f_arr, g=g_arr)
a = [
ps.Assignment(t1, t2 > 0),
ps.Assignment(g[0, 0],
sp.Piecewise((f[0, 0], t1), (42, True)))
]
ast = ps.create_kernel(a, cpu_vectorize_info={'instruction_set': 'avx'})
kernel = ast.compile()
kernel(f=f_arr, g=g_arr, t2=1.0)
print(g)
np.testing.assert_array_equal(g_arr, 1.0)
kernel(f=f_arr, g=g_arr, t2=-1.0)
np.testing.assert_array_equal(g_arr, 42.0)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment