diff --git a/pystencils/rng.py b/pystencils/rng.py index bbc28bbd2c7d98cf931149d7e5572cda7fa33442..27606324a68cd30cd2105447209c6456d16213d4 100644 --- a/pystencils/rng.py +++ b/pystencils/rng.py @@ -19,9 +19,8 @@ def _get_rng_template(name, data_type, num_vars): return template -def _get_rng_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols): - parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] - for i in range(dim)] + [0] * (3 - dim) + list(keys) +def _get_rng_code(template, dialect, vector_instruction_set, time_step, coordinates, keys, dim, result_symbols): + parameters = [time_step] + coordinates + [0] * (3 - dim) + list(keys) if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None): return template.format(parameters=', '.join(str(p) for p in parameters), @@ -44,6 +43,7 @@ class RNGBase(CustomCodeNode): super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) self._time_step = time_step self._offsets = offsets + self._coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)] self.headers = [f'"{self._name}_rand.h"'] self.keys = tuple(keys) self._args = sp.sympify((dim, time_step, keys)) @@ -61,13 +61,17 @@ class RNGBase(CustomCodeNode): result.update(loop_counters) return result + def subs(self, subs_dict) -> None: + for i in range(len(self._coordinates)): + self._coordinates[i] = self._coordinates[i].subs(subs_dict) + def fast_subs(self, *_): return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them def get_code(self, dialect, vector_instruction_set): template = _get_rng_template(self._name, self._data_type, self._num_vars) return _get_rng_code(template, dialect, vector_instruction_set, - self._time_step, self._offsets, self.keys, self._dim, self.result_symbols) + self._time_step, self._coordinates, self.keys, self._dim, self.result_symbols) def __repr__(self): return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols, diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py index 1b3f89f2f81aa6e079f7bc030f27da74e778dd13..9c7ee4dd0840339331a028d58e39f0650b9c8374 100644 --- a/pystencils_tests/test_random.py +++ b/pystencils_tests/test_random.py @@ -3,7 +3,7 @@ import numpy as np import pytest import pystencils as ps -from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles +from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol # curand_Philox4x32_10(make_uint4(124, i, j, 0), make_uint2(0, 0)) @@ -99,4 +99,13 @@ def test_aesni_float(): dh.run_kernel(kernel, time_step=124) dh.all_to_cpu() arr = dh.gather_array('f') - assert np.logical_and(arr <= 1.0, arr >= 0).all() \ No newline at end of file + assert np.logical_and(arr <= 1.0, arr >= 0).all() + +def test_staggered(): + """Make sure that the RNG counter can be substituted during loop cutting""" + dh = ps.create_data_handling((8, 8), default_ghost_layers=0, default_target="cpu") + j = dh.add_array("j", values_per_cell=dh.dim, field_type=ps.FieldType.STAGGERED_FLUX) + a = ps.AssignmentCollection([ps.Assignment(j.staggered_access(n), 0) for n in j.staggered_stencil]) + rng_symbol_gen = random_symbol(a.subexpressions, dim=dh.dim) + a.main_assignments[0] = ps.Assignment(a.main_assignments[0].lhs, next(rng_symbol_gen)) + kernel = ps.create_staggered_kernel(a, target=dh.default_target).compile()