indexing.py 7.2 KB
Newer Older
1
2
3
4
5
6
7
import numpy as np
import sympy as sp
import pystencils as ps

from pystencils.data_types import TypedSymbol, create_type
from pystencils.backends.cbackend import CustomCodeNode

Frederik Hennig's avatar
Frederik Hennig committed
8
from lbmpy.advanced_streaming.utility import get_accessor, inverse_dir_index, is_inplace, Timestep
9

10
11
12
from itertools import product


13
class BetweenTimestepsIndexing:
14

15
16
17
    #   ==============================================
    #       Symbols for usage in kernel definitions
    #   ==============================================
18

Frederik Hennig's avatar
Frederik Hennig committed
19
    @property
20
    def proxy_fields(self):
21
        return ps.fields(f"f_out({self._q}), f_in({self._q}): [{self._dim}D]")
22

Frederik Hennig's avatar
Frederik Hennig committed
23
    @property
24
    def dir_symbol(self):
25
        return TypedSymbol('dir', create_type(self._index_dtype))
26

Frederik Hennig's avatar
Frederik Hennig committed
27
    @property
28
29
30
31
32
33
34
    def inverse_dir_symbol(self):
        return sp.IndexedBase('invdir')

    #   =============================
    #       Constructor and State
    #   =============================

Frederik Hennig's avatar
Frederik Hennig committed
35
    def __init__(self, pdf_field, stencil, prev_timestep=Timestep.BOTH, streaming_pattern='pull',
36
                 index_dtype=np.int32, offsets_dtype=np.int32):
Frederik Hennig's avatar
Frederik Hennig committed
37
38
        if prev_timestep == Timestep.BOTH and is_inplace(streaming_pattern):
            raise ValueError('Cannot create index arrays for both kinds of timesteps for inplace streaming pattern '
39
                             + streaming_pattern)
40

Frederik Hennig's avatar
Frederik Hennig committed
41
42
        prev_accessor = get_accessor(streaming_pattern, prev_timestep)
        next_accessor = get_accessor(streaming_pattern, prev_timestep.next())
43

Frederik Hennig's avatar
Frederik Hennig committed
44
45
        outward_accesses = prev_accessor.write(pdf_field, stencil)
        inward_accesses = next_accessor.read(pdf_field, stencil)
46

47
        self._accesses = {'out': outward_accesses, 'in': inward_accesses}
48

49
50
51
52
53
        self._pdf_field = pdf_field
        self._stencil = stencil
        self._dim = len(stencil[0])
        self._q = len(stencil)
        self._coordinate_names = ['x', 'y', 'z'][:self._dim]
54

55
56
57
        self._index_dtype = create_type(index_dtype)
        self._offsets_dtype = create_type(offsets_dtype)

58
59
        self._required_arrays = set()
        self._trivial_translations = self._collect_trivial_translations()
60

61
    def _index_array_symbol(self, f_dir, inverse):
62
63
64
        assert f_dir in ['in', 'out']
        inv = '_inv' if inverse else ''
        name = f"f_{f_dir}{inv}_dir_idx"
65
        return TypedSymbol(name, self._index_dtype)
66

67
    def _offset_array_symbols(self, f_dir, inverse):
68
69
70
        assert f_dir in ['in', 'out']
        inv = '_inv' if inverse else ''
        name_base = f"f_{f_dir}{inv}_offsets_"
71
72
        symbols = [TypedSymbol(name_base + d, self._index_dtype) for d in self._coordinate_names]
        return symbols
73
74

    def _array_symbols(self, f_dir, inverse, index):
75
76
        if (f_dir, inverse) in self._trivial_translations:
            return {'index': index, 'offsets': (0, ) * self._dim}
77

78
79
        index_array_symbol = self._index_array_symbol(f_dir, inverse)
        offset_array_symbols = self._offset_array_symbols(f_dir, inverse)
80
        self._required_arrays.add((f_dir, inverse))
81
        return {
82
83
84
            'index': sp.IndexedBase(index_array_symbol, shape=(1,))[index],
            'offsets': tuple(sp.IndexedBase(s, shape=(1,))[index]
                             for s in offset_array_symbols)
85
86
87
88
89
90
91
92
93
94
95
96
97
        }

    #   =================================
    #       Proxy fields substitution
    #   =================================

    def substitute_proxies(self, assignments):
        if isinstance(assignments, ps.Assignment):
            assignments = [assignments]

        if not isinstance(assignments, ps.AssignmentCollection):
            assignments = ps.AssignmentCollection(assignments)

98
        accesses = self._accesses
Frederik Hennig's avatar
Frederik Hennig committed
99
100
        f_out, f_in = self.proxy_fields
        inv_dir = self.inverse_dir_symbol
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

        accessor_subs = dict()

        for fa in assignments.atoms(ps.Field.Access):
            if fa.field == f_out:
                f_dir = 'out'
            elif fa.field == f_in:
                f_dir = 'in'
            else:
                continue

            inv = False
            idx = fa.index[0]
            if isinstance(idx, sp.Indexed) and idx.base == inv_dir:
                idx = idx.indices[0]
                if isinstance(sp.sympify(idx), sp.Integer):
117
                    idx = inverse_dir_index(self._stencil, idx)
118
119
120
121
122
123
                inv = True

            if isinstance(sp.sympify(idx), sp.Integer):
                accessor_subs[fa] = accesses[f_dir][idx].get_shifted(*(fa.offsets))
            else:
                arr = self._array_symbols(f_dir, inv, idx)
124
                accessor_subs[fa] = self._pdf_field[arr['offsets']](arr['index']).get_shifted(*(fa.offsets))
125
126
127
128
129
130
131
132

        return assignments.new_with_substitutions(accessor_subs)

    #   =================
    #       Internals
    #   =================

    def _get_translated_indices_and_offsets(self, f_dir, inv):
133
        accesses = self._accesses[f_dir]
134
135

        if inv:
136
            inverse_indices = [inverse_dir_index(self._stencil, i)
137
                               for i in range(len(self._stencil))]
138
139
140
141
            accesses = [accesses[idx] for idx in inverse_indices]

        indices = [a.index[0] for a in accesses]
        offsets = []
142
        for d in range(self._dim):
143
144
145
146
147
            offsets.append([a.offsets[d] for a in accesses])
        return indices, offsets

    def _collect_trivial_translations(self):
        trivials = set()
148
149
        trivial_indices = list(range(self._q))
        trivial_offsets = [[0] * self._q] * self._dim
150
151
152
153
154
155
        for f_dir, inv in product(['in', 'out'], [False, True]):
            indices, offsets = self._get_translated_indices_and_offsets(f_dir, inv)
            if indices == trivial_indices and offsets == trivial_offsets:
                trivials.add((f_dir, inv))
        return trivials

156
    def create_code_node(self):
157
        return BetweenTimestepsIndexing.TranslationArraysNode(self)
158
159
160
161

    class TranslationArraysNode(CustomCodeNode):

        def __init__(self, indexing):
162
163
            code = ''
            symbols_defined = set()
164

165
            for f_dir, inv in indexing._required_arrays:
166
167
                indices, offsets = indexing._get_translated_indices_and_offsets(f_dir, inv)

168
169
                index_array_symbol = indexing._index_array_symbol(f_dir, inv)
                symbols_defined.add(index_array_symbol)
170
                acc_indices = ", ".join([str(i) for i in indices])
171
                code += self._array_pattern(indexing._index_dtype, index_array_symbol.name, acc_indices)
172

173
174
175
                offset_array_symbols = indexing._offset_array_symbols(f_dir, inv)
                symbols_defined |= set(offset_array_symbols)
                for d, arrsymb in enumerate(offset_array_symbols):
176
                    acc_offsets = ", ".join([str(o) for o in offsets[d]])
177
                    code += self._array_pattern(indexing._offsets_dtype, arrsymb.name, acc_offsets)
178

179
            super(BetweenTimestepsIndexing.TranslationArraysNode, self).__init__(
180
                code, symbols_read=set(), symbols_defined=symbols_defined)
181

182
183
        def _array_pattern(self, dtype, name, content):
            return f"const {str(dtype)} {name} [] = {{ {content} }}; \n"
184

185
        def __str__(self):
186
            return "Variable PDF Access Translation Arrays"
187
188

        def __repr__(self):
189
            return "Variable PDF Access Translation Arrays"
190

191
#   end class AdvancedStreamingIndexing