transformations.py 52 KB
Newer Older
1
import warnings
2
from collections import defaultdict, OrderedDict, namedtuple
3
from copy import deepcopy
Martin Bauer's avatar
Martin Bauer committed
4
from types import MappingProxyType
5
6
import pickle
import hashlib
7
8
import sympy as sp
from sympy.logic.boolalg import Boolean
9
from pystencils.simp.assignment_collection import AssignmentCollection
10
from pystencils.assignment import Assignment
11
from pystencils.field import AbstractField, FieldType, Field
12
13
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \
    cast_func, pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
14
from pystencils.kernelparameters import FieldPointerSymbol
Martin Bauer's avatar
Martin Bauer committed
15
from pystencils.slicing import normalize_slice
Martin Bauer's avatar
Martin Bauer committed
16
import pystencils.astnodes as ast
17
18


19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class NestedScopes:
    """Symbol visibility model using nested scopes

    - every accessed symbol that was not defined before, is added as a "free parameter"
    - free parameters are global, i.e. they are not in scopes
    - push/pop adds or removes a scope

    >>> s = NestedScopes()
    >>> s.access_symbol("a")
    >>> s.is_defined("a")
    False
    >>> s.free_parameters
    {'a'}
    >>> s.define_symbol("b")
    >>> s.is_defined("b")
    True
    >>> s.push()
    >>> s.is_defined_locally("b")
    False
    >>> s.define_symbol("c")
    >>> s.pop()
    >>> s.is_defined("c")
    False
    """

    def __init__(self):
        self.free_parameters = set()
        self._defined = [set()]

    def access_symbol(self, symbol):
        if not self.is_defined(symbol):
            self.free_parameters.add(symbol)

    def define_symbol(self, symbol):
        self._defined[-1].add(symbol)

    def is_defined(self, symbol):
        return any(symbol in scopes for scopes in self._defined)

    def is_defined_locally(self, symbol):
        return symbol in self._defined[-1]

    def push(self):
        self._defined.append(set())

    def pop(self):
        self._defined.pop()
        assert self.depth >= 1

    @property
    def depth(self):
        return len(self._defined)


Martin Bauer's avatar
Martin Bauer committed
73
def filtered_tree_iteration(node, node_type, stop_type=None):
74
    for arg in node.args:
Martin Bauer's avatar
Martin Bauer committed
75
        if isinstance(arg, node_type):
76
            yield arg
Martin Bauer's avatar
Martin Bauer committed
77
78
79
        elif stop_type and isinstance(node, stop_type):
            continue

Martin Bauer's avatar
Martin Bauer committed
80
        yield from filtered_tree_iteration(arg, node_type)
81
82


83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def generic_visit(term, visitor):
    if isinstance(term, AssignmentCollection):
        new_main_assignments = generic_visit(term.main_assignments, visitor)
        new_subexpressions = generic_visit(term.subexpressions, visitor)
        return term.copy(new_main_assignments, new_subexpressions)
    elif isinstance(term, list):
        return [generic_visit(e, visitor) for e in term]
    elif isinstance(term, Assignment):
        return Assignment(term.lhs, generic_visit(term.rhs, visitor))
    elif isinstance(term, sp.Matrix):
        return term.applyfunc(lambda e: generic_visit(e, visitor))
    else:
        return visitor(term)


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def unify_shape_symbols(body, common_shape, fields):
    """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.

    When creating a kernel with variable array sizes, all passed arrays must have the same size.
    This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead
    of one for each field. For example shape_arr1[0]  and shape_arr2[0] must be equal, so they should also be
    represented by the same symbol.

    Args:
        body: ast node, for the kernel part where substitutions is made, is modified in-place
        common_shape: shape of the field that was chosen
        fields: all fields whose shapes should be replaced by common_shape
    """
    substitutions = {}
    for field in fields:
        assert len(field.spatial_shape) == len(common_shape)
        if not field.has_fixed_shape:
            for common_shape_component, shape_component in zip(common_shape, field.spatial_shape):
                if shape_component != common_shape_component:
                    substitutions[shape_component] = common_shape_component
    if substitutions:
        body.subs(substitutions)


Martin Bauer's avatar
Martin Bauer committed
122
def get_common_shape(field_set):
123
124
    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
    ValueError is raised"""
Martin Bauer's avatar
Martin Bauer committed
125
126
127
128
129
130
131
132
    nr_of_fixed_shaped_fields = 0
    for f in field_set:
        if f.has_fixed_shape:
            nr_of_fixed_shaped_fields += 1

    if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set):
        fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
        var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
133
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
Martin Bauer's avatar
Martin Bauer committed
134
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (var_field_names, fixed_field_names)
135
136
        raise ValueError(msg)

Martin Bauer's avatar
Martin Bauer committed
137
138
139
140
    shape_set = set([f.spatial_shape for f in field_set])
    if nr_of_fixed_shaped_fields == len(field_set):
        if len(shape_set) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
141

Martin Bauer's avatar
Martin Bauer committed
142
    shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
143
144
145
    return shape


146
def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
Martin Bauer's avatar
Martin Bauer committed
147
148
149
    """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.

    Args:
150
        body: Block object with inner loop contents
Martin Bauer's avatar
Martin Bauer committed
151
152
153
154
155
156
157
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
             if None, the number of ghost layers is determined automatically and assumed to be equal for a
             all dimensions
        loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)

    Returns:
158
        tuple of loop-node, ghost_layer_info
159
160
    """
    # find correct ordering by inspecting participating FieldAccesses
161
    field_accesses = body.atoms(AbstractField.AbstractAccess)
162
163
    field_accesses = {e for e in field_accesses if not e.is_absolute_access}

Martin Bauer's avatar
Martin Bauer committed
164
165
166
167
168
169
170
    # exclude accesses to buffers from field_list, because buffers are treated separately
    field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)]
    fields = set(field_list)

    if loop_order is None:
        loop_order = get_optimal_loop_ordering(fields)

171
172
    shape = get_common_shape(fields)
    unify_shape_symbols(body, common_shape=shape, fields=fields)
Martin Bauer's avatar
Martin Bauer committed
173
174
175
176
177
178
179
180
181
182
183

    if iteration_slice is not None:
        iteration_slice = normalize_slice(iteration_slice, shape)

    if ghost_layers is None:
        required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
        ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
    if isinstance(ghost_layers, int):
        ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)

    current_body = body
Martin Bauer's avatar
Martin Bauer committed
184
    for i, loop_coordinate in enumerate(reversed(loop_order)):
Martin Bauer's avatar
Martin Bauer committed
185
        if iteration_slice is None:
Martin Bauer's avatar
Martin Bauer committed
186
187
188
            begin = ghost_layers[loop_coordinate][0]
            end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
            new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
Martin Bauer's avatar
Martin Bauer committed
189
            current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
190
        else:
Martin Bauer's avatar
Martin Bauer committed
191
            slice_component = iteration_slice[loop_coordinate]
Martin Bauer's avatar
Martin Bauer committed
192
193
            if type(slice_component) is slice:
                sc = slice_component
Martin Bauer's avatar
Martin Bauer committed
194
                new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
Martin Bauer's avatar
Martin Bauer committed
195
                current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
196
            else:
Martin Bauer's avatar
Martin Bauer committed
197
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
Martin Bauer's avatar
Martin Bauer committed
198
199
                                                 sp.sympify(slice_component))
                current_body.insert_front(assignment)
200

201
    return current_body, ghost_layers
202
203


Martin Bauer's avatar
Martin Bauer committed
204
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
Martin Bauer's avatar
Martin Bauer committed
205
    r"""
206
    Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
Martin Bauer's avatar
Martin Bauer committed
207
208
209
210
211
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.

212
213
214
215
216
217
218
219
220
    Args:
        field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
        coordinates: mapping of coordinate ids to its value, where stride*value is calculated
        previous_ptr: the pointer which is de-referenced

    Returns
        tuple with the new pointer symbol and the calculated offset

    Examples:
Martin Bauer's avatar
Martin Bauer committed
221
        >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
Martin Bauer's avatar
Martin Bauer committed
222
        >>> x, y = sp.symbols("x y")
Martin Bauer's avatar
Martin Bauer committed
223
224
        >>> prev_pointer = TypedSymbol("ptr", "double")
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
225
        (ptr_01, _stride_myfield_0*x + _stride_myfield_0)
Martin Bauer's avatar
Martin Bauer committed
226
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
227
        (ptr_01_1m2, _stride_myfield_0*x + _stride_myfield_0 + _stride_myfield_1*y - 2*_stride_myfield_1)
Martin Bauer's avatar
Martin Bauer committed
228
    """
Martin Bauer's avatar
Martin Bauer committed
229
    field = field_access.field
230
231
    offset = 0
    name = ""
Martin Bauer's avatar
Martin Bauer committed
232
    list_to_hash = []
Martin Bauer's avatar
Martin Bauer committed
233
234
    for coordinate_id, coordinate_value in coordinates.items():
        offset += field.strides[coordinate_id] * coordinate_value
235

Martin Bauer's avatar
Martin Bauer committed
236
237
238
        if coordinate_id < field.spatial_dimensions:
            offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
            if type(field_access.offsets[coordinate_id]) is int:
Martin Bauer's avatar
Martin Bauer committed
239
                name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
240
            else:
Martin Bauer's avatar
Martin Bauer committed
241
                list_to_hash.append(field_access.offsets[coordinate_id])
242
        else:
Martin Bauer's avatar
Martin Bauer committed
243
            if type(coordinate_value) is int:
Martin Bauer's avatar
Martin Bauer committed
244
                name += "_%d%d" % (coordinate_id, coordinate_value)
245
            else:
Martin Bauer's avatar
Martin Bauer committed
246
                list_to_hash.append(coordinate_value)
247

Martin Bauer's avatar
Martin Bauer committed
248
    if len(list_to_hash) > 0:
249
        name += hashlib.md5(pickle.dumps(list_to_hash)).hexdigest()[:16]
250

Martin Bauer's avatar
Martin Bauer committed
251
    name = name.replace("-", 'm')
Martin Bauer's avatar
Martin Bauer committed
252
253
    new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
    return new_ptr, offset
254
255


Martin Bauer's avatar
Martin Bauer committed
256
def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
257
    """
Martin Bauer's avatar
Martin Bauer committed
258
    Creates base pointer specification for :func:`resolve_field_accesses` function.
Martin Bauer's avatar
Martin Bauer committed
259
260
261

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
262
    zero directly in the field access. These specifications are defined dependent on the loop ordering.
Martin Bauer's avatar
Martin Bauer committed
263
264
    This function translates more readable version into the specification above.

265
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
266
267
268
269
270
271
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

Martin Bauer's avatar
Martin Bauer committed
272
273
274
    Args:
        base_pointer_specification: nested list with above specifications
        loop_order: list with ordering of loops from outer to inner
Martin Bauer's avatar
Martin Bauer committed
275
276
        spatial_dimensions: number of spatial dimensions
        index_dimensions: number of index dimensions
Martin Bauer's avatar
Martin Bauer committed
277
278
279

    Returns:
        list of tuples that can be passed to :func:`resolve_field_accesses`
Martin Bauer's avatar
Martin Bauer committed
280
281
282
283
284

    Examples:
        >>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
        ...                         spatial_dimensions=3, index_dimensions=1)
        [[0], [3], [1, 2]]
285
286
    """
    result = []
Martin Bauer's avatar
Martin Bauer committed
287
288
    specified_coordinates = set()
    loop_order = list(reversed(loop_order))
Martin Bauer's avatar
Martin Bauer committed
289
    for spec_group in base_pointer_specification:
Martin Bauer's avatar
Martin Bauer committed
290
291
292
        new_group = []

        def add_new_element(elem):
Martin Bauer's avatar
Martin Bauer committed
293
            if elem >= spatial_dimensions + index_dimensions:
Martin Bauer's avatar
Martin Bauer committed
294
295
296
297
298
                raise ValueError("Coordinate %d does not exist" % (elem,))
            new_group.append(elem)
            if elem in specified_coordinates:
                raise ValueError("Coordinate %d specified two times" % (elem,))
            specified_coordinates.add(elem)
Martin Bauer's avatar
Martin Bauer committed
299
        for element in spec_group:
300
            if type(element) is int:
Martin Bauer's avatar
Martin Bauer committed
301
                add_new_element(element)
302
303
304
305
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
Martin Bauer's avatar
Martin Bauer committed
306
                    add_new_element(loop_order[index])
307
308
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
Martin Bauer's avatar
Martin Bauer committed
309
                    add_new_element(loop_order[-index])
310
                elif element == "all":
Martin Bauer's avatar
Martin Bauer committed
311
                    for i in range(spatial_dimensions):
Martin Bauer's avatar
Martin Bauer committed
312
                        add_new_element(i)
313
314
315
316
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
Martin Bauer's avatar
Martin Bauer committed
317
                add_new_element(spatial_dimensions + index)
318
319
320
            else:
                raise ValueError("Unknown specification %s" % (element,))

Martin Bauer's avatar
Martin Bauer committed
321
        result.append(new_group)
322

Martin Bauer's avatar
Martin Bauer committed
323
    all_coordinates = set(range(spatial_dimensions + index_dimensions))
Martin Bauer's avatar
Martin Bauer committed
324
    rest = all_coordinates - specified_coordinates
325
326
    if rest:
        result.append(list(rest))
327

328
329
330
    return result


Martin Bauer's avatar
Martin Bauer committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
    """Used for buffer fields to determine the linearized index of the buffer dependent on loop counter symbols.

    Args:
        ast_node: ast before any field accesses are resolved
        loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
                       for GPU kernels: list of 'loop counters' from inner to outer loop
        loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default

    Returns:
        base buffer index - required by 'resolve_buffer_accesses' function
    """
    if loop_counters is None or loop_iterations is None:
        loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
        loops.reverse()
        parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
        assert len(loops) == len(parents_of_innermost_loop)
        assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))

        loop_iterations = [(l.stop - l.start) / l.step for l in loops]
        loop_counters = [l.loop_counter_symbol for l in loops]

353
    field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
Martin Bauer's avatar
Martin Bauer committed
354
355
356
357
358
359
360
361
362
363
364
365
    buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
    loop_counters = [v * len(buffer_accesses) for v in loop_counters]

    base_buffer_index = loop_counters[0]
    stride = 1
    for idx, var in enumerate(loop_counters[1:]):
        cur_stride = loop_iterations[idx]
        stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
        base_buffer_index += var * stride
    return base_buffer_index


Martin Bauer's avatar
Martin Bauer committed
366
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
Martin Bauer's avatar
Martin Bauer committed
367

Martin Bauer's avatar
Martin Bauer committed
368
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
369
        if isinstance(expr, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
370
            field_access = expr
371
372

            # Do not apply transformation if field is not a buffer
Martin Bauer's avatar
Martin Bauer committed
373
            if not FieldType.is_buffer(field_access.field):
374
375
                return expr

Martin Bauer's avatar
Martin Bauer committed
376
            buffer = field_access.field
377
            field_ptr = FieldPointerSymbol(buffer.name, buffer.dtype, const=buffer.name in read_only_field_names)
378

Martin Bauer's avatar
Martin Bauer committed
379
380
            buffer_index = base_buffer_index
            if len(field_access.index) > 1:
381
382
                raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')

Martin Bauer's avatar
Martin Bauer committed
383
384
385
            if len(field_access.index) > 0:
                cell_index = field_access.index[0]
                buffer_index += cell_index
386

Martin Bauer's avatar
Martin Bauer committed
387
388
            result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
                                             field_access.index)
389

Martin Bauer's avatar
Martin Bauer committed
390
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
391
392
393
394
        else:
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr

Martin Bauer's avatar
Martin Bauer committed
395
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
396
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
397
398
399
400
401
402
403
404
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
405
        else:
Martin Bauer's avatar
Martin Bauer committed
406
407
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
408

Martin Bauer's avatar
Martin Bauer committed
409
    return visit_node(ast_node)
410

411

Martin Bauer's avatar
Martin Bauer committed
412
def resolve_field_accesses(ast_node, read_only_field_names=set(),
Martin Bauer's avatar
Martin Bauer committed
413
414
                           field_to_base_pointer_info=MappingProxyType({}),
                           field_to_fixed_coordinates=MappingProxyType({})):
Martin Bauer's avatar
Martin Bauer committed
415
416
417
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

418
419
420
421
422
423
    Args:
        ast_node: the AST root
        read_only_field_names: set of field names which are considered read-only
        field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
                                    for details see :func:`parse_base_pointer_info`
        field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
Martin Bauer's avatar
Martin Bauer committed
424
                                    counters to index the field these symbols are used as coordinates
425
426
427

    Returns
        transformed AST
Martin Bauer's avatar
Martin Bauer committed
428
    """
Martin Bauer's avatar
Martin Bauer committed
429
430
    field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
    field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
431

Martin Bauer's avatar
Martin Bauer committed
432
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
433
        if isinstance(expr, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
434
            field_access = expr
Martin Bauer's avatar
Martin Bauer committed
435
            field = field_access.field
436

Martin Bauer's avatar
Martin Bauer committed
437
            if field_access.indirect_addressing_fields:
438
439
                new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
                                    for off in field_access.offsets)
Martin Bauer's avatar
Martin Bauer committed
440
441
442
443
444
                new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
                                    if isinstance(ind, sp.Basic) else ind
                                    for ind in field_access.index)
                field_access = Field.Access(field_access.field, new_offsets,
                                            new_indices, field_access.is_absolute_access)
445

Martin Bauer's avatar
Martin Bauer committed
446
447
            if field.name in field_to_base_pointer_info:
                base_pointer_info = field_to_base_pointer_info[field.name]
448
            else:
Martin Bauer's avatar
Martin Bauer committed
449
                base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
450

451
            field_ptr = FieldPointerSymbol(field.name, field.dtype, const=field.name in read_only_field_names)
452

Martin Bauer's avatar
Martin Bauer committed
453
454
455
456
            def create_coordinate_dict(group_param):
                coordinates = {}
                for e in group_param:
                    if e < field.spatial_dimensions:
Martin Bauer's avatar
Martin Bauer committed
457
                        if field.name in field_to_fixed_coordinates:
458
459
460
461
                            if not field_access.is_absolute_access:
                                coordinates[e] = field_to_fixed_coordinates[field.name][e]
                            else:
                                coordinates[e] = 0
462
                        else:
Martin Bauer's avatar
Martin Bauer committed
463
464
465
466
                            if not field_access.is_absolute_access:
                                coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
                            else:
                                coordinates[e] = 0
Martin Bauer's avatar
Martin Bauer committed
467
                        coordinates[e] *= field.dtype.item_size
468
                    else:
469
                        if isinstance(field.dtype, StructType):
Martin Bauer's avatar
Martin Bauer committed
470
                            assert field.index_dimensions == 1
Martin Bauer's avatar
Martin Bauer committed
471
472
                            accessed_field_name = field_access.index[0]
                            assert isinstance(accessed_field_name, str)
Martin Bauer's avatar
Martin Bauer committed
473
                            coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
474
                        else:
Martin Bauer's avatar
Martin Bauer committed
475
                            coordinates[e] = field_access.index[e - field.spatial_dimensions]
476

Martin Bauer's avatar
Martin Bauer committed
477
                return coordinates
478

Martin Bauer's avatar
Martin Bauer committed
479
            last_pointer = field_ptr
480

Martin Bauer's avatar
Martin Bauer committed
481
482
            for group in reversed(base_pointer_info[1:]):
                coord_dict = create_coordinate_dict(group)
Martin Bauer's avatar
Martin Bauer committed
483
                new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
484
485
486
487
                if new_ptr not in enclosing_block.symbols_defined:
                    new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
                    enclosing_block.insert_before(new_assignment, sympy_assignment)
                last_pointer = new_ptr
488

Martin Bauer's avatar
Martin Bauer committed
489
            coord_dict = create_coordinate_dict(base_pointer_info[0])
Martin Bauer's avatar
Martin Bauer committed
490
            _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
491
492
            result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
                                             field_access.offsets, field_access.index)
493

Martin Bauer's avatar
Martin Bauer committed
494
495
            if isinstance(get_base_type(field_access.field.dtype), StructType):
                new_type = field_access.field.dtype.get_element_type(field_access.index[0])
496
                result = reinterpret_cast_func(result, new_type)
497

Martin Bauer's avatar
Martin Bauer committed
498
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
499
        else:
Martin Bauer's avatar
Martin Bauer committed
500
501
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr
502

Martin Bauer's avatar
Martin Bauer committed
503
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
504
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
505
506
507
508
509
510
511
512
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
513
514
515
516
517
518
519
        elif isinstance(sub_ast, ast.Conditional):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, enclosing_block, sub_ast)
            visit_node(sub_ast.true_block)
            if sub_ast.false_block:
                visit_node(sub_ast.false_block)
520
        else:
Martin Bauer's avatar
Martin Bauer committed
521
522
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
523

Martin Bauer's avatar
Martin Bauer committed
524
    return visit_node(ast_node)
525
526


Martin Bauer's avatar
Martin Bauer committed
527
def move_constants_before_loop(ast_node):
528
529
    """Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.

Martin Bauer's avatar
Martin Bauer committed
530
    Call this after creating the loop structure with :func:`make_loop_over_domain`
Martin Bauer's avatar
Martin Bauer committed
531
    """
Martin Bauer's avatar
Martin Bauer committed
532
    def find_block_to_move_to(node):
Martin Bauer's avatar
Martin Bauer committed
533
534
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
535
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
536
537
538
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
539
540
        assert isinstance(node.parent, ast.Block)

Martin Bauer's avatar
Martin Bauer committed
541
542
        last_block = node.parent
        last_block_child = node
543
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
544
        prev_element = node
545
546
        while element:
            if isinstance(element, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
547
548
                last_block = element
                last_block_child = prev_element
549
550

            if isinstance(element, ast.Conditional):
551
                break
552
            else:
Martin Bauer's avatar
Martin Bauer committed
553
554
                critical_symbols = element.symbols_defined
            if node.undefined_symbols.intersection(critical_symbols):
555
                break
Martin Bauer's avatar
Martin Bauer committed
556
            prev_element = element
557
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
558
        return last_block, last_block_child
559

560
    def check_if_assignment_already_in_block(assignment, target_block, rhs_or_lhs=True):
Martin Bauer's avatar
Martin Bauer committed
561
        for arg in target_block.args:
562
563
            if type(arg) is not ast.SympyAssignment:
                continue
564
            if (rhs_or_lhs and arg.rhs == assignment.rhs) or (not rhs_or_lhs and arg.lhs == assignment.lhs):
565
566
567
                return arg
        return None

Martin Bauer's avatar
Martin Bauer committed
568
    def get_blocks(node, result_list):
569
        if isinstance(node, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
570
            result_list.append(node)
571
572
        if isinstance(node, ast.Node):
            for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
573
                get_blocks(a, result_list)
574

Martin Bauer's avatar
Martin Bauer committed
575
576
577
    all_blocks = []
    get_blocks(ast_node, all_blocks)
    for block in all_blocks:
Martin Bauer's avatar
Martin Bauer committed
578
        children = block.take_child_nodes()
579
580
581
582
        # Every time a symbol can be replaced in the current block because the assignment
        # was found in a parent block, but with a different lhs symbol (same rhs)
        # the outer symbol is inserted here as key.
        substitute_variables = {}
583
        for child in children:
584
585
586
587
588
589
590
            # Before traversing the next child, all symbols are substituted first.
            child.subs(substitute_variables)

            if not isinstance(child, ast.SympyAssignment):  # only move SympyAssignments
                block.append(child)
                continue

591
592
593
            target, child_to_insert_before = find_block_to_move_to(child)
            if target == block:     # movement not possible
                target.append(child)
594
            else:
595
                if isinstance(child, ast.SympyAssignment):
596
                    exists_already = check_if_assignment_already_in_block(child, target, False)
597
                else:
598
                    exists_already = False
Martin Bauer's avatar
Martin Bauer committed
599

600
                if not exists_already:
601
602
603
604
605
606
607
608
                    rhs_identical = check_if_assignment_already_in_block(child, target, True)
                    if rhs_identical:
                        # there is already an assignment out there with the same rhs
                        # -> replace all lhs symbols in this block with the lhs of the outer assignment
                        # -> remove the local assignment (do not re-append child to the former block)
                        substitute_variables[child.lhs] = rhs_identical.lhs
                    else:
                        target.insert_before(child, child_to_insert_before)
Martin Bauer's avatar
Martin Bauer committed
609
610
                elif exists_already and exists_already.rhs == child.rhs:
                    pass
611
                else:
612
613
614
615
616
617
                    # this variable already exists in outer block, but with different rhs
                    # -> symbol has to be renamed
                    assert isinstance(child.lhs, TypedSymbol)
                    new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
                    target.insert_before(ast.SympyAssignment(new_symbol, child.rhs), child_to_insert_before)
                    substitute_variables[child.lhs] = new_symbol
618
619


Martin Bauer's avatar
Martin Bauer committed
620
def split_inner_loop(ast_node: ast.Node, symbol_groups):
Martin Bauer's avatar
Martin Bauer committed
621
622
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
623

Martin Bauer's avatar
Martin Bauer committed
624
625
626
627
628
    Args:
        ast_node: AST root
        symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
                       updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups
                       and which no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
629
    """
Martin Bauer's avatar
Martin Bauer committed
630
631
632
633
634
635
636
637
638
639
640
641
642
    all_loops = ast_node.atoms(ast.LoopOverCoordinate)
    inner_loop = [l for l in all_loops if l.is_innermost_loop]
    assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
    inner_loop = inner_loop[0]
    assert type(inner_loop.body) is ast.Block
    outer_loop = [l for l in all_loops if l.is_outermost_loop]
    assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
    outer_loop = outer_loop[0]

    symbols_with_temporary_array = OrderedDict()
    assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args)

    assignment_groups = []
Martin Bauer's avatar
Martin Bauer committed
643
    for symbol_group in symbol_groups:
644
        # get all dependent symbols
Martin Bauer's avatar
Martin Bauer committed
645
        symbols_to_process = list(symbol_group)
Martin Bauer's avatar
Martin Bauer committed
646
647
648
649
        symbols_resolved = set()
        while symbols_to_process:
            s = symbols_to_process.pop()
            if s in symbols_resolved:
650
651
                continue

Martin Bauer's avatar
Martin Bauer committed
652
            if s in assignment_map:  # if there is no assignment inside the loop body it is independent already
Martin Bauer's avatar
Martin Bauer committed
653
                for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
654
655
                    if not isinstance(new_symbol, AbstractField.AbstractAccess) and \
                            new_symbol not in symbols_with_temporary_array:
Martin Bauer's avatar
Martin Bauer committed
656
                        symbols_to_process.append(new_symbol)
Martin Bauer's avatar
Martin Bauer committed
657
            symbols_resolved.add(s)
658

Martin Bauer's avatar
Martin Bauer committed
659
        for symbol in symbol_group:
660
            if not isinstance(symbol, AbstractField.AbstractAccess):
661
                assert type(symbol) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
662
                new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
663
664
                symbols_with_temporary_array[symbol] = sp.IndexedBase(new_ts,
                                                                      shape=(1,))[inner_loop.loop_counter_symbol]
665

Martin Bauer's avatar
Martin Bauer committed
666
667
668
669
        assignment_group = []
        for assignment in inner_loop.body.args:
            if assignment.lhs in symbols_resolved:
                new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items())
670
                if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group:
671
                    assert type(assignment.lhs) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
672
                    new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
673
                    new_lhs = sp.IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
674
                else:
Martin Bauer's avatar
Martin Bauer committed
675
676
677
                    new_lhs = assignment.lhs
                assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
        assignment_groups.append(assignment_group)
678

Martin Bauer's avatar
Martin Bauer committed
679
680
    new_loops = [inner_loop.new_loop_with_different_body(ast.Block(group)) for group in assignment_groups]
    inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
681

Martin Bauer's avatar
Martin Bauer committed
682
683
    for tmp_array in symbols_with_temporary_array:
        tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
Martin Bauer's avatar
Martin Bauer committed
684
685
686
687
        alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start)
        free_node = ast.TemporaryMemoryFree(alloc_node)
        outer_loop.parent.insert_front(alloc_node)
        outer_loop.parent.append(free_node)
688
689


Martin Bauer's avatar
Martin Bauer committed
690
def cut_loop(loop_node, cutting_points):
Martin Bauer's avatar
Martin Bauer committed
691
692
693
694
695
696
697
698
699
700
    """Cuts loop at given cutting points.

    One loop is transformed into len(cuttingPoints)+1 new loops that range from
    old_begin to cutting_points[1], ..., cutting_points[-1] to old_end

    Modifies the ast in place

    Returns:
        list of new loop nodes
    """
Martin Bauer's avatar
Martin Bauer committed
701
    if loop_node.step != 1:
702
        raise NotImplementedError("Can only split loops that have a step of 1")
703
    new_loops = ast.Block([])
Martin Bauer's avatar
Martin Bauer committed
704
705
    new_start = loop_node.start
    cutting_points = list(cutting_points) + [loop_node.stop]
Martin Bauer's avatar
Martin Bauer committed
706
707
    for new_end in cutting_points:
        if new_end - new_start == 1:
Martin Bauer's avatar
Martin Bauer committed
708
709
710
            new_body = deepcopy(loop_node.body)
            new_body.subs({loop_node.loop_counter_symbol: new_start})
            new_loops.append(new_body)
Martin Bauer's avatar
Martin Bauer committed
711
712
        elif new_end - new_start == 0:
            pass
713
        else:
Martin Bauer's avatar
Martin Bauer committed
714
715
            new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
                                              new_start, new_end, loop_node.step)
Martin Bauer's avatar
Martin Bauer committed
716
            new_loops.append(new_loop)
Martin Bauer's avatar
Martin Bauer committed
717
        new_start = new_end
Martin Bauer's avatar
Martin Bauer committed
718
    loop_node.parent.replace(loop_node, new_loops)
Martin Bauer's avatar
Martin Bauer committed
719
    return new_loops
720
721


Martin Bauer's avatar
Martin Bauer committed
722
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = False) -> None:
723
    """Removes conditionals that are always true/false.
724
725

    Args:
726
727
728
729
730
731
        node: ast node, all descendants of this node are simplified
        loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false
                                     depending on the surrounding loop. For example if the surrounding loop goes from
                                     x=0 to 10 and the condition is x < 0, it is removed.
                                     This analysis needs the integer set library (ISL) islpy, so it is not done by
                                     default.
732
    """
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
    for conditional in node.atoms(ast.Conditional):
        conditional.condition_expr = sp.simplify(conditional.condition_expr)
        if conditional.condition_expr == sp.true:
            conditional.parent.replace(conditional, [conditional.true_block])
        elif conditional.condition_expr == sp.false:
            conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
        elif loop_counter_simplification:
            try:
                # noinspection PyUnresolvedReferences
                from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
                simplify_loop_counter_dependent_conditional(conditional)
            except ImportError:
                warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")


def cleanup_blocks(node: ast.Node) -> None:
749
750
751
752
753
    """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
    if isinstance(node, ast.SympyAssignment):
        return
    elif isinstance(node, ast.Block):
        for a in list(node.args):
Martin Bauer's avatar
Martin Bauer committed
754
            cleanup_blocks(a)
755
756
757
758
759
        if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
            node.parent.replace(node, node.args)
            return
    else:
        for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
760
            cleanup_blocks(a)
761
762


763
764
765
766
767
768
769
770
771
772
773
774
class KernelConstraintsCheck:
    """Checks if the input to create_kernel is valid.

    Test the following conditions:

    - SSA Form for pure symbols:
        -  Every pure symbol may occur only once as left-hand-side of an assignment
        -  Every pure symbol that is read, may not be written to later
    - Independence / Parallelization condition:
        - a field that is written may only be read at exact the same spatial position

    (Pure symbols are symbols that are not Field.Accesses)
Martin Bauer's avatar
Martin Bauer committed
775
    """
776
777
778
779
780
    FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])

    def __init__(self, type_for_symbol, check_independence_condition):
        self._type_for_symbol = type_for_symbol

781
        self.scopes = NestedScopes()
782
783
784
785
786
787
788
789
790
791
        self._field_writes = defaultdict(set)
        self.fields_read = set()
        self.check_independence_condition = check_independence_condition

    def process_assignment(self, assignment):
        # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
        new_rhs = self.process_expression(assignment.rhs)
        new_lhs = self._process_lhs(assignment.lhs)
        return ast.SympyAssignment(new_lhs, new_rhs)

792
    def process_expression(self, rhs, type_constants=True):
793
        self._update_accesses_rhs(rhs)
794
        if isinstance(rhs, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
795
            self.fields_read.add(rhs.field)
Martin Bauer's avatar
Martin Bauer committed
796
            self.fields_read.update(rhs.indirect_addressing_fields)
797
798
799
800
            return rhs
        elif isinstance(rhs, TypedSymbol):
            return rhs
        elif isinstance(rhs, sp.Symbol):
801
            return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
802
        elif type_constants and isinstance(rhs, sp.Number):
803
804
            return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
        elif isinstance(rhs, sp.Mul):
805
            new_args = [self.process_expression(arg, type_constants) if arg not in (-1, 1) else arg for arg in rhs.args]
806
            return rhs.func(*new_args) if new_args else rhs
807
808
        elif isinstance(rhs, sp.Indexed):
            return rhs
809
810
811
        else:
            if isinstance(rhs, sp.Pow):
                # don't process exponents -> they should remain integers
812
                return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1])
813
            else:
814
                new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
815
                return rhs.func(*new_args) if new_args else rhs
816
817
818
819
820
821
822
823

    @property
    def fields_written(self):
        return set(k.field for k, v in self._field_writes.items() if len(v))

    def _process_lhs(self, lhs):
        assert isinstance(lhs, sp.Symbol)
        self._update_accesses_lhs(lhs)
824
        if not isinstance(lhs, AbstractField.AbstractAccess) and not isinstance(lhs, TypedSymbol):
825
826
827
828
829
            return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
        else:
            return lhs

    def _update_accesses_lhs(self, lhs):
830
        if isinstance(lhs, AbstractField.AbstractAccess):
831
832
833
            fai = self.FieldAndIndex(lhs.field, lhs.index)
            self._field_writes[fai].add(lhs.offsets)
            if len(self._field_writes[fai]) > 1:
834
                raise ValueError("Field {} is written at two different locations".format(lhs.field.name))
835
        elif isinstance(lhs, sp.Symbol):
836
            if self.scopes.is_defined_locally(lhs):
837
                raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
838
            if lhs in self.scopes.free_parameters:
839
                raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
840
            self.scopes.define_symbol(lhs)
841
842

    def _update_accesses_rhs(self, rhs):
843
        if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition:
844
845
846
847
            writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
            for write_offset in writes:
                assert len(writes) == 1
                if write_offset != rhs.offsets:
848
849
                    raise ValueError("Violation of loop independence condition. Field "
                                     "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
850
851
            self.fields_read.add(rhs.field)
        elif isinstance(rhs, sp.Symbol):
852
            self.scopes.access_symbol(rhs)
853
854
855
856
857


def add_types(eqs, type_for_symbol, check_independence_condition):
    """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.

Martin Bauer's avatar
Martin Bauer committed
858
859
    Additionally returns sets of all fields which are read/written

860
861
862
863
864
865
866
867
868
    Args:
        eqs: list of equations
        type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
        check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
                                      kernels

    Returns:
        ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
         list of equations where symbols have been replaced by typed symbols
Martin Bauer's avatar
Martin Bauer committed
869
    """
Martin Bauer's avatar
Martin Bauer committed
870
871
    if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
        type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
872

873
    check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
874

Martin Bauer's avatar
Martin Bauer committed
875
876
877
878
    def visit(obj):
        if isinstance(obj, list) or isinstance(obj, tuple):
            return [visit(e) for e in obj]
        if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
879
            return check.process_assignment(obj)
Martin Bauer's avatar
Martin Bauer committed
880
        elif isinstance(obj, ast.Conditional):
881
            check.scopes.push()
Martin Bauer's avatar
Martin Bauer committed
882
            false_block = None if obj.false_block is None else visit(obj.false_block)
883
884
885
886
            result = ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False),
                                     true_block=visit(obj.true_block), false_block=false_block)
            check.scopes.pop()
            return result
Martin Bauer's avatar
Martin Bauer committed
887
        elif isinstance(obj, ast.Block):
888
889
890
891
            check.scopes.push()
            result = ast.Block([visit(e) for e in obj.args])
            check.scopes.pop()
            return result
892
        elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
Martin Bauer's avatar
Martin Bauer committed
893
            return obj
894
895
        else:
            raise ValueError("Invalid object in kernel " + str(type(obj)))
896

Martin Bauer's avatar
Martin Bauer committed
897
    typed_equations = visit(eqs)
898

899
    return check.fields_read, check.fields_written, typed_equations
900
901


Martin Bauer's avatar
Martin Bauer committed
902
def insert_casts(node):
903
    """Checks the types and inserts casts and pointer arithmetic where necessary.
Martin Bauer's avatar
Martin Bauer committed
904

905
906
907
908
909
    Args:
        node: the head node of the ast

    Returns:
        modified AST
Martin Bauer's avatar
Martin Bauer committed
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
    """
    def cast(zipped_args_types, target_dtype):
        """
        Adds casts to the arguments if their type differs from the target type
        :param zipped_args_types: a zipped list of args and types
        :param target_dtype: The target data type
        :return: args with possible casts
        """
        casted_args = []
        for argument, data_type in zipped_args_types:
            if data_type.numpy_dtype != target_dtype.numpy_dtype:  # ignoring const
                casted_args.append(cast_func(argument, target_dtype))
            else:
                casted_args.append(argument)
        return casted_args

    def pointer_arithmetic(expr_args):
        """
        Creates a valid pointer arithmetic function
        :param expr_args: Arguments of the add expression
        :return: pointer_arithmetic_func
        """
        pointer = None
        new_args = []
        for arg, data_type in expr_args:
            if data_type.func is PointerType:
                assert pointer is None
                pointer = arg
        for arg, data_type in expr_args:
            if arg != pointer:
                assert data_type.is_int() or data_type.is_uint()
                new_args.append(arg)
        new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
        return pointer_arithmetic_func(pointer, new_args)

945
    if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func):
Martin Bauer's avatar
Martin Bauer committed
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
        return node
    args = []
    for arg in node.args:
        args.append(insert_casts(arg))
    # TODO indexed, LoopOverCoordinate
    if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
        # TODO optimize pow, don't cast integer on double
        types = [get_type_of_expression(arg) for arg in args]
        assert len(types) > 0
        target = collate_types(types)
        zipped = list(zip(args, types))
        if target.func is PointerType:
            assert node.func is sp.Add
            return pointer_arithmetic(zipped)
        else:
            return node.func(*cast(zipped, target))
    elif node.func is ast.SympyAssignment:
        lhs = args[0]
        rhs = args[1]
        target = get_type_of_expression(lhs)
        if target.func is PointerType:
            return node.func(*args)  # TODO fix, not complete
        else:
            return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
    elif node.func is ast.ResolvedFieldAccess:
        return node
    elif node.func is ast.Block:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is ast.LoopOverCoordinate:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is sp.Piecewise:
        expressions = [expr for (expr, _) in args]
        types = [get_type_of_expression(expr) for expr in expressions]
        target = collate_types(types)
        zipped = list(zip(expressions, types))
        casted_expressions = cast(zipped, target)
        args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_expressions)]

    return node.func(*args)


991
992
993
994
995
996
997
998
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
    """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""

    all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
    assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
    inner_loop = all_inner_loops.pop()

    for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
999
        cut_loop(loop, [loop.stop - 1])
1000

1001
    simplify_conditionals(function_node.body, loop_counter_simplification=True)
1002
    cleanup_blocks(function_node.body)
Martin Bauer's avatar
Martin Bauer committed
1003

1004
1005
1006
1007
    move_constants_before_loop(function_node.body)
    cleanup_blocks(function_node.body)


Martin Bauer's avatar
Martin Bauer committed
1008
1009
1010
# --------------------------------------- Helper Functions -------------------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
1011
def typing_from_sympy_inspection(eqs, default_type="double"):
Martin Bauer's avatar