rng.py 4.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import sympy as sp
import numpy as np
from pystencils import TypedSymbol
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import CustomCodeNode

philox_two_doubles_call = """
{result_symbols[0].dtype} {result_symbols[0].name};
{result_symbols[1].dtype} {result_symbols[1].name};
philox_double2({parameters}, {result_symbols[0].name}, {result_symbols[1].name});
"""

philox_four_floats_call = """
{result_symbols[0].dtype} {result_symbols[0].name};
{result_symbols[1].dtype} {result_symbols[1].name};
{result_symbols[2].dtype} {result_symbols[2].name};
{result_symbols[3].dtype} {result_symbols[3].name};
philox_float4({parameters}, 
              {result_symbols[0].name}, {result_symbols[1].name}, {result_symbols[2].name}, {result_symbols[3].name});

"""


class PhiloxTwoDoubles(CustomCodeNode):

    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)):
        self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float64) for _ in range(2))

        symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
        super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
        self._time_step = time_step
        self.headers = ['"philox_rand.h"']
        self.keys = list(keys)
        self._args = (time_step, *sp.sympify(keys))
        self._dim = dim

    @property
    def args(self):
        return self._args

    @property
    def undefined_symbols(self):
        result = {a for a in self.args if isinstance(a, sp.Symbol)}
        loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
                         for i in range(self._dim)]
        result.update(loop_counters)
        return result

    def get_code(self, dialect, vector_instruction_set):
        parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i)
                                          for i in range(self._dim)] + self.keys

        while len(parameters) < 6:
            parameters.append(0)
        parameters = parameters[:6]

        assert len(parameters) == 6

        if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
            return philox_two_doubles_call.format(parameters=', '.join(str(p) for p in parameters),
                                                  result_symbols=self.result_symbols)
        else:
            raise NotImplementedError("Not yet implemented for this backend")

    def __repr__(self):
        return "{}, {} <- PhiloxRNG".format(*self.result_symbols)


class PhiloxFourFloats(CustomCodeNode):

    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)):
        self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float32) for _ in range(4))
        symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]

        super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
        self._time_step = time_step
        self.headers = ['"philox_rand.h"']
        self.keys = list(keys)
        self._args = (time_step, *sp.sympify(keys))
        self._dim = dim

    @property
    def args(self):
        return self._args

    @property
    def undefined_symbols(self):
        result = {a for a in self.args if isinstance(a, sp.Symbol)}
        loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
                         for i in range(self._dim)]
        result.update(loop_counters)
        return result

    def get_code(self, dialect, vector_instruction_set):
        parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i)
                                          for i in range(self._dim)] + self.keys

        while len(parameters) < 6:
            parameters.append(0)
        parameters = parameters[:6]

        assert len(parameters) == 6

        if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
            return philox_four_floats_call.format(parameters=', '.join(str(p) for p in parameters),
                                                  result_symbols=self.result_symbols)
        else:
            raise NotImplementedError("Not yet implemented for this backend")

    def __repr__(self):
        return "{}, {}, {}, {} <- PhiloxRNG".format(*self.result_symbols)
Martin Bauer's avatar
Martin Bauer committed
112
113
114
115
116
117
118
119
120
121
122


def random_symbol(assignment_list, rng_node=PhiloxTwoDoubles, *args, **kwargs):
    while True:
        node = rng_node(*args, **kwargs)
        inserted = False
        for symbol in node.result_symbols:
            if not inserted:
                assignment_list.insert(0, node)
                inserted = True
            yield symbol