Skip to content
Snippets Groups Projects
Commit 945e6bd3 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

Make sure that the RNG counter can be substituted during loop cutting

parent 9f966136
No related merge requests found
...@@ -19,9 +19,8 @@ def _get_rng_template(name, data_type, num_vars): ...@@ -19,9 +19,8 @@ def _get_rng_template(name, data_type, num_vars):
return template return template
def _get_rng_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols): def _get_rng_code(template, dialect, vector_instruction_set, time_step, coordinates, keys, dim, result_symbols):
parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] parameters = [time_step] + coordinates + [0] * (3 - dim) + list(keys)
for i in range(dim)] + [0] * (3 - dim) + list(keys)
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None): if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
return template.format(parameters=', '.join(str(p) for p in parameters), return template.format(parameters=', '.join(str(p) for p in parameters),
...@@ -44,6 +43,7 @@ class RNGBase(CustomCodeNode): ...@@ -44,6 +43,7 @@ class RNGBase(CustomCodeNode):
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step self._time_step = time_step
self._offsets = offsets self._offsets = offsets
self._coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)]
self.headers = [f'"{self._name}_rand.h"'] self.headers = [f'"{self._name}_rand.h"']
self.keys = tuple(keys) self.keys = tuple(keys)
self._args = sp.sympify((dim, time_step, keys)) self._args = sp.sympify((dim, time_step, keys))
...@@ -61,13 +61,17 @@ class RNGBase(CustomCodeNode): ...@@ -61,13 +61,17 @@ class RNGBase(CustomCodeNode):
result.update(loop_counters) result.update(loop_counters)
return result return result
def subs(self, subs_dict) -> None:
for i in range(len(self._coordinates)):
self._coordinates[i] = self._coordinates[i].subs(subs_dict)
def fast_subs(self, *_): def fast_subs(self, *_):
return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
def get_code(self, dialect, vector_instruction_set): def get_code(self, dialect, vector_instruction_set):
template = _get_rng_template(self._name, self._data_type, self._num_vars) template = _get_rng_template(self._name, self._data_type, self._num_vars)
return _get_rng_code(template, dialect, vector_instruction_set, return _get_rng_code(template, dialect, vector_instruction_set,
self._time_step, self._offsets, self.keys, self._dim, self.result_symbols) self._time_step, self._coordinates, self.keys, self._dim, self.result_symbols)
def __repr__(self): def __repr__(self):
return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols, return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols,
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
import pytest import pytest
import pystencils as ps import pystencils as ps
from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol
# curand_Philox4x32_10(make_uint4(124, i, j, 0), make_uint2(0, 0)) # curand_Philox4x32_10(make_uint4(124, i, j, 0), make_uint2(0, 0))
...@@ -99,4 +99,13 @@ def test_aesni_float(): ...@@ -99,4 +99,13 @@ def test_aesni_float():
dh.run_kernel(kernel, time_step=124) dh.run_kernel(kernel, time_step=124)
dh.all_to_cpu() dh.all_to_cpu()
arr = dh.gather_array('f') arr = dh.gather_array('f')
assert np.logical_and(arr <= 1.0, arr >= 0).all() assert np.logical_and(arr <= 1.0, arr >= 0).all()
\ No newline at end of file
def test_staggered():
"""Make sure that the RNG counter can be substituted during loop cutting"""
dh = ps.create_data_handling((8, 8), default_ghost_layers=0, default_target="cpu")
j = dh.add_array("j", values_per_cell=dh.dim, field_type=ps.FieldType.STAGGERED_FLUX)
a = ps.AssignmentCollection([ps.Assignment(j.staggered_access(n), 0) for n in j.staggered_stencil])
rng_symbol_gen = random_symbol(a.subexpressions, dim=dh.dim)
a.main_assignments[0] = ps.Assignment(a.main_assignments[0].lhs, next(rng_symbol_gen))
kernel = ps.create_staggered_kernel(a, target=dh.default_target).compile()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment