Skip to content
Snippets Groups Projects
Commit 945e6bd3 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

Make sure that the RNG counter can be substituted during loop cutting

parent 9f966136
Branches
Tags
No related merge requests found
......@@ -19,9 +19,8 @@ def _get_rng_template(name, data_type, num_vars):
return template
def _get_rng_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols):
parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i]
for i in range(dim)] + [0] * (3 - dim) + list(keys)
def _get_rng_code(template, dialect, vector_instruction_set, time_step, coordinates, keys, dim, result_symbols):
parameters = [time_step] + coordinates + [0] * (3 - dim) + list(keys)
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
return template.format(parameters=', '.join(str(p) for p in parameters),
......@@ -44,6 +43,7 @@ class RNGBase(CustomCodeNode):
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step
self._offsets = offsets
self._coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)]
self.headers = [f'"{self._name}_rand.h"']
self.keys = tuple(keys)
self._args = sp.sympify((dim, time_step, keys))
......@@ -61,13 +61,17 @@ class RNGBase(CustomCodeNode):
result.update(loop_counters)
return result
def subs(self, subs_dict) -> None:
for i in range(len(self._coordinates)):
self._coordinates[i] = self._coordinates[i].subs(subs_dict)
def fast_subs(self, *_):
return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
def get_code(self, dialect, vector_instruction_set):
template = _get_rng_template(self._name, self._data_type, self._num_vars)
return _get_rng_code(template, dialect, vector_instruction_set,
self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
self._time_step, self._coordinates, self.keys, self._dim, self.result_symbols)
def __repr__(self):
return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols,
......
......@@ -3,7 +3,7 @@ import numpy as np
import pytest
import pystencils as ps
from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles
from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol
# curand_Philox4x32_10(make_uint4(124, i, j, 0), make_uint2(0, 0))
......@@ -99,4 +99,13 @@ def test_aesni_float():
dh.run_kernel(kernel, time_step=124)
dh.all_to_cpu()
arr = dh.gather_array('f')
assert np.logical_and(arr <= 1.0, arr >= 0).all()
\ No newline at end of file
assert np.logical_and(arr <= 1.0, arr >= 0).all()
def test_staggered():
"""Make sure that the RNG counter can be substituted during loop cutting"""
dh = ps.create_data_handling((8, 8), default_ghost_layers=0, default_target="cpu")
j = dh.add_array("j", values_per_cell=dh.dim, field_type=ps.FieldType.STAGGERED_FLUX)
a = ps.AssignmentCollection([ps.Assignment(j.staggered_access(n), 0) for n in j.staggered_stencil])
rng_symbol_gen = random_symbol(a.subexpressions, dim=dh.dim)
a.main_assignments[0] = ps.Assignment(a.main_assignments[0].lhs, next(rng_symbol_gen))
kernel = ps.create_staggered_kernel(a, target=dh.default_target).compile()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment