diff --git a/pystencils/rng.py b/pystencils/rng.py index afc08b033ee23983966fd2ce09c99a8d28580f58..26a92b313de13343c02375f6bfb003fca15b418f 100644 --- a/pystencils/rng.py +++ b/pystencils/rng.py @@ -26,13 +26,12 @@ class PhiloxTwoDoubles(CustomCodeNode): def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)): self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float64) for _ in range(2)) - symbols_read = [s for s in keys if isinstance(s, sp.Symbol)] super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) self._time_step = time_step self.headers = ['"philox_rand.h"'] - self.keys = list(keys) - self._args = (time_step, *sp.sympify(keys)) + self.keys = tuple(keys) + self._args = sp.sympify((dim, time_step, keys)) self._dim = dim @property @@ -47,9 +46,12 @@ class PhiloxTwoDoubles(CustomCodeNode): result.update(loop_counters) return result + def fast_subs(self, _): + return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them + def get_code(self, dialect, vector_instruction_set): parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) - for i in range(self._dim)] + self.keys + for i in range(self._dim)] + list(self.keys) while len(parameters) < 6: parameters.append(0) @@ -76,8 +78,8 @@ class PhiloxFourFloats(CustomCodeNode): super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) self._time_step = time_step self.headers = ['"philox_rand.h"'] - self.keys = list(keys) - self._args = (time_step, *sp.sympify(keys)) + self.keys = tuple(keys) + self._args = sp.sympify((dim, time_step, keys)) self._dim = dim @property @@ -92,9 +94,12 @@ class PhiloxFourFloats(CustomCodeNode): result.update(loop_counters) return result + def fast_subs(self, _): + return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them + def get_code(self, dialect, vector_instruction_set): parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) - for i in range(self._dim)] + self.keys + for i in range(self._dim)] + list(self.keys) while len(parameters) < 6: parameters.append(0)