Dear CS10-Gitlab-users, on Thursday, Feb 3 there will be maintenance. That will lead to a downtime of the CS10-Gitlab-service including Subversion and Mattermost chat from 09:30. This might take the whole day since we don't know how long it is going to take. We are sorry for the inconvenience! Best regards, CS10-Admin-Team

Commit a3cb1634 authored by Martin Bauer's avatar Martin Bauer
Browse files

Bugfix in vectorization, in case conditionals are pulled before loop

parent f35dbf6f
......@@ -16,7 +16,7 @@ from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift
bitwise_or, modulo_ceil
from pystencils.astnodes import Node, KernelFunction
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
vector_memory_access, reinterpret_cast_func
vector_memory_access, reinterpret_cast_func, get_base_type
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
......@@ -517,6 +517,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._print(expr.args[-1][0])
for true_expr, condition in reversed(expr.args[:-1]):
# noinspection SpellCheckingInspection
result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), result)
else:
# noinspection SpellCheckingInspection
result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
return result
......@@ -41,3 +41,23 @@ def test_vec_all():
before = data_arr.copy()
kernel(data=data_arr)
np.testing.assert_equal(data_arr, before)
def test_boolean_before_loop():
t1, t2 = sp.symbols('t1, t2')
f_arr = np.ones((10, 10))
g_arr = np.zeros_like(f_arr)
f, g = ps.fields("f, g : double[2D]", f=f_arr, g=g_arr)
a = [
ps.Assignment(t1, t2 > 0),
ps.Assignment(g[0, 0],
sp.Piecewise((f[0, 0], t1), (42, True)))
]
ast = ps.create_kernel(a, cpu_vectorize_info={'instruction_set': 'avx'})
kernel = ast.compile()
kernel(f=f_arr, g=g_arr, t2=1.0)
print(g)
np.testing.assert_array_equal(g_arr, 1.0)
kernel(f=f_arr, g=g_arr, t2=-1.0)
np.testing.assert_array_equal(g_arr, 42.0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment