Commit 754c7767 authored by Martin Bauer's avatar Martin Bauer
Browse files

Bugfix - Vectorization of in-place LB update wrong

Block.subs method tried to be too smart:
a = field[..]
b = a + b

was "simplified" incorrectly to
b = field[...] + b
parent 0690f5ad
...@@ -212,6 +212,8 @@ class KernelFunction(Node): ...@@ -212,6 +212,8 @@ class KernelFunction(Node):
argument_symbols = self._body.undefined_symbols - self.global_variables argument_symbols = self._body.undefined_symbols - self.global_variables
parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols] parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
if hasattr(self, 'indexing'):
parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
parameters.sort(key=lambda p: p.symbol.name) parameters.sort(key=lambda p: p.symbol.name)
return parameters return parameters
...@@ -252,14 +254,6 @@ class Block(Node): ...@@ -252,14 +254,6 @@ class Block(Node):
return self._nodes return self._nodes
def subs(self, subs_dict) -> None: def subs(self, subs_dict) -> None:
new_args = []
for a in self.args:
if isinstance(a, SympyAssignment) and a.is_declaration and a.rhs in subs_dict.keys():
subs_dict[a.lhs] = subs_dict[a.rhs]
else:
new_args.append(a)
self._nodes = new_args
for a in self.args: for a in self.args:
a.subs(subs_dict) a.subs(subs_dict)
......
...@@ -25,6 +25,26 @@ def test_vector_type_propagation(): ...@@ -25,6 +25,26 @@ def test_vector_type_propagation():
np.testing.assert_equal(dst[1:-1, 1:-1], 2 * 10.0 + 3) np.testing.assert_equal(dst[1:-1, 1:-1], 2 * 10.0 + 3)
def test_inplace_update():
shape = (9, 9, 3)
arr = np.ones(shape, order='f')
@ps.kernel
def update_rule(s):
f = ps.fields("f(3) : [2D]", f=arr)
s.tmp0 @= f(0)
s.tmp1 @= f(1)
s.tmp2 @= f(2)
f0, f1, f2 = f(0), f(1), f(2)
f0 @= 2 * s.tmp0
f1 @= 2 * s.tmp0
f2 @= 2 * s.tmp0
ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': 'avx'})
kernel = ast.compile()
kernel(f=arr)
np.testing.assert_equal(arr, 2)
def test_vectorization_fixed_size(): def test_vectorization_fixed_size():
configurations = [] configurations = []
# Fixed size - multiple of four # Fixed size - multiple of four
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment