Commit 2c25776f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Merge branch 'rng' into 'master'

Make sure that the RNG counter can be substituted during loop cutting

See merge request pycodegen/pystencils!190
parents 86b97688 945e6bd3
......@@ -19,9 +19,8 @@ def _get_rng_template(name, data_type, num_vars):
return template
def _get_rng_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols):
parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i]
for i in range(dim)] + [0] * (3 - dim) + list(keys)
def _get_rng_code(template, dialect, vector_instruction_set, time_step, coordinates, keys, dim, result_symbols):
parameters = [time_step] + coordinates + [0] * (3 - dim) + list(keys)
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
return template.format(parameters=', '.join(str(p) for p in parameters),
......@@ -44,6 +43,7 @@ class RNGBase(CustomCodeNode):
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step
self._offsets = offsets
self._coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)]
self.headers = [f'"{self._name}_rand.h"']
self.keys = tuple(keys)
self._args = sp.sympify((dim, time_step, keys))
......@@ -61,13 +61,17 @@ class RNGBase(CustomCodeNode):
result.update(loop_counters)
return result
def subs(self, subs_dict) -> None:
for i in range(len(self._coordinates)):
self._coordinates[i] = self._coordinates[i].subs(subs_dict)
def fast_subs(self, *_):
return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
def get_code(self, dialect, vector_instruction_set):
template = _get_rng_template(self._name, self._data_type, self._num_vars)
return _get_rng_code(template, dialect, vector_instruction_set,
self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
self._time_step, self._coordinates, self.keys, self._dim, self.result_symbols)
def __repr__(self):
return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols,
......
......@@ -3,7 +3,7 @@ import numpy as np
import pytest
import pystencils as ps
from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles
from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol
# curand_Philox4x32_10(make_uint4(124, i, j, 0), make_uint2(0, 0))
......@@ -99,4 +99,13 @@ def test_aesni_float():
dh.run_kernel(kernel, time_step=124)
dh.all_to_cpu()
arr = dh.gather_array('f')
assert np.logical_and(arr <= 1.0, arr >= 0).all()
\ No newline at end of file
assert np.logical_and(arr <= 1.0, arr >= 0).all()
def test_staggered():
"""Make sure that the RNG counter can be substituted during loop cutting"""
dh = ps.create_data_handling((8, 8), default_ghost_layers=0, default_target="cpu")
j = dh.add_array("j", values_per_cell=dh.dim, field_type=ps.FieldType.STAGGERED_FLUX)
a = ps.AssignmentCollection([ps.Assignment(j.staggered_access(n), 0) for n in j.staggered_stencil])
rng_symbol_gen = random_symbol(a.subexpressions, dim=dh.dim)
a.main_assignments[0] = ps.Assignment(a.main_assignments[0].lhs, next(rng_symbol_gen))
kernel = ps.create_staggered_kernel(a, target=dh.default_target).compile()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment