Commit b1a72be5 authored by Martin Bauer's avatar Martin Bauer
Browse files

Bufix: RNG nodes produced wrong results when substituted

parent 36f757e2
Pipeline #16765 passed with stages
in 3 minutes and 32 seconds
......@@ -26,13 +26,12 @@ class PhiloxTwoDoubles(CustomCodeNode):
def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)):
self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float64) for _ in range(2))
symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step
self.headers = ['"philox_rand.h"']
self.keys = list(keys)
self._args = (time_step, *sp.sympify(keys))
self.keys = tuple(keys)
self._args = sp.sympify((dim, time_step, keys))
self._dim = dim
@property
......@@ -47,9 +46,12 @@ class PhiloxTwoDoubles(CustomCodeNode):
result.update(loop_counters)
return result
def fast_subs(self, _):
return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
def get_code(self, dialect, vector_instruction_set):
parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i)
for i in range(self._dim)] + self.keys
for i in range(self._dim)] + list(self.keys)
while len(parameters) < 6:
parameters.append(0)
......@@ -76,8 +78,8 @@ class PhiloxFourFloats(CustomCodeNode):
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step
self.headers = ['"philox_rand.h"']
self.keys = list(keys)
self._args = (time_step, *sp.sympify(keys))
self.keys = tuple(keys)
self._args = sp.sympify((dim, time_step, keys))
self._dim = dim
@property
......@@ -92,9 +94,12 @@ class PhiloxFourFloats(CustomCodeNode):
result.update(loop_counters)
return result
def fast_subs(self, _):
return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
def get_code(self, dialect, vector_instruction_set):
parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i)
for i in range(self._dim)] + self.keys
for i in range(self._dim)] + list(self.keys)
while len(parameters) < 6:
parameters.append(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment