indexing.py 8.1 KB
Newer Older
1
2
3
4
5
6
7
import numpy as np
import sympy as sp
import pystencils as ps

from pystencils.data_types import TypedSymbol, create_type
from pystencils.backends.cbackend import CustomCodeNode

Frederik Hennig's avatar
Frederik Hennig committed
8
from lbmpy.advanced_streaming.utility import get_accessor, inverse_dir_index, is_inplace, Timestep
9

10
11
from itertools import product

12
13
def _array_pattern(dtype, name, content):
    return f"const {str(dtype)} {name} [] = {{ {','.join(str(c) for c in content)} }}; \n"
14

15
class BetweenTimestepsIndexing:
16

17
18
19
    #   ==============================================
    #       Symbols for usage in kernel definitions
    #   ==============================================
20

Frederik Hennig's avatar
Frederik Hennig committed
21
    @property
22
    def proxy_fields(self):
23
        return ps.fields(f"f_out({self._q}), f_in({self._q}): [{self._dim}D]")
24

Frederik Hennig's avatar
Frederik Hennig committed
25
    @property
26
    def dir_symbol(self):
27
        return TypedSymbol('dir', create_type(self._index_dtype))
28

Frederik Hennig's avatar
Frederik Hennig committed
29
    @property
30
31
32
33
34
35
36
    def inverse_dir_symbol(self):
        return sp.IndexedBase('invdir')

    #   =============================
    #       Constructor and State
    #   =============================

Frederik Hennig's avatar
Frederik Hennig committed
37
    def __init__(self, pdf_field, stencil, prev_timestep=Timestep.BOTH, streaming_pattern='pull',
38
                 index_dtype=np.int32, offsets_dtype=np.int32):
Frederik Hennig's avatar
Frederik Hennig committed
39
40
        if prev_timestep == Timestep.BOTH and is_inplace(streaming_pattern):
            raise ValueError('Cannot create index arrays for both kinds of timesteps for inplace streaming pattern '
41
                             + streaming_pattern)
42

Frederik Hennig's avatar
Frederik Hennig committed
43
44
        prev_accessor = get_accessor(streaming_pattern, prev_timestep)
        next_accessor = get_accessor(streaming_pattern, prev_timestep.next())
45

Frederik Hennig's avatar
Frederik Hennig committed
46
47
        outward_accesses = prev_accessor.write(pdf_field, stencil)
        inward_accesses = next_accessor.read(pdf_field, stencil)
48

49
        self._accesses = {'out': outward_accesses, 'in': inward_accesses}
50

51
52
53
54
55
        self._pdf_field = pdf_field
        self._stencil = stencil
        self._dim = len(stencil[0])
        self._q = len(stencil)
        self._coordinate_names = ['x', 'y', 'z'][:self._dim]
56

57
58
59
        self._index_dtype = create_type(index_dtype)
        self._offsets_dtype = create_type(offsets_dtype)

60
61
        self._required_arrays = set()
        self._trivial_translations = self._collect_trivial_translations()
62

63
    def _index_array_symbol(self, f_dir, inverse):
64
65
66
        assert f_dir in ['in', 'out']
        inv = '_inv' if inverse else ''
        name = f"f_{f_dir}{inv}_dir_idx"
67
        return TypedSymbol(name, self._index_dtype)
68

69
    def _offset_array_symbols(self, f_dir, inverse):
70
71
72
        assert f_dir in ['in', 'out']
        inv = '_inv' if inverse else ''
        name_base = f"f_{f_dir}{inv}_offsets_"
73
74
        symbols = [TypedSymbol(name_base + d, self._index_dtype) for d in self._coordinate_names]
        return symbols
75
76

    def _array_symbols(self, f_dir, inverse, index):
77
78
        if (f_dir, inverse) in self._trivial_translations:
            return {'index': index, 'offsets': (0, ) * self._dim}
79

80
81
        index_array_symbol = self._index_array_symbol(f_dir, inverse)
        offset_array_symbols = self._offset_array_symbols(f_dir, inverse)
82
        self._required_arrays.add((f_dir, inverse))
83
        return {
84
85
86
            'index': sp.IndexedBase(index_array_symbol, shape=(1,))[index],
            'offsets': tuple(sp.IndexedBase(s, shape=(1,))[index]
                             for s in offset_array_symbols)
87
88
89
90
91
92
93
94
95
96
97
98
99
        }

    #   =================================
    #       Proxy fields substitution
    #   =================================

    def substitute_proxies(self, assignments):
        if isinstance(assignments, ps.Assignment):
            assignments = [assignments]

        if not isinstance(assignments, ps.AssignmentCollection):
            assignments = ps.AssignmentCollection(assignments)

100
        accesses = self._accesses
Frederik Hennig's avatar
Frederik Hennig committed
101
102
        f_out, f_in = self.proxy_fields
        inv_dir = self.inverse_dir_symbol
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

        accessor_subs = dict()

        for fa in assignments.atoms(ps.Field.Access):
            if fa.field == f_out:
                f_dir = 'out'
            elif fa.field == f_in:
                f_dir = 'in'
            else:
                continue

            inv = False
            idx = fa.index[0]
            if isinstance(idx, sp.Indexed) and idx.base == inv_dir:
                idx = idx.indices[0]
                if isinstance(sp.sympify(idx), sp.Integer):
119
                    idx = inverse_dir_index(self._stencil, idx)
120
121
122
123
124
125
                inv = True

            if isinstance(sp.sympify(idx), sp.Integer):
                accessor_subs[fa] = accesses[f_dir][idx].get_shifted(*(fa.offsets))
            else:
                arr = self._array_symbols(f_dir, inv, idx)
126
                accessor_subs[fa] = self._pdf_field[arr['offsets']](arr['index']).get_shifted(*(fa.offsets))
127
128
129
130
131
132
133
134

        return assignments.new_with_substitutions(accessor_subs)

    #   =================
    #       Internals
    #   =================

    def _get_translated_indices_and_offsets(self, f_dir, inv):
135
        accesses = self._accesses[f_dir]
136
137

        if inv:
138
            inverse_indices = [inverse_dir_index(self._stencil, i)
139
                               for i in range(len(self._stencil))]
140
141
142
143
            accesses = [accesses[idx] for idx in inverse_indices]

        indices = [a.index[0] for a in accesses]
        offsets = []
144
        for d in range(self._dim):
145
146
147
148
149
            offsets.append([a.offsets[d] for a in accesses])
        return indices, offsets

    def _collect_trivial_translations(self):
        trivials = set()
150
151
        trivial_indices = list(range(self._q))
        trivial_offsets = [[0] * self._q] * self._dim
152
153
154
155
156
157
        for f_dir, inv in product(['in', 'out'], [False, True]):
            indices, offsets = self._get_translated_indices_and_offsets(f_dir, inv)
            if indices == trivial_indices and offsets == trivial_offsets:
                trivials.add((f_dir, inv))
        return trivials

158
    def create_code_node(self):
159
        return BetweenTimestepsIndexing.TranslationArraysNode(self)
160
161
162
163

    class TranslationArraysNode(CustomCodeNode):

        def __init__(self, indexing):
164
165
            code = ''
            symbols_defined = set()
166

167
            for f_dir, inv in indexing._required_arrays:
168
169
                indices, offsets = indexing._get_translated_indices_and_offsets(f_dir, inv)

170
171
                index_array_symbol = indexing._index_array_symbol(f_dir, inv)
                symbols_defined.add(index_array_symbol)
172
                code += _array_pattern(indexing._index_dtype, index_array_symbol.name, indices)
173

174
175
176
                offset_array_symbols = indexing._offset_array_symbols(f_dir, inv)
                symbols_defined |= set(offset_array_symbols)
                for d, arrsymb in enumerate(offset_array_symbols):
177
                    code += _array_pattern(indexing._offsets_dtype, arrsymb.name, offsets[d])
178

179
            super(BetweenTimestepsIndexing.TranslationArraysNode, self).__init__(
180
                code, symbols_read=set(), symbols_defined=symbols_defined)
181

182
        def __str__(self):
183
            return "Variable PDF Access Translation Arrays"
184
185

        def __repr__(self):
186
            return "Variable PDF Access Translation Arrays"
187

188
#   end class AdvancedStreamingIndexing
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

class NeighbourOffsetArraysForStencil(CustomCodeNode):

    @staticmethod
    def symbolic_neighbour_offset_from_dir(dir_idx, dim):
        return tuple([sp.IndexedBase(symbol, shape=(1,))[dir_idx]
                      for symbol in NeighbourOffsetArraysForStencil._offset_symbols(dim)])

    @staticmethod
    def _offset_symbols(dim):
        return [TypedSymbol(f"neighbour_offset_{d}", create_type(np.int64)) for d in ['x', 'y', 'z'][:dim]]

    def __init__(self, stencil, offsets_dtype=np.int32):
        offsets_dtype = create_type(offsets_dtype)
        dim = len(stencil[0])

        array_symbols = NeighbourOffsetArraysForStencil._offset_symbols(dim)
        code = "\n"
        for i, arrsymb in enumerate(array_symbols):
            code += _array_pattern(offsets_dtype, arrsymb.name, (d[i] for d in stencil))

        offset_symbols = NeighbourOffsetArraysForStencil._offset_symbols(dim)
        super(NeighbourOffsetArraysForStencil, self).__init__(code, symbols_read=set(),
                                                 symbols_defined=set(offset_symbols))