Skip to content
Snippets Groups Projects
Commit 7750926f authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

Remove code duplication from Philox generation

parent 6ce27ece
No related merge requests found
......@@ -52,8 +52,7 @@ QUALIFIERS void _philox4x32bumpkey(uint32* key)
QUALIFIERS double _uniform_double_hq(uint32 x, uint32 y)
{
unsigned long long z = (unsigned long long)x ^
((unsigned long long)y << (53 - 32));
uint64 z = (uint64)x ^ ((uint64)y << (53 - 32));
return z * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0);
}
......
......@@ -5,21 +5,18 @@ from pystencils import TypedSymbol
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import CustomCodeNode
philox_two_doubles_call = """
{result_symbols[0].dtype} {result_symbols[0].name};
{result_symbols[1].dtype} {result_symbols[1].name};
philox_double2({parameters}, {result_symbols[0].name}, {result_symbols[1].name});
"""
philox_four_floats_call = """
{result_symbols[0].dtype} {result_symbols[0].name};
{result_symbols[1].dtype} {result_symbols[1].name};
{result_symbols[2].dtype} {result_symbols[2].name};
{result_symbols[3].dtype} {result_symbols[3].name};
philox_float4({parameters},
{result_symbols[0].name}, {result_symbols[1].name}, {result_symbols[2].name}, {result_symbols[3].name});
"""
def _get_philox_template(data_type, num_vars):
if data_type is np.float32:
c_type = "float"
elif data_type is np.float64:
c_type = "double"
template = "\n"
for i in range(num_vars):
template += "{{result_symbols[{}].dtype}} {{result_symbols[{}].name}};\n".format(i, i)
template += ("philox_{}{}({{parameters}}, " + ", ".join(["{{result_symbols[{}].name}}"] * num_vars) + ");\n") \
.format(c_type, num_vars, *tuple(range(num_vars)))
return template
def _get_philox_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols):
......@@ -39,10 +36,10 @@ def _get_philox_code(template, dialect, vector_instruction_set, time_step, offse
raise NotImplementedError("Not yet implemented for this backend")
class PhiloxTwoDoubles(CustomCodeNode):
class PhiloxBase(CustomCodeNode):
def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=(0, 0, 0), keys=(0, 0)):
self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float64) for _ in range(2))
self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, self._data_type) for _ in range(self._num_vars))
symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step
......@@ -68,47 +65,22 @@ class PhiloxTwoDoubles(CustomCodeNode):
return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
def get_code(self, dialect, vector_instruction_set):
return _get_philox_code(philox_two_doubles_call, dialect, vector_instruction_set,
template = _get_philox_template(self._data_type, self._num_vars)
return _get_philox_code(template, dialect, vector_instruction_set,
self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
def __repr__(self):
return "{}, {} <- PhiloxRNG".format(*self.result_symbols)
class PhiloxFourFloats(CustomCodeNode):
def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=(0, 0, 0), keys=(0, 0)):
self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float32) for _ in range(4))
symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step
self._offsets = offsets
self.headers = ['"philox_rand.h"']
self.keys = tuple(keys)
self._args = sp.sympify((dim, time_step, offsets, keys))
self._dim = dim
@property
def args(self):
return self._args
return (", ".join(['{}'] * self._num_vars) + " <- PhiloxRNG").format(*self.result_symbols)
@property
def undefined_symbols(self):
result = {a for a in (self._time_step, *self._offsets, *self.keys) if isinstance(a, sp.Symbol)}
loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
for i in range(self._dim)]
result.update(loop_counters)
return result
def fast_subs(self, _):
return self # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them
class PhiloxTwoDoubles(PhiloxBase):
_data_type = np.float64
_num_vars = 2
def get_code(self, dialect, vector_instruction_set):
return _get_philox_code(philox_four_floats_call, dialect, vector_instruction_set,
self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
def __repr__(self):
return "{}, {}, {}, {} <- PhiloxRNG".format(*self.result_symbols)
class PhiloxFourFloats(PhiloxBase):
_data_type = np.float32
_num_vars = 4
def random_symbol(assignment_list, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles, *args, **kwargs):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment