Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import sympy as sp
import numpy as np
from pystencils import TypedSymbol
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import CustomCodeNode
philox_two_doubles_call = """
{result_symbols[0].dtype} {result_symbols[0].name};
{result_symbols[1].dtype} {result_symbols[1].name};
philox_double2({parameters}, {result_symbols[0].name}, {result_symbols[1].name});
"""
philox_four_floats_call = """
{result_symbols[0].dtype} {result_symbols[0].name};
{result_symbols[1].dtype} {result_symbols[1].name};
{result_symbols[2].dtype} {result_symbols[2].name};
{result_symbols[3].dtype} {result_symbols[3].name};
philox_float4({parameters},
{result_symbols[0].name}, {result_symbols[1].name}, {result_symbols[2].name}, {result_symbols[3].name});
"""
class PhiloxTwoDoubles(CustomCodeNode):
def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)):
self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float64) for _ in range(2))
symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step
self.headers = ['"philox_rand.h"']
self.keys = list(keys)
self._args = (time_step, *sp.sympify(keys))
self._dim = dim
@property
def args(self):
return self._args
@property
def undefined_symbols(self):
result = {a for a in self.args if isinstance(a, sp.Symbol)}
loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
for i in range(self._dim)]
result.update(loop_counters)
return result
def get_code(self, dialect, vector_instruction_set):
parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i)
for i in range(self._dim)] + self.keys
while len(parameters) < 6:
parameters.append(0)
parameters = parameters[:6]
assert len(parameters) == 6
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
return philox_two_doubles_call.format(parameters=', '.join(str(p) for p in parameters),
result_symbols=self.result_symbols)
else:
raise NotImplementedError("Not yet implemented for this backend")
def __repr__(self):
return "{}, {} <- PhiloxRNG".format(*self.result_symbols)
class PhiloxFourFloats(CustomCodeNode):
def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)):
self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float32) for _ in range(4))
symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self._time_step = time_step
self.headers = ['"philox_rand.h"']
self.keys = list(keys)
self._args = (time_step, *sp.sympify(keys))
self._dim = dim
@property
def args(self):
return self._args
@property
def undefined_symbols(self):
result = {a for a in self.args if isinstance(a, sp.Symbol)}
loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
for i in range(self._dim)]
result.update(loop_counters)
return result
def get_code(self, dialect, vector_instruction_set):
parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i)
for i in range(self._dim)] + self.keys
while len(parameters) < 6:
parameters.append(0)
parameters = parameters[:6]
assert len(parameters) == 6
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
return philox_four_floats_call.format(parameters=', '.join(str(p) for p in parameters),
result_symbols=self.result_symbols)
else:
raise NotImplementedError("Not yet implemented for this backend")
def __repr__(self):
return "{}, {}, {}, {} <- PhiloxRNG".format(*self.result_symbols)
def random_symbol(assignment_list, rng_node=PhiloxTwoDoubles, *args, **kwargs):
while True:
node = rng_node(*args, **kwargs)
inserted = False
for symbol in node.result_symbols:
if not inserted:
assignment_list.insert(0, node)
inserted = True
yield symbol