indexing.py 5.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np
import sympy as sp
import pystencils as ps

from pystencils.data_types import TypedSymbol, create_type

from lbmpy.fieldaccess import StreamPullTwoFieldsAccessor, \
    StreamPushTwoFieldsAccessor, \
    AAEvenTimeStepAccessor, \
    AAOddTimeStepAccessor, \
    EsoTwistEvenTimeStepAccessor, \
    EsoTwistOddTimeStepAccessor

class AdvancedStreamingIndexing:

    #   =======================================
    #       Symbols for usage in boundaries
    #   =======================================

    def proxy_fields(self):
        q = len(self.stencil)
        d = len(self.stencil[0])
Frederik Hennig's avatar
Frederik Hennig committed
23
        return ps.fields(f"f_out({q}), f_in({q}): [{d}]")
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    def dir_symbol(self):
        return TypedSymbol('dir', create_type(np.int64))

    def inverse_dir_symbol(self):
        return sp.IndexedBase('invdir')

    #   =============================
    #       Constructor and State
    #   =============================

    def __init__(self, pdf_field, stencil, between_timesteps = 'both', kernel_type='pull'):
        if between_timesteps not in ['both', 'odd_to_even', 'even_to_odd']:
            raise ValueError(
                "Invalid value of parameter 'between_timesteps'. Must be one of 'both', 'odd_to_even' or 'even_to_odd'.",
                between_timesteps)

        if between_timesteps == 'both' and kernel_type in ['aa', 'esotwist']:
            raise ValueError('Cannot create an offset info for both kinds of timesteps for kernel type ' + kernel_type)

        odd_to_even = (between_timesteps == 'odd_to_even')

        even_accessors = {
            'pull': StreamPullTwoFieldsAccessor,
            'push': StreamPushTwoFieldsAccessor,
            'aa': AAEvenTimeStepAccessor,
            'esotwist': EsoTwistEvenTimeStepAccessor
        }

        odd_accessors = {
            'pull': StreamPullTwoFieldsAccessor,
            'push': StreamPushTwoFieldsAccessor,
            'aa': AAOddTimeStepAccessor,
            'esotwist': EsoTwistOddTimeStepAccessor
        }

        try:
            even_accessor = even_accessors[kernel_type]
            odd_accessor = odd_accessors[kernel_type]
        except KeyError:
            raise ValueError(
                "Invalid value of parameter 'kernel_type'.", kernel_type)

        if between_timesteps == 'both':
            assert even_accessor == odd_accessor

        outward_accesses = (
            odd_accessor if odd_to_even else even_accessor).write(pdf_field, stencil)
        inward_accesses = (
            even_accessor if odd_to_even else odd_accessor).read(pdf_field, stencil)

        self.outward_accesses = outward_accesses
        self.inward_accesses = inward_accesses

        self.pdf_field = pdf_field
        self.stencil = stencil
        self.directions = ['x', 'y', 'z'][:len(stencil[0])]

        #   Collection of translation arrays required in generated code
        self.required_arrays = set()

    def _index_array_symbol(self, f_dir, inverse, dir_symbol):
        assert f_dir in ['in', 'out']
Frederik Hennig's avatar
Frederik Hennig committed
87
88
        inv = '_inv' if inverse else ''
        name = f"f_{f_dir}{inv}_dir_idx"
89
90
91
92
93
        self.required_arrays.add(name)
        return sp.IndexedBase(name)[dir_symbol]

    def _offset_array_symbols(self, f_dir, inverse, dir_symbol):
        assert f_dir in ['in', 'out']
Frederik Hennig's avatar
Frederik Hennig committed
94
95
        inv = '_inv' if inverse else ''
        name_base = f"f_{f_dir}{inv}_offsets_"
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        names = [ name_base + d for d in self.directions ]
        self.required_arrays |= set(names)
        return [ sp.IndexedBase(n)[dir_symbol] for n in names ]


    #   =================================
    #       Proxy fields substitution
    #   =================================

    def substitute_proxies(self, assignments):
        # Get the accessor lists for the streaming pattern (here AA)
        outward_accesses = self.outward_accesses
        inward_accesses = self.inward_accesses
        
        f_out, f_in = self.proxy_fields()
        dir_symbol = self.dir_symbol()
        inv_dir = self.inverse_dir_symbol()

        # Substitute all proxy field accesses

        if not isinstance(assignments, ps.AssignmentCollection):
            assignments = ps.AssignmentCollection([assignments])
        
        accessor_subs = dict()

        for fa in assignments.atoms(ps.Field.Access):
            if fa.field == f_out:
                if fa.offsets == (0,0):
                    if isinstance(fa.index[0], int):
                        accessor_subs[fa] = outward_accesses[fa.index[0]]
                    elif fa.index[0] == dir_symbol:
                        accessor_subs[fa] = self.pdf_field[ 
                            self._offset_array_symbols('out', False, dir_symbol)
                            ](self._index_array_symbol('out', False, dir_symbol))
                    else:
                        # other cases like inverse direction, etc.
                        pass
                else:
                    # non-trivial neighbour accesses -> add neighbour offset to streaming pattern offsets
                    pass

            elif fa.field == f_in:
                if fa.offsets == (0,0):
                    if isinstance(fa.index[0], int):
                        accessor_subs[fa] = inward_accesses[fa.index[0]]
                    elif fa.index[0] == inv_dir[dir_symbol]:
                        accessor_subs[fa] = self.pdf_field[ 
                            self._offset_array_symbols('in', True, dir_symbol)
                            ](self._index_array_symbol('in', True, dir_symbol))
                    else:
                        # other cases
                        pass
                else:
                    # non-trivial neighbour accesses -> add neighbour offset to streaming pattern offsets
                    pass

            else:
                pass

        return assignments.new_with_substitutions(accessor_subs)

#   end class AdvancedStreamingIndexing