import sympy as sp import numpy as np from pystencils import TypedSymbol from pystencils.astnodes import LoopOverCoordinate from pystencils.backends.cbackend import CustomCodeNode philox_two_doubles_call = """ {result_symbols[0].dtype} {result_symbols[0].name}; {result_symbols[1].dtype} {result_symbols[1].name}; philox_double2({parameters}, {result_symbols[0].name}, {result_symbols[1].name}); """ philox_four_floats_call = """ {result_symbols[0].dtype} {result_symbols[0].name}; {result_symbols[1].dtype} {result_symbols[1].name}; {result_symbols[2].dtype} {result_symbols[2].name}; {result_symbols[3].dtype} {result_symbols[3].name}; philox_float4({parameters}, {result_symbols[0].name}, {result_symbols[1].name}, {result_symbols[2].name}, {result_symbols[3].name}); """ class PhiloxTwoDoubles(CustomCodeNode): def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)): self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float64) for _ in range(2)) symbols_read = [s for s in keys if isinstance(s, sp.Symbol)] super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) self._time_step = time_step self.headers = ['"philox_rand.h"'] self.keys = list(keys) self._args = (time_step, *sp.sympify(keys)) self._dim = dim @property def args(self): return self._args @property def undefined_symbols(self): result = {a for a in self.args if isinstance(a, sp.Symbol)} loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(self._dim)] result.update(loop_counters) return result def get_code(self, dialect, vector_instruction_set): parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(self._dim)] + self.keys while len(parameters) < 6: parameters.append(0) parameters = parameters[:6] assert len(parameters) == 6 if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None): return philox_two_doubles_call.format(parameters=', '.join(str(p) for p in parameters), result_symbols=self.result_symbols) else: raise NotImplementedError("Not yet implemented for this backend") def __repr__(self): return "{}, {} <- PhiloxRNG".format(*self.result_symbols) class PhiloxFourFloats(CustomCodeNode): def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)): self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float32) for _ in range(4)) symbols_read = [s for s in keys if isinstance(s, sp.Symbol)] super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) self._time_step = time_step self.headers = ['"philox_rand.h"'] self.keys = list(keys) self._args = (time_step, *sp.sympify(keys)) self._dim = dim @property def args(self): return self._args @property def undefined_symbols(self): result = {a for a in self.args if isinstance(a, sp.Symbol)} loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(self._dim)] result.update(loop_counters) return result def get_code(self, dialect, vector_instruction_set): parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(self._dim)] + self.keys while len(parameters) < 6: parameters.append(0) parameters = parameters[:6] assert len(parameters) == 6 if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None): return philox_four_floats_call.format(parameters=', '.join(str(p) for p in parameters), result_symbols=self.result_symbols) else: raise NotImplementedError("Not yet implemented for this backend") def __repr__(self): return "{}, {}, {}, {} <- PhiloxRNG".format(*self.result_symbols) def random_symbol(assignment_list, rng_node=PhiloxTwoDoubles, *args, **kwargs): while True: node = rng_node(*args, **kwargs) inserted = False for symbol in node.result_symbols: if not inserted: assignment_list.insert(0, node) inserted = True yield symbol