random.py 4.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import sympy as sp
import numpy as np
from pystencils import TypedSymbol
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import CustomCodeNode

philox_two_doubles_call = """
{result_symbols[0].dtype} {result_symbols[0].name};
{result_symbols[1].dtype} {result_symbols[1].name};
philox_double2({parameters}, {result_symbols[0].name}, {result_symbols[1].name});
"""

philox_four_floats_call = """
{result_symbols[0].dtype} {result_symbols[0].name};
{result_symbols[1].dtype} {result_symbols[1].name};
{result_symbols[2].dtype} {result_symbols[2].name};
{result_symbols[3].dtype} {result_symbols[3].name};
philox_float4({parameters}, 
              {result_symbols[0].name}, {result_symbols[1].name}, {result_symbols[2].name}, {result_symbols[3].name});

"""


class PhiloxTwoDoubles(CustomCodeNode):

    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)):
        self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float64) for _ in range(2))

        symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
        super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
        self._time_step = time_step
        self.headers = ['"philox_rand.h"']
        self.keys = list(keys)
        self._args = (time_step, *sp.sympify(keys))
        self._dim = dim

    @property
    def args(self):
        return self._args

    @property
    def undefined_symbols(self):
        result = {a for a in self.args if isinstance(a, sp.Symbol)}
        loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
                         for i in range(self._dim)]
        result.update(loop_counters)
        return result

    def get_code(self, dialect, vector_instruction_set):
        parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i)
                                          for i in range(self._dim)] + self.keys

        while len(parameters) < 6:
            parameters.append(0)
        parameters = parameters[:6]

        assert len(parameters) == 6

        if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
            return philox_two_doubles_call.format(parameters=', '.join(str(p) for p in parameters),
                                                  result_symbols=self.result_symbols)
        else:
            raise NotImplementedError("Not yet implemented for this backend")

    def __repr__(self):
        return "{}, {} <- PhiloxRNG".format(*self.result_symbols)


class PhiloxFourFloats(CustomCodeNode):

    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), keys=(0, 0)):
        self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, np.float32) for _ in range(4))
        symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]

        super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
        self._time_step = time_step
        self.headers = ['"philox_rand.h"']
        self.keys = list(keys)
        self._args = (time_step, *sp.sympify(keys))
        self._dim = dim

    @property
    def args(self):
        return self._args

    @property
    def undefined_symbols(self):
        result = {a for a in self.args if isinstance(a, sp.Symbol)}
        loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
                         for i in range(self._dim)]
        result.update(loop_counters)
        return result

    def get_code(self, dialect, vector_instruction_set):
        parameters = [self._time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i)
                                          for i in range(self._dim)] + self.keys

        while len(parameters) < 6:
            parameters.append(0)
        parameters = parameters[:6]

        assert len(parameters) == 6

        if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
            return philox_four_floats_call.format(parameters=', '.join(str(p) for p in parameters),
                                                  result_symbols=self.result_symbols)
        else:
            raise NotImplementedError("Not yet implemented for this backend")

    def __repr__(self):
        return "{}, {}, {}, {} <- PhiloxRNG".format(*self.result_symbols)