rng.py 3.58 KB
Newer Older
1
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
2
3
import sympy as sp

4
5
6
7
8
from pystencils import TypedSymbol
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import CustomCodeNode


Michael Kuron's avatar
Michael Kuron committed
9
def _get_rng_template(name, data_type, num_vars):
10
11
12
13
14
15
    if data_type is np.float32:
        c_type = "float"
    elif data_type is np.float64:
        c_type = "double"
    template = "\n"
    for i in range(num_vars):
16
        template += f"{{result_symbols[{i}].dtype}} {{result_symbols[{i}].name}};\n"
Michael Kuron's avatar
Michael Kuron committed
17
18
    template += ("{}_{}{}({{parameters}}, " + ", ".join(["{{result_symbols[{}].name}}"] * num_vars) + ");\n") \
        .format(name, c_type, num_vars, *tuple(range(num_vars)))
19
    return template
20
21


22
def _get_rng_code(template, dialect, vector_instruction_set, args, result_symbols):
23
    if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
24
        return template.format(parameters=', '.join(str(a) for a in args),
25
26
27
28
29
                               result_symbols=result_symbols)
    else:
        raise NotImplementedError("Not yet implemented for this backend")


Michael Kuron's avatar
Michael Kuron committed
30
class RNGBase(CustomCodeNode):
31

32
33
34
    id = 0

    def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=None, keys=None):
Michael Kuron's avatar
Michael Kuron committed
35
36
        if keys is None:
            keys = (0,) * self._num_keys
37
38
        if offsets is None:
            offsets = (0,) * dim
Michael Kuron's avatar
Michael Kuron committed
39
        if len(keys) != self._num_keys:
40
            raise ValueError(f"Provided {len(keys)} keys but need {self._num_keys}")
41
42
43
44
45
46
47
48
49
50
        if len(offsets) != dim:
            raise ValueError(f"Provided {len(offsets)} offsets but need {dim}")
        coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)]
        if dim < 3:
            coordinates.append(0)

        self._args = sp.sympify([time_step, *coordinates, *keys])
        self.result_symbols = tuple(TypedSymbol(f'random_{self.id}_{i}', self._data_type)
                                    for i in range(self._num_vars))
        symbols_read = set.union(*[s.atoms(sp.Symbol) for s in self.args])
51
        super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
52

53
        self.headers = [f'"{self._name}_rand.h"']
54
55

        RNGBase.id += 1
56
57
58
59
60
61

    @property
    def args(self):
        return self._args

    def get_code(self, dialect, vector_instruction_set):
Michael Kuron's avatar
Michael Kuron committed
62
        template = _get_rng_template(self._name, self._data_type, self._num_vars)
63
        return _get_rng_code(template, dialect, vector_instruction_set, self.args, self.result_symbols)
64
65

    def __repr__(self):
Michael Kuron's avatar
Michael Kuron committed
66
67
        return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols,
                                                                                  self._name.capitalize())
Michael Kuron's avatar
Michael Kuron committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81


class PhiloxTwoDoubles(RNGBase):
    _name = "philox"
    _data_type = np.float64
    _num_vars = 2
    _num_keys = 2


class PhiloxFourFloats(RNGBase):
    _name = "philox"
    _data_type = np.float32
    _num_vars = 4
    _num_keys = 2
82
83


Michael Kuron's avatar
Michael Kuron committed
84
85
class AESNITwoDoubles(RNGBase):
    _name = "aesni"
86
87
    _data_type = np.float64
    _num_vars = 2
Michael Kuron's avatar
Michael Kuron committed
88
    _num_keys = 4
89

90

Michael Kuron's avatar
Michael Kuron committed
91
92
class AESNIFourFloats(RNGBase):
    _name = "aesni"
93
94
    _data_type = np.float32
    _num_vars = 4
Michael Kuron's avatar
Michael Kuron committed
95
    _num_keys = 4
Martin Bauer's avatar
Martin Bauer committed
96
97


98
99
def random_symbol(assignment_list, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles, *args, **kwargs):
    counter = 0
Martin Bauer's avatar
Martin Bauer committed
100
    while True:
101
        node = rng_node(*args, keys=(counter, seed), **kwargs)
Martin Bauer's avatar
Martin Bauer committed
102
103
104
105
106
107
        inserted = False
        for symbol in node.result_symbols:
            if not inserted:
                assignment_list.insert(0, node)
                inserted = True
            yield symbol
108
        counter += 1