Skip to content
Snippets Groups Projects
Commit 2c25776f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Merge branch 'rng' into 'master'

Make sure that the RNG counter can be substituted during loop cutting

See merge request !190
parents 86b97688 945e6bd3
No related merge requests found
...@@ -19,9 +19,8 @@ def _get_rng_template(name, data_type, num_vars): ...@@ -19,9 +19,8 @@ def _get_rng_template(name, data_type, num_vars):
return template return template
def _get_rng_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols): def _get_rng_code(template, dialect, vector_instruction_set, time_step, coordinates, keys, dim, result_symbols):
parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] parameters = [time_step] + coordinates + [0] * (3 - dim) + list(keys)
for i in range(dim)] + [0] * (3 - dim) + list(keys)
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None): if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
return template.format(parameters=', '.join(str(p) for p in parameters), return template.format(parameters=', '.join(str(p) for p in parameters),
...@@ -44,6 +43,7 @@ class RNGBase(CustomCodeNode): ...@@ -44,6 +43,7 @@ class RNGBase(CustomCodeNode):
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step self._time_step = time_step
self._offsets = offsets self._offsets = offsets
self._coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)]
self.headers = [f'"{self._name}_rand.h"'] self.headers = [f'"{self._name}_rand.h"']
self.keys = tuple(keys) self.keys = tuple(keys)
self._args = sp.sympify((dim, time_step, keys)) self._args = sp.sympify((dim, time_step, keys))
...@@ -61,13 +61,17 @@ class RNGBase(CustomCodeNode): ...@@ -61,13 +61,17 @@ class RNGBase(CustomCodeNode):
result.update(loop_counters) result.update(loop_counters)
return result return result
def subs(self, subs_dict) -> None:
for i in range(len(self._coordinates)):
self._coordinates[i] = self._coordinates[i].subs(subs_dict)
def fast_subs(self, *_): def fast_subs(self, *_):
return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
def get_code(self, dialect, vector_instruction_set): def get_code(self, dialect, vector_instruction_set):
template = _get_rng_template(self._name, self._data_type, self._num_vars) template = _get_rng_template(self._name, self._data_type, self._num_vars)
return _get_rng_code(template, dialect, vector_instruction_set, return _get_rng_code(template, dialect, vector_instruction_set,
self._time_step, self._offsets, self.keys, self._dim, self.result_symbols) self._time_step, self._coordinates, self.keys, self._dim, self.result_symbols)
def __repr__(self): def __repr__(self):
return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols, return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols,
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
import pytest import pytest
import pystencils as ps import pystencils as ps
from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol
# curand_Philox4x32_10(make_uint4(124, i, j, 0), make_uint2(0, 0)) # curand_Philox4x32_10(make_uint4(124, i, j, 0), make_uint2(0, 0))
...@@ -99,4 +99,13 @@ def test_aesni_float(): ...@@ -99,4 +99,13 @@ def test_aesni_float():
dh.run_kernel(kernel, time_step=124) dh.run_kernel(kernel, time_step=124)
dh.all_to_cpu() dh.all_to_cpu()
arr = dh.gather_array('f') arr = dh.gather_array('f')
assert np.logical_and(arr <= 1.0, arr >= 0).all() assert np.logical_and(arr <= 1.0, arr >= 0).all()
\ No newline at end of file
def test_staggered():
"""Make sure that the RNG counter can be substituted during loop cutting"""
dh = ps.create_data_handling((8, 8), default_ghost_layers=0, default_target="cpu")
j = dh.add_array("j", values_per_cell=dh.dim, field_type=ps.FieldType.STAGGERED_FLUX)
a = ps.AssignmentCollection([ps.Assignment(j.staggered_access(n), 0) for n in j.staggered_stencil])
rng_symbol_gen = random_symbol(a.subexpressions, dim=dh.dim)
a.main_assignments[0] = ps.Assignment(a.main_assignments[0].lhs, next(rng_symbol_gen))
kernel = ps.create_staggered_kernel(a, target=dh.default_target).compile()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment