transformations.py 47.2 KB
Newer Older
1
import warnings
2
from collections import defaultdict, OrderedDict, namedtuple
3
from copy import deepcopy
Martin Bauer's avatar
Martin Bauer committed
4
from types import MappingProxyType
5
6
import pickle
import hashlib
7
8
import sympy as sp
from sympy.logic.boolalg import Boolean
9
from sympy.tensor import IndexedBase
10
from pystencils.assignment import Assignment
11
from pystencils.assignment_collection.nestedscopes import NestedScopes
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.field import Field, FieldType
13
14
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \
    cast_func, pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
15
from pystencils.kernelparameters import FieldPointerSymbol
Martin Bauer's avatar
Martin Bauer committed
16
from pystencils.slicing import normalize_slice
Martin Bauer's avatar
Martin Bauer committed
17
import pystencils.astnodes as ast
18
19


Martin Bauer's avatar
Martin Bauer committed
20
def filtered_tree_iteration(node, node_type, stop_type=None):
21
    for arg in node.args:
Martin Bauer's avatar
Martin Bauer committed
22
        if isinstance(arg, node_type):
23
            yield arg
Martin Bauer's avatar
Martin Bauer committed
24
25
26
        elif stop_type and isinstance(node, stop_type):
            continue

Martin Bauer's avatar
Martin Bauer committed
27
        yield from filtered_tree_iteration(arg, node_type)
28
29


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def unify_shape_symbols(body, common_shape, fields):
    """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.

    When creating a kernel with variable array sizes, all passed arrays must have the same size.
    This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead
    of one for each field. For example shape_arr1[0]  and shape_arr2[0] must be equal, so they should also be
    represented by the same symbol.

    Args:
        body: ast node, for the kernel part where substitutions is made, is modified in-place
        common_shape: shape of the field that was chosen
        fields: all fields whose shapes should be replaced by common_shape
    """
    substitutions = {}
    for field in fields:
        assert len(field.spatial_shape) == len(common_shape)
        if not field.has_fixed_shape:
            for common_shape_component, shape_component in zip(common_shape, field.spatial_shape):
                if shape_component != common_shape_component:
                    substitutions[shape_component] = common_shape_component
    if substitutions:
        body.subs(substitutions)


Martin Bauer's avatar
Martin Bauer committed
54
def get_common_shape(field_set):
55
56
    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
    ValueError is raised"""
Martin Bauer's avatar
Martin Bauer committed
57
58
59
60
61
62
63
64
    nr_of_fixed_shaped_fields = 0
    for f in field_set:
        if f.has_fixed_shape:
            nr_of_fixed_shaped_fields += 1

    if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set):
        fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
        var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
65
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
Martin Bauer's avatar
Martin Bauer committed
66
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (var_field_names, fixed_field_names)
67
68
        raise ValueError(msg)

Martin Bauer's avatar
Martin Bauer committed
69
70
71
72
    shape_set = set([f.spatial_shape for f in field_set])
    if nr_of_fixed_shaped_fields == len(field_set):
        if len(shape_set) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
73

Martin Bauer's avatar
Martin Bauer committed
74
    shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
75
76
77
    return shape


Martin Bauer's avatar
Martin Bauer committed
78
79
80
81
def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layers=None, loop_order=None):
    """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.

    Args:
82
        body: Block object with inner loop contents
Martin Bauer's avatar
Martin Bauer committed
83
84
85
86
87
88
89
90
91
        function_name: name of generated C function
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
             if None, the number of ghost layers is determined automatically and assumed to be equal for a
             all dimensions
        loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)

    Returns:
        :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
92
93
    """
    # find correct ordering by inspecting participating FieldAccesses
Martin Bauer's avatar
Martin Bauer committed
94
    field_accesses = body.atoms(Field.Access)
95
96
    field_accesses = {e for e in field_accesses if not e.is_absolute_access}

Martin Bauer's avatar
Martin Bauer committed
97
98
99
100
101
102
103
    # exclude accesses to buffers from field_list, because buffers are treated separately
    field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)]
    fields = set(field_list)

    if loop_order is None:
        loop_order = get_optimal_loop_ordering(fields)

104
105
    shape = get_common_shape(fields)
    unify_shape_symbols(body, common_shape=shape, fields=fields)
Martin Bauer's avatar
Martin Bauer committed
106
107
108
109
110
111
112
113
114
115
116

    if iteration_slice is not None:
        iteration_slice = normalize_slice(iteration_slice, shape)

    if ghost_layers is None:
        required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
        ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
    if isinstance(ghost_layers, int):
        ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)

    current_body = body
Martin Bauer's avatar
Martin Bauer committed
117
    for i, loop_coordinate in enumerate(reversed(loop_order)):
Martin Bauer's avatar
Martin Bauer committed
118
        if iteration_slice is None:
Martin Bauer's avatar
Martin Bauer committed
119
120
121
            begin = ghost_layers[loop_coordinate][0]
            end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
            new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
Martin Bauer's avatar
Martin Bauer committed
122
            current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
123
        else:
Martin Bauer's avatar
Martin Bauer committed
124
            slice_component = iteration_slice[loop_coordinate]
Martin Bauer's avatar
Martin Bauer committed
125
126
            if type(slice_component) is slice:
                sc = slice_component
Martin Bauer's avatar
Martin Bauer committed
127
                new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
Martin Bauer's avatar
Martin Bauer committed
128
                current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
129
            else:
Martin Bauer's avatar
Martin Bauer committed
130
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
Martin Bauer's avatar
Martin Bauer committed
131
132
                                                 sp.sympify(slice_component))
                current_body.insert_front(assignment)
133

Martin Bauer's avatar
Martin Bauer committed
134
    ast_node = ast.KernelFunction(current_body, ghost_layers=ghost_layers, function_name=function_name, backend='cpu')
Martin Bauer's avatar
Martin Bauer committed
135
    return ast_node
136
137


Martin Bauer's avatar
Martin Bauer committed
138
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
Martin Bauer's avatar
Martin Bauer committed
139
    r"""
140
    Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
Martin Bauer's avatar
Martin Bauer committed
141
142
143
144
145
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.

146
147
148
149
150
151
152
153
154
    Args:
        field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
        coordinates: mapping of coordinate ids to its value, where stride*value is calculated
        previous_ptr: the pointer which is de-referenced

    Returns
        tuple with the new pointer symbol and the calculated offset

    Examples:
Martin Bauer's avatar
Martin Bauer committed
155
        >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
Martin Bauer's avatar
Martin Bauer committed
156
        >>> x, y = sp.symbols("x y")
Martin Bauer's avatar
Martin Bauer committed
157
158
        >>> prev_pointer = TypedSymbol("ptr", "double")
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
159
        (ptr_01, _stride_myfield_0*x + _stride_myfield_0)
Martin Bauer's avatar
Martin Bauer committed
160
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
161
        (ptr_01_1m2, _stride_myfield_0*x + _stride_myfield_0 + _stride_myfield_1*y - 2*_stride_myfield_1)
Martin Bauer's avatar
Martin Bauer committed
162
    """
Martin Bauer's avatar
Martin Bauer committed
163
    field = field_access.field
164
165
    offset = 0
    name = ""
Martin Bauer's avatar
Martin Bauer committed
166
    list_to_hash = []
Martin Bauer's avatar
Martin Bauer committed
167
168
    for coordinate_id, coordinate_value in coordinates.items():
        offset += field.strides[coordinate_id] * coordinate_value
169

Martin Bauer's avatar
Martin Bauer committed
170
171
172
        if coordinate_id < field.spatial_dimensions:
            offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
            if type(field_access.offsets[coordinate_id]) is int:
Martin Bauer's avatar
Martin Bauer committed
173
                name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
174
            else:
Martin Bauer's avatar
Martin Bauer committed
175
                list_to_hash.append(field_access.offsets[coordinate_id])
176
        else:
Martin Bauer's avatar
Martin Bauer committed
177
            if type(coordinate_value) is int:
Martin Bauer's avatar
Martin Bauer committed
178
                name += "_%d%d" % (coordinate_id, coordinate_value)
179
            else:
Martin Bauer's avatar
Martin Bauer committed
180
                list_to_hash.append(coordinate_value)
181

Martin Bauer's avatar
Martin Bauer committed
182
    if len(list_to_hash) > 0:
183
        name += hashlib.md5(pickle.dumps(list_to_hash)).hexdigest()[:16]
184

Martin Bauer's avatar
Martin Bauer committed
185
    name = name.replace("-", 'm')
Martin Bauer's avatar
Martin Bauer committed
186
187
    new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
    return new_ptr, offset
188
189


Martin Bauer's avatar
Martin Bauer committed
190
def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
191
    """
Martin Bauer's avatar
Martin Bauer committed
192
    Creates base pointer specification for :func:`resolve_field_accesses` function.
Martin Bauer's avatar
Martin Bauer committed
193
194
195

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
196
    zero directly in the field access. These specifications are defined dependent on the loop ordering.
Martin Bauer's avatar
Martin Bauer committed
197
198
    This function translates more readable version into the specification above.

199
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
200
201
202
203
204
205
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

Martin Bauer's avatar
Martin Bauer committed
206
207
208
    Args:
        base_pointer_specification: nested list with above specifications
        loop_order: list with ordering of loops from outer to inner
Martin Bauer's avatar
Martin Bauer committed
209
210
        spatial_dimensions: number of spatial dimensions
        index_dimensions: number of index dimensions
Martin Bauer's avatar
Martin Bauer committed
211
212
213

    Returns:
        list of tuples that can be passed to :func:`resolve_field_accesses`
Martin Bauer's avatar
Martin Bauer committed
214
215
216
217
218

    Examples:
        >>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
        ...                         spatial_dimensions=3, index_dimensions=1)
        [[0], [3], [1, 2]]
219
220
    """
    result = []
Martin Bauer's avatar
Martin Bauer committed
221
222
    specified_coordinates = set()
    loop_order = list(reversed(loop_order))
Martin Bauer's avatar
Martin Bauer committed
223
    for spec_group in base_pointer_specification:
Martin Bauer's avatar
Martin Bauer committed
224
225
226
        new_group = []

        def add_new_element(elem):
Martin Bauer's avatar
Martin Bauer committed
227
            if elem >= spatial_dimensions + index_dimensions:
Martin Bauer's avatar
Martin Bauer committed
228
229
230
231
232
                raise ValueError("Coordinate %d does not exist" % (elem,))
            new_group.append(elem)
            if elem in specified_coordinates:
                raise ValueError("Coordinate %d specified two times" % (elem,))
            specified_coordinates.add(elem)
Martin Bauer's avatar
Martin Bauer committed
233
        for element in spec_group:
234
            if type(element) is int:
Martin Bauer's avatar
Martin Bauer committed
235
                add_new_element(element)
236
237
238
239
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
Martin Bauer's avatar
Martin Bauer committed
240
                    add_new_element(loop_order[index])
241
242
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
Martin Bauer's avatar
Martin Bauer committed
243
                    add_new_element(loop_order[-index])
244
                elif element == "all":
Martin Bauer's avatar
Martin Bauer committed
245
                    for i in range(spatial_dimensions):
Martin Bauer's avatar
Martin Bauer committed
246
                        add_new_element(i)
247
248
249
250
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
Martin Bauer's avatar
Martin Bauer committed
251
                add_new_element(spatial_dimensions + index)
252
253
254
            else:
                raise ValueError("Unknown specification %s" % (element,))

Martin Bauer's avatar
Martin Bauer committed
255
        result.append(new_group)
256

Martin Bauer's avatar
Martin Bauer committed
257
    all_coordinates = set(range(spatial_dimensions + index_dimensions))
Martin Bauer's avatar
Martin Bauer committed
258
    rest = all_coordinates - specified_coordinates
259
260
    if rest:
        result.append(list(rest))
261

262
263
264
    return result


Martin Bauer's avatar
Martin Bauer committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
    """Used for buffer fields to determine the linearized index of the buffer dependent on loop counter symbols.

    Args:
        ast_node: ast before any field accesses are resolved
        loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
                       for GPU kernels: list of 'loop counters' from inner to outer loop
        loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default

    Returns:
        base buffer index - required by 'resolve_buffer_accesses' function
    """
    if loop_counters is None or loop_iterations is None:
        loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
        loops.reverse()
        parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
        assert len(loops) == len(parents_of_innermost_loop)
        assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))

        loop_iterations = [(l.stop - l.start) / l.step for l in loops]
        loop_counters = [l.loop_counter_symbol for l in loops]

    field_accesses = ast_node.atoms(Field.Access)
    buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
    loop_counters = [v * len(buffer_accesses) for v in loop_counters]

    base_buffer_index = loop_counters[0]
    stride = 1
    for idx, var in enumerate(loop_counters[1:]):
        cur_stride = loop_iterations[idx]
        stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
        base_buffer_index += var * stride
    return base_buffer_index


Martin Bauer's avatar
Martin Bauer committed
300
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
Martin Bauer's avatar
Martin Bauer committed
301

Martin Bauer's avatar
Martin Bauer committed
302
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
303
        if isinstance(expr, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
304
            field_access = expr
305
306

            # Do not apply transformation if field is not a buffer
Martin Bauer's avatar
Martin Bauer committed
307
            if not FieldType.is_buffer(field_access.field):
308
309
                return expr

Martin Bauer's avatar
Martin Bauer committed
310
            buffer = field_access.field
311
            field_ptr = FieldPointerSymbol(buffer.name, buffer.dtype, const=buffer.name in read_only_field_names)
312

Martin Bauer's avatar
Martin Bauer committed
313
314
            buffer_index = base_buffer_index
            if len(field_access.index) > 1:
315
316
                raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')

Martin Bauer's avatar
Martin Bauer committed
317
318
319
            if len(field_access.index) > 0:
                cell_index = field_access.index[0]
                buffer_index += cell_index
320

Martin Bauer's avatar
Martin Bauer committed
321
322
            result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
                                             field_access.index)
323

Martin Bauer's avatar
Martin Bauer committed
324
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
325
326
327
328
        else:
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr

Martin Bauer's avatar
Martin Bauer committed
329
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
330
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
331
332
333
334
335
336
337
338
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
339
        else:
Martin Bauer's avatar
Martin Bauer committed
340
341
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
342

Martin Bauer's avatar
Martin Bauer committed
343
    return visit_node(ast_node)
344

345

Martin Bauer's avatar
Martin Bauer committed
346
def resolve_field_accesses(ast_node, read_only_field_names=set(),
Martin Bauer's avatar
Martin Bauer committed
347
348
                           field_to_base_pointer_info=MappingProxyType({}),
                           field_to_fixed_coordinates=MappingProxyType({})):
Martin Bauer's avatar
Martin Bauer committed
349
350
351
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

352
353
354
355
356
357
    Args:
        ast_node: the AST root
        read_only_field_names: set of field names which are considered read-only
        field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
                                    for details see :func:`parse_base_pointer_info`
        field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
Martin Bauer's avatar
Martin Bauer committed
358
                                    counters to index the field these symbols are used as coordinates
359
360
361

    Returns
        transformed AST
Martin Bauer's avatar
Martin Bauer committed
362
    """
Martin Bauer's avatar
Martin Bauer committed
363
364
    field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
    field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
365

Martin Bauer's avatar
Martin Bauer committed
366
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
367
        if isinstance(expr, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
368
            field_access = expr
Martin Bauer's avatar
Martin Bauer committed
369
            field = field_access.field
370

Martin Bauer's avatar
Martin Bauer committed
371
            if field_access.indirect_addressing_fields:
372
373
                new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
                                    for off in field_access.offsets)
Martin Bauer's avatar
Martin Bauer committed
374
375
376
377
378
                new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
                                    if isinstance(ind, sp.Basic) else ind
                                    for ind in field_access.index)
                field_access = Field.Access(field_access.field, new_offsets,
                                            new_indices, field_access.is_absolute_access)
379

Martin Bauer's avatar
Martin Bauer committed
380
381
            if field.name in field_to_base_pointer_info:
                base_pointer_info = field_to_base_pointer_info[field.name]
382
            else:
Martin Bauer's avatar
Martin Bauer committed
383
                base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
384

385
            field_ptr = FieldPointerSymbol(field.name, field.dtype, const=field.name in read_only_field_names)
386

Martin Bauer's avatar
Martin Bauer committed
387
388
389
390
            def create_coordinate_dict(group_param):
                coordinates = {}
                for e in group_param:
                    if e < field.spatial_dimensions:
Martin Bauer's avatar
Martin Bauer committed
391
                        if field.name in field_to_fixed_coordinates:
392
393
394
395
                            if not field_access.is_absolute_access:
                                coordinates[e] = field_to_fixed_coordinates[field.name][e]
                            else:
                                coordinates[e] = 0
396
                        else:
Martin Bauer's avatar
Martin Bauer committed
397
398
399
400
                            if not field_access.is_absolute_access:
                                coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
                            else:
                                coordinates[e] = 0
Martin Bauer's avatar
Martin Bauer committed
401
                        coordinates[e] *= field.dtype.item_size
402
                    else:
403
                        if isinstance(field.dtype, StructType):
Martin Bauer's avatar
Martin Bauer committed
404
                            assert field.index_dimensions == 1
Martin Bauer's avatar
Martin Bauer committed
405
406
                            accessed_field_name = field_access.index[0]
                            assert isinstance(accessed_field_name, str)
Martin Bauer's avatar
Martin Bauer committed
407
                            coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
408
                        else:
Martin Bauer's avatar
Martin Bauer committed
409
                            coordinates[e] = field_access.index[e - field.spatial_dimensions]
410

Martin Bauer's avatar
Martin Bauer committed
411
                return coordinates
412

Martin Bauer's avatar
Martin Bauer committed
413
            last_pointer = field_ptr
414

Martin Bauer's avatar
Martin Bauer committed
415
416
            for group in reversed(base_pointer_info[1:]):
                coord_dict = create_coordinate_dict(group)
Martin Bauer's avatar
Martin Bauer committed
417
                new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
418
419
420
421
                if new_ptr not in enclosing_block.symbols_defined:
                    new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
                    enclosing_block.insert_before(new_assignment, sympy_assignment)
                last_pointer = new_ptr
422

Martin Bauer's avatar
Martin Bauer committed
423
            coord_dict = create_coordinate_dict(base_pointer_info[0])
Martin Bauer's avatar
Martin Bauer committed
424
            _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
425
426
            result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
                                             field_access.offsets, field_access.index)
427

Martin Bauer's avatar
Martin Bauer committed
428
429
            if isinstance(get_base_type(field_access.field.dtype), StructType):
                new_type = field_access.field.dtype.get_element_type(field_access.index[0])
430
                result = reinterpret_cast_func(result, new_type)
431

Martin Bauer's avatar
Martin Bauer committed
432
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
433
        else:
Martin Bauer's avatar
Martin Bauer committed
434
435
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr
436

Martin Bauer's avatar
Martin Bauer committed
437
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
438
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
439
440
441
442
443
444
445
446
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
447
        else:
Martin Bauer's avatar
Martin Bauer committed
448
449
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
450

Martin Bauer's avatar
Martin Bauer committed
451
    return visit_node(ast_node)
452
453


Martin Bauer's avatar
Martin Bauer committed
454
def move_constants_before_loop(ast_node):
455
456
    """Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.

Martin Bauer's avatar
Martin Bauer committed
457
    Call this after creating the loop structure with :func:`make_loop_over_domain`
Martin Bauer's avatar
Martin Bauer committed
458
    """
Martin Bauer's avatar
Martin Bauer committed
459
    def find_block_to_move_to(node):
Martin Bauer's avatar
Martin Bauer committed
460
461
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
462
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
463
464
465
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
466
467
        assert isinstance(node.parent, ast.Block)

Martin Bauer's avatar
Martin Bauer committed
468
469
        last_block = node.parent
        last_block_child = node
470
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
471
        prev_element = node
472
473
        while element:
            if isinstance(element, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
474
475
                last_block = element
                last_block_child = prev_element
476
477

            if isinstance(element, ast.Conditional):
Martin Bauer's avatar
Martin Bauer committed
478
                critical_symbols = element.condition_expr.atoms(sp.Symbol)
479
            else:
Martin Bauer's avatar
Martin Bauer committed
480
481
                critical_symbols = element.symbols_defined
            if node.undefined_symbols.intersection(critical_symbols):
482
                break
Martin Bauer's avatar
Martin Bauer committed
483
            prev_element = element
484
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
485
        return last_block, last_block_child
486

Martin Bauer's avatar
Martin Bauer committed
487
488
    def check_if_assignment_already_in_block(assignment, target_block):
        for arg in target_block.args:
489
490
491
492
493
494
            if type(arg) is not ast.SympyAssignment:
                continue
            if arg.lhs == assignment.lhs:
                return arg
        return None

Martin Bauer's avatar
Martin Bauer committed
495
    def get_blocks(node, result_list):
496
        if isinstance(node, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
497
            result_list.append(node)
498
499
        if isinstance(node, ast.Node):
            for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
500
                get_blocks(a, result_list)
501

Martin Bauer's avatar
Martin Bauer committed
502
503
504
    all_blocks = []
    get_blocks(ast_node, all_blocks)
    for block in all_blocks:
Martin Bauer's avatar
Martin Bauer committed
505
        children = block.take_child_nodes()
506
        for child in children:
507
508
509
            target, child_to_insert_before = find_block_to_move_to(child)
            if target == block:     # movement not possible
                target.append(child)
510
            else:
511
512
                if isinstance(child, ast.SympyAssignment):
                    exists_already = check_if_assignment_already_in_block(child, target)
513
                else:
514
                    exists_already = False
Martin Bauer's avatar
Martin Bauer committed
515

516
517
                if not exists_already:
                    target.insert_before(child, child_to_insert_before)
Martin Bauer's avatar
Martin Bauer committed
518
519
                elif exists_already and exists_already.rhs == child.rhs:
                    pass
520
                else:
Martin Bauer's avatar
Martin Bauer committed
521
                    block.append(child)  # don't move in this case - better would be to rename symbol
522
523


Martin Bauer's avatar
Martin Bauer committed
524
def split_inner_loop(ast_node: ast.Node, symbol_groups):
Martin Bauer's avatar
Martin Bauer committed
525
526
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
527

Martin Bauer's avatar
Martin Bauer committed
528
529
530
531
532
    Args:
        ast_node: AST root
        symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
                       updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups
                       and which no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
533
    """
Martin Bauer's avatar
Martin Bauer committed
534
535
536
537
538
539
540
541
542
543
544
545
546
    all_loops = ast_node.atoms(ast.LoopOverCoordinate)
    inner_loop = [l for l in all_loops if l.is_innermost_loop]
    assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
    inner_loop = inner_loop[0]
    assert type(inner_loop.body) is ast.Block
    outer_loop = [l for l in all_loops if l.is_outermost_loop]
    assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
    outer_loop = outer_loop[0]

    symbols_with_temporary_array = OrderedDict()
    assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args)

    assignment_groups = []
Martin Bauer's avatar
Martin Bauer committed
547
    for symbol_group in symbol_groups:
548
        # get all dependent symbols
Martin Bauer's avatar
Martin Bauer committed
549
        symbols_to_process = list(symbol_group)
Martin Bauer's avatar
Martin Bauer committed
550
551
552
553
        symbols_resolved = set()
        while symbols_to_process:
            s = symbols_to_process.pop()
            if s in symbols_resolved:
554
555
                continue

Martin Bauer's avatar
Martin Bauer committed
556
            if s in assignment_map:  # if there is no assignment inside the loop body it is independent already
Martin Bauer's avatar
Martin Bauer committed
557
558
559
                for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
                    if type(new_symbol) is not Field.Access and new_symbol not in symbols_with_temporary_array:
                        symbols_to_process.append(new_symbol)
Martin Bauer's avatar
Martin Bauer committed
560
            symbols_resolved.add(s)
561

Martin Bauer's avatar
Martin Bauer committed
562
        for symbol in symbol_group:
563
564
            if type(symbol) is not Field.Access:
                assert type(symbol) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
565
566
                new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
                symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
567

Martin Bauer's avatar
Martin Bauer committed
568
569
570
571
        assignment_group = []
        for assignment in inner_loop.body.args:
            if assignment.lhs in symbols_resolved:
                new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items())
Martin Bauer's avatar
Martin Bauer committed
572
                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbol_group:
573
                    assert type(assignment.lhs) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
574
575
                    new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
                    new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
576
                else:
Martin Bauer's avatar
Martin Bauer committed
577
578
579
                    new_lhs = assignment.lhs
                assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
        assignment_groups.append(assignment_group)
580

Martin Bauer's avatar
Martin Bauer committed
581
582
    new_loops = [inner_loop.new_loop_with_different_body(ast.Block(group)) for group in assignment_groups]
    inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
583

Martin Bauer's avatar
Martin Bauer committed
584
585
    for tmp_array in symbols_with_temporary_array:
        tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
Martin Bauer's avatar
Martin Bauer committed
586
587
588
589
        alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start)
        free_node = ast.TemporaryMemoryFree(alloc_node)
        outer_loop.parent.insert_front(alloc_node)
        outer_loop.parent.append(free_node)
590
591


Martin Bauer's avatar
Martin Bauer committed
592
def cut_loop(loop_node, cutting_points):
Martin Bauer's avatar
Martin Bauer committed
593
594
595
596
597
598
599
600
601
602
    """Cuts loop at given cutting points.

    One loop is transformed into len(cuttingPoints)+1 new loops that range from
    old_begin to cutting_points[1], ..., cutting_points[-1] to old_end

    Modifies the ast in place

    Returns:
        list of new loop nodes
    """
Martin Bauer's avatar
Martin Bauer committed
603
    if loop_node.step != 1:
604
        raise NotImplementedError("Can only split loops that have a step of 1")
Martin Bauer's avatar
Martin Bauer committed
605
606
607
    new_loops = []
    new_start = loop_node.start
    cutting_points = list(cutting_points) + [loop_node.stop]
Martin Bauer's avatar
Martin Bauer committed
608
609
    for new_end in cutting_points:
        if new_end - new_start == 1:
Martin Bauer's avatar
Martin Bauer committed
610
611
612
            new_body = deepcopy(loop_node.body)
            new_body.subs({loop_node.loop_counter_symbol: new_start})
            new_loops.append(new_body)
Martin Bauer's avatar
Martin Bauer committed
613
614
        elif new_end - new_start == 0:
            pass
615
        else:
Martin Bauer's avatar
Martin Bauer committed
616
617
            new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
                                              new_start, new_end, loop_node.step)
Martin Bauer's avatar
Martin Bauer committed
618
            new_loops.append(new_loop)
Martin Bauer's avatar
Martin Bauer committed
619
        new_start = new_end
Martin Bauer's avatar
Martin Bauer committed
620
    loop_node.parent.replace(loop_node, new_loops)
Martin Bauer's avatar
Martin Bauer committed
621
    return new_loops
622
623


Martin Bauer's avatar
Martin Bauer committed
624
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = False) -> None:
625
    """Removes conditionals that are always true/false.
626
627

    Args:
628
629
630
631
632
633
        node: ast node, all descendants of this node are simplified
        loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false
                                     depending on the surrounding loop. For example if the surrounding loop goes from
                                     x=0 to 10 and the condition is x < 0, it is removed.
                                     This analysis needs the integer set library (ISL) islpy, so it is not done by
                                     default.
634
    """
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
    for conditional in node.atoms(ast.Conditional):
        conditional.condition_expr = sp.simplify(conditional.condition_expr)
        if conditional.condition_expr == sp.true:
            conditional.parent.replace(conditional, [conditional.true_block])
        elif conditional.condition_expr == sp.false:
            conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
        elif loop_counter_simplification:
            try:
                # noinspection PyUnresolvedReferences
                from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
                simplify_loop_counter_dependent_conditional(conditional)
            except ImportError:
                warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")


def cleanup_blocks(node: ast.Node) -> None:
651
652
653
654
655
    """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
    if isinstance(node, ast.SympyAssignment):
        return
    elif isinstance(node, ast.Block):
        for a in list(node.args):
Martin Bauer's avatar
Martin Bauer committed
656
            cleanup_blocks(a)
657
658
659
660
661
        if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
            node.parent.replace(node, node.args)
            return
    else:
        for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
662
            cleanup_blocks(a)
663
664


665
666
667
668
669
670
671
672
673
674
675
676
class KernelConstraintsCheck:
    """Checks if the input to create_kernel is valid.

    Test the following conditions:

    - SSA Form for pure symbols:
        -  Every pure symbol may occur only once as left-hand-side of an assignment
        -  Every pure symbol that is read, may not be written to later
    - Independence / Parallelization condition:
        - a field that is written may only be read at exact the same spatial position

    (Pure symbols are symbols that are not Field.Accesses)
Martin Bauer's avatar
Martin Bauer committed
677
    """
678
679
680
681
682
    FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])

    def __init__(self, type_for_symbol, check_independence_condition):
        self._type_for_symbol = type_for_symbol

683
        self.scopes = NestedScopes()
684
685
686
687
688
689
690
691
692
693
        self._field_writes = defaultdict(set)
        self.fields_read = set()
        self.check_independence_condition = check_independence_condition

    def process_assignment(self, assignment):
        # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
        new_rhs = self.process_expression(assignment.rhs)
        new_lhs = self._process_lhs(assignment.lhs)
        return ast.SympyAssignment(new_lhs, new_rhs)

694
    def process_expression(self, rhs, type_constants=True):
695
696
        self._update_accesses_rhs(rhs)
        if isinstance(rhs, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
697
            self.fields_read.add(rhs.field)
Martin Bauer's avatar
Martin Bauer committed
698
            self.fields_read.update(rhs.indirect_addressing_fields)
699
700
701
702
            return rhs
        elif isinstance(rhs, TypedSymbol):
            return rhs
        elif isinstance(rhs, sp.Symbol):
703
            return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
704
        elif type_constants and isinstance(rhs, sp.Number):
705
706
            return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
        elif isinstance(rhs, sp.Mul):
707
            new_args = [self.process_expression(arg, type_constants) if arg not in (-1, 1) else arg for arg in rhs.args]
708
            return rhs.func(*new_args) if new_args else rhs
709
710
        elif isinstance(rhs, sp.Indexed):
            return rhs
711
712
713
        else:
            if isinstance(rhs, sp.Pow):
                # don't process exponents -> they should remain integers
714
                return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1])
715
            else:
716
                new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
717
                return rhs.func(*new_args) if new_args else rhs
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735

    @property
    def fields_written(self):
        return set(k.field for k, v in self._field_writes.items() if len(v))

    def _process_lhs(self, lhs):
        assert isinstance(lhs, sp.Symbol)
        self._update_accesses_lhs(lhs)
        if not isinstance(lhs, Field.Access) and not isinstance(lhs, TypedSymbol):
            return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
        else:
            return lhs

    def _update_accesses_lhs(self, lhs):
        if isinstance(lhs, Field.Access):
            fai = self.FieldAndIndex(lhs.field, lhs.index)
            self._field_writes[fai].add(lhs.offsets)
            if len(self._field_writes[fai]) > 1:
736
                raise ValueError("Field {} is written at two different locations".format(lhs.field.name))
737
        elif isinstance(lhs, sp.Symbol):
738
            if self.scopes.is_defined_locally(lhs):
739
                raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
740
            if lhs in self.scopes.free_parameters:
741
                raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
742
            self.scopes.define_symbol(lhs)
743
744
745
746
747
748
749

    def _update_accesses_rhs(self, rhs):
        if isinstance(rhs, Field.Access) and self.check_independence_condition:
            writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
            for write_offset in writes:
                assert len(writes) == 1
                if write_offset != rhs.offsets:
750
751
                    raise ValueError("Violation of loop independence condition. Field "
                                     "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
752
753
            self.fields_read.add(rhs.field)
        elif isinstance(rhs, sp.Symbol):
754
            self.scopes.access_symbol(rhs)
755
756
757
758
759


def add_types(eqs, type_for_symbol, check_independence_condition):
    """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.

Martin Bauer's avatar
Martin Bauer committed
760
761
    Additionally returns sets of all fields which are read/written

762
763
764
765
766
767
768
769
770
    Args:
        eqs: list of equations
        type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
        check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
                                      kernels

    Returns:
        ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
         list of equations where symbols have been replaced by typed symbols
Martin Bauer's avatar
Martin Bauer committed
771
    """
Martin Bauer's avatar
Martin Bauer committed
772
773
    if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
        type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
774

775
    check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
776

Martin Bauer's avatar
Martin Bauer committed
777
778
779
780
    def visit(obj):
        if isinstance(obj, list) or isinstance(obj, tuple):
            return [visit(e) for e in obj]
        if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
781
            return check.process_assignment(obj)
Martin Bauer's avatar
Martin Bauer committed
782
        elif isinstance(obj, ast.Conditional):
783
            check.scopes.push()
Martin Bauer's avatar
Martin Bauer committed
784
            false_block = None if obj.false_block is None else visit(obj.false_block)
785
786
787
788
            result = ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False),
                                     true_block=visit(obj.true_block), false_block=false_block)
            check.scopes.pop()
            return result
Martin Bauer's avatar
Martin Bauer committed
789
        elif isinstance(obj, ast.Block):
790
791
792
793
            check.scopes.push()
            result = ast.Block([visit(e) for e in obj.args])
            check.scopes.pop()
            return result
794
        elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
Martin Bauer's avatar
Martin Bauer committed
795
            return obj
796
797
        else:
            raise ValueError("Invalid object in kernel " + str(type(obj)))
798

Martin Bauer's avatar
Martin Bauer committed
799
    typed_equations = visit(eqs)
800

801
    return check.fields_read, check.fields_written, typed_equations
802
803


Martin Bauer's avatar
Martin Bauer committed
804
def insert_casts(node):
805
    """Checks the types and inserts casts and pointer arithmetic where necessary.
Martin Bauer's avatar
Martin Bauer committed
806

807
808
809
810
811
    Args:
        node: the head node of the ast

    Returns:
        modified AST
Martin Bauer's avatar
Martin Bauer committed
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
    """
    def cast(zipped_args_types, target_dtype):
        """
        Adds casts to the arguments if their type differs from the target type
        :param zipped_args_types: a zipped list of args and types
        :param target_dtype: The target data type
        :return: args with possible casts
        """
        casted_args = []
        for argument, data_type in zipped_args_types:
            if data_type.numpy_dtype != target_dtype.numpy_dtype:  # ignoring const
                casted_args.append(cast_func(argument, target_dtype))
            else:
                casted_args.append(argument)
        return casted_args

    def pointer_arithmetic(expr_args):
        """
        Creates a valid pointer arithmetic function
        :param expr_args: Arguments of the add expression
        :return: pointer_arithmetic_func
        """
        pointer = None
        new_args = []
        for arg, data_type in expr_args:
            if data_type.func is PointerType:
                assert pointer is None
                pointer = arg
        for arg, data_type in expr_args:
            if arg != pointer:
                assert data_type.is_int() or data_type.is_uint()
                new_args.append(arg)
        new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
        return pointer_arithmetic_func(pointer, new_args)

847
    if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func):
Martin Bauer's avatar
Martin Bauer committed
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
        return node
    args = []
    for arg in node.args:
        args.append(insert_casts(arg))
    # TODO indexed, LoopOverCoordinate
    if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
        # TODO optimize pow, don't cast integer on double
        types = [get_type_of_expression(arg) for arg in args]
        assert len(types) > 0
        target = collate_types(types)
        zipped = list(zip(args, types))
        if target.func is PointerType:
            assert node.func is sp.Add
            return pointer_arithmetic(zipped)
        else:
            return node.func(*cast(zipped, target))
    elif node.func is ast.SympyAssignment:
        lhs = args[0]
        rhs = args[1]
        target = get_type_of_expression(lhs)
        if target.func is PointerType:
            return node.func(*args)  # TODO fix, not complete
        else:
            return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
    elif node.func is ast.ResolvedFieldAccess:
        return node
    elif node.func is ast.Block:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is ast.LoopOverCoordinate:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is sp.Piecewise:
        expressions = [expr for (expr, _) in args]
        types = [get_type_of_expression(expr) for expr in expressions]
        target = collate_types(types)
        zipped = list(zip(expressions, types))
        casted_expressions = cast(zipped, target)
        args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_expressions)]

    return node.func(*args)


893
894
895
896
897
898
899
900
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
    """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""

    all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
    assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
    inner_loop = all_inner_loops.pop()

    for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
901
        cut_loop(loop, [loop.stop - 1])
902

903
    simplify_conditionals(function_node.body, loop_counter_simplification=True)
904
    cleanup_blocks(function_node.body)
Martin Bauer's avatar
Martin Bauer committed
905

906
907
908
909
    move_constants_before_loop(function_node.body)
    cleanup_blocks(function_node.body)


Martin Bauer's avatar
Martin Bauer committed
910
911
912
# --------------------------------------- Helper Functions -------------------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
913
def typing_from_sympy_inspection(eqs, default_type="double"):
Martin Bauer's avatar
Martin Bauer committed
914
915
916
    """
    Creates a default symbol name to type mapping.
    If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
917
918
919
920
921
922

    Args:
        eqs: list of equations
        default_type: the type for non-boolean symbols
    Returns:
        dictionary, mapping symbol name to type
Martin Bauer's avatar
Martin Bauer committed
923
    """
Martin Bauer's avatar
Martin Bauer committed
924
    result = defaultdict(lambda: default_type)
925
    for eq in eqs:
926
927
928
929
930
        if isinstance(eq, ast.Conditional):
            result.update(typing_from_sympy_inspection(eq.true_block.args))
            if eq.false_block:
                result.update(typing_from_sympy_inspection(eq.false_block.args))
        elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment):
931
            continue
932
933
934
935
936
        else:
            # problematic case here is when rhs is a symbol: then it is impossible to decide here without
            # further information what type the left hand side is - default fallback is the dict value then
            if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
                result[eq.lhs.name] = "bool"
937
938
939
    return result


Martin Bauer's avatar
Martin Bauer committed
940
def get_next_parent_of_type(node, parent_type):
941
942
943
944
    """Returns the next parent node of given type or None, if root is reached.

    Traverses the AST nodes parents until a parent of given type was found.
    If no such parent is found, None is returned
Martin Bauer's avatar
Martin Bauer committed
945
    """
946
947
    parent = node.parent
    while parent is not None:
Martin Bauer's avatar
Martin Bauer committed
948
        if isinstance(parent, parent_type):
949
950
951
952
953
            return parent
        parent = parent.parent
    return None


954
def parents_of_type(node, parent_type, include_current=False):
955
    """Generator for all parent nodes of given type"""
956
957
958
959
960
961
962
    parent = node if include_current else node.parent
    while parent is not None:
        if isinstance(parent, parent_type):
            yield parent
        parent = parent.parent


Martin Bauer's avatar
Martin Bauer committed
963
def get_optimal_loop_ordering(fields):
Martin Bauer's avatar
Martin Bauer committed
964
965
966
    """
    Determines the optimal loop order for a given set of fields.
    If the fields have different memory layout or different sizes an exception is thrown.
967
968
969
970
971
972

    Args:
        fields: sequence of fields

    Returns:
        list of coordinate ids, where the first list entry should be the outermost loop
Martin Bauer's avatar
Martin Bauer committed
973
    """
974
    assert len(fields) > 0
Martin Bauer's avatar
Martin Bauer committed
975
    ref_field = next(iter(fields))
976
    for field in fields:
Martin Bauer's avatar
Martin Bauer committed
977
        if field.spatial_dimensions != ref_field.spatial_dimensions:
978
            raise ValueError("All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
Martin Bauer's avatar
Martin Bauer committed
979
                             + str({f.name: f.spatial_shape for f in fields}))
980
981
982

    layouts = set([field.layout for field in fields])
    if len(layouts) > 1:
Martin Bauer's avatar
Martin Bauer committed
983
984
        raise ValueError("Due to different layout of the fields no optimal loop ordering exists "
                         + str({f.name: f.layout for f in fields}))
985
    layout = list(layouts)[0]
986
    return list(layout)
987
988


Martin Bauer's avatar
Martin Bauer committed