rng.py 3.51 KB
Newer Older
1
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
2
3
import sympy as sp

4
5
6
7
8
from pystencils import TypedSymbol
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import CustomCodeNode


9
10
11
12
13
14
15
16
17
18
19
def _get_philox_template(data_type, num_vars):
    if data_type is np.float32:
        c_type = "float"
    elif data_type is np.float64:
        c_type = "double"
    template = "\n"
    for i in range(num_vars):
        template += "{{result_symbols[{}].dtype}} {{result_symbols[{}].name}};\n".format(i, i)
    template += ("philox_{}{}({{parameters}}, " + ", ".join(["{{result_symbols[{}].name}}"] * num_vars) + ");\n") \
        .format(c_type, num_vars, *tuple(range(num_vars)))
    return template
20
21


22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def _get_philox_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols):
    parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i]
                                for i in range(dim)] + list(keys)

    while len(parameters) < 6:
        parameters.append(0)
    parameters = parameters[:6]

    assert len(parameters) == 6

    if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
        return template.format(parameters=', '.join(str(p) for p in parameters),
                               result_symbols=result_symbols)
    else:
        raise NotImplementedError("Not yet implemented for this backend")


39
class PhiloxBase(CustomCodeNode):
40

41
    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=(0, 0, 0), keys=(0, 0)):
42
        self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, self._data_type) for _ in range(self._num_vars))
43
44
45
        symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
        super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
        self._time_step = time_step
46
        self._offsets = offsets
47
        self.headers = ['"philox_rand.h"']
48
49
        self.keys = tuple(keys)
        self._args = sp.sympify((dim, time_step, keys))
50
51
52
53
54
55
56
57
        self._dim = dim

    @property
    def args(self):
        return self._args

    @property
    def undefined_symbols(self):
58
        result = {a for a in (self._time_step, *self._offsets, *self.keys) if isinstance(a, sp.Symbol)}
59
60
61
62
63
        loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
                         for i in range(self._dim)]
        result.update(loop_counters)
        return result

64
65
66
    def fast_subs(self, _):
        return self  # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them

67
    def get_code(self, dialect, vector_instruction_set):
68
69
        template = _get_philox_template(self._data_type, self._num_vars)
        return _get_philox_code(template, dialect, vector_instruction_set,
70
                                self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
71
72

    def __repr__(self):
73
        return (", ".join(['{}'] * self._num_vars) + " <- PhiloxRNG").format(*self.result_symbols)
74
75


76
77
78
class PhiloxTwoDoubles(PhiloxBase):
    _data_type = np.float64
    _num_vars = 2
79

80

81
82
83
class PhiloxFourFloats(PhiloxBase):
    _data_type = np.float32
    _num_vars = 4
Martin Bauer's avatar
Martin Bauer committed
84
85


86
87
def random_symbol(assignment_list, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles, *args, **kwargs):
    counter = 0
Martin Bauer's avatar
Martin Bauer committed
88
    while True:
89
        node = rng_node(*args, keys=(counter, seed), **kwargs)
Martin Bauer's avatar
Martin Bauer committed
90
91
92
93
94
95
        inserted = False
        for symbol in node.result_symbols:
            if not inserted:
                assignment_list.insert(0, node)
                inserted = True
            yield symbol