rng.py 4.3 KB
Newer Older
1
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
2
3
import sympy as sp

4
5
6
7
8
from pystencils import TypedSymbol
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import CustomCodeNode


Michael Kuron's avatar
Michael Kuron committed
9
def _get_rng_template(name, data_type, num_vars):
10
11
12
13
14
15
    if data_type is np.float32:
        c_type = "float"
    elif data_type is np.float64:
        c_type = "double"
    template = "\n"
    for i in range(num_vars):
16
        template += f"{{result_symbols[{i}].dtype}} {{result_symbols[{i}].name}};\n"
Michael Kuron's avatar
Michael Kuron committed
17
18
    template += ("{}_{}{}({{parameters}}, " + ", ".join(["{{result_symbols[{}].name}}"] * num_vars) + ");\n") \
        .format(name, c_type, num_vars, *tuple(range(num_vars)))
19
    return template
20
21


22
23
def _get_rng_code(template, dialect, vector_instruction_set, time_step, coordinates, keys, dim, result_symbols):
    parameters = [time_step] + coordinates + [0] * (3 - dim) + list(keys)
24
25
26
27
28
29
30
31

    if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
        return template.format(parameters=', '.join(str(p) for p in parameters),
                               result_symbols=result_symbols)
    else:
        raise NotImplementedError("Not yet implemented for this backend")


Michael Kuron's avatar
Michael Kuron committed
32
class RNGBase(CustomCodeNode):
33

Michael Kuron's avatar
Michael Kuron committed
34
35
36
37
    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=(0, 0, 0), keys=None):
        if keys is None:
            keys = (0,) * self._num_keys
        if len(keys) != self._num_keys:
38
            raise ValueError(f"Provided {len(keys)} keys but need {self._num_keys}")
Michael Kuron's avatar
Michael Kuron committed
39
        if len(offsets) != 3:
40
            raise ValueError(f"Provided {len(offsets)} offsets but need {3}")
41
        self.result_symbols = tuple(TypedSymbol(sp.Dummy().name, self._data_type) for _ in range(self._num_vars))
42
43
44
        symbols_read = [s for s in keys if isinstance(s, sp.Symbol)]
        super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
        self._time_step = time_step
45
        self._offsets = offsets
46
        self._coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)]
47
        self.headers = [f'"{self._name}_rand.h"']
48
49
        self.keys = tuple(keys)
        self._args = sp.sympify((dim, time_step, keys))
50
51
52
53
54
55
56
57
        self._dim = dim

    @property
    def args(self):
        return self._args

    @property
    def undefined_symbols(self):
58
        result = {a for a in (self._time_step, *self._offsets, *self.keys) if isinstance(a, sp.Symbol)}
59
60
61
62
63
        loop_counters = [LoopOverCoordinate.get_loop_counter_symbol(i)
                         for i in range(self._dim)]
        result.update(loop_counters)
        return result

64
65
66
67
    def subs(self, subs_dict) -> None:
        for i in range(len(self._coordinates)):
            self._coordinates[i] = self._coordinates[i].subs(subs_dict)

68
    def fast_subs(self, *_):
69
70
        return self  # nothing to replace inside this node - would destroy intermediate "dummy" by re-creating them

71
    def get_code(self, dialect, vector_instruction_set):
Michael Kuron's avatar
Michael Kuron committed
72
73
        template = _get_rng_template(self._name, self._data_type, self._num_vars)
        return _get_rng_code(template, dialect, vector_instruction_set,
74
                             self._time_step, self._coordinates, self.keys, self._dim, self.result_symbols)
75
76

    def __repr__(self):
Michael Kuron's avatar
Michael Kuron committed
77
78
        return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols,
                                                                                  self._name.capitalize())
Michael Kuron's avatar
Michael Kuron committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92


class PhiloxTwoDoubles(RNGBase):
    _name = "philox"
    _data_type = np.float64
    _num_vars = 2
    _num_keys = 2


class PhiloxFourFloats(RNGBase):
    _name = "philox"
    _data_type = np.float32
    _num_vars = 4
    _num_keys = 2
93
94


Michael Kuron's avatar
Michael Kuron committed
95
96
class AESNITwoDoubles(RNGBase):
    _name = "aesni"
97
98
    _data_type = np.float64
    _num_vars = 2
Michael Kuron's avatar
Michael Kuron committed
99
    _num_keys = 4
100

101

Michael Kuron's avatar
Michael Kuron committed
102
103
class AESNIFourFloats(RNGBase):
    _name = "aesni"
104
105
    _data_type = np.float32
    _num_vars = 4
Michael Kuron's avatar
Michael Kuron committed
106
    _num_keys = 4
Martin Bauer's avatar
Martin Bauer committed
107
108


109
110
def random_symbol(assignment_list, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles, *args, **kwargs):
    counter = 0
Martin Bauer's avatar
Martin Bauer committed
111
    while True:
112
        node = rng_node(*args, keys=(counter, seed), **kwargs)
Martin Bauer's avatar
Martin Bauer committed
113
114
115
116
117
118
        inserted = False
        for symbol in node.result_symbols:
            if not inserted:
                assignment_list.insert(0, node)
                inserted = True
            yield symbol
119
        counter += 1