transformations.py 45.6 KB
Newer Older
1
import warnings
2
from collections import defaultdict, OrderedDict, namedtuple
3
from copy import deepcopy
Martin Bauer's avatar
Martin Bauer committed
4
from types import MappingProxyType
5
6
import sympy as sp
from sympy.logic.boolalg import Boolean
7
from sympy.tensor import IndexedBase
8
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
9
from pystencils.field import Field, FieldType
Martin Bauer's avatar
Martin Bauer committed
10
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \
11
    pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.slicing import normalize_slice
Martin Bauer's avatar
Martin Bauer committed
13
import pystencils.astnodes as ast
14
15


Martin Bauer's avatar
Martin Bauer committed
16
def filtered_tree_iteration(node, node_type, stop_type=None):
17
    for arg in node.args:
Martin Bauer's avatar
Martin Bauer committed
18
        if isinstance(arg, node_type):
19
            yield arg
Martin Bauer's avatar
Martin Bauer committed
20
21
22
        elif stop_type and isinstance(node, stop_type):
            continue

Martin Bauer's avatar
Martin Bauer committed
23
        yield from filtered_tree_iteration(arg, node_type)
24
25


26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def unify_shape_symbols(body, common_shape, fields):
    """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.

    When creating a kernel with variable array sizes, all passed arrays must have the same size.
    This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead
    of one for each field. For example shape_arr1[0]  and shape_arr2[0] must be equal, so they should also be
    represented by the same symbol.

    Args:
        body: ast node, for the kernel part where substitutions is made, is modified in-place
        common_shape: shape of the field that was chosen
        fields: all fields whose shapes should be replaced by common_shape
    """
    substitutions = {}
    for field in fields:
        assert len(field.spatial_shape) == len(common_shape)
        if not field.has_fixed_shape:
            for common_shape_component, shape_component in zip(common_shape, field.spatial_shape):
                if shape_component != common_shape_component:
                    substitutions[shape_component] = common_shape_component
    if substitutions:
        body.subs(substitutions)


Martin Bauer's avatar
Martin Bauer committed
50
def get_common_shape(field_set):
51
52
    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
    ValueError is raised"""
Martin Bauer's avatar
Martin Bauer committed
53
54
55
56
57
58
59
60
    nr_of_fixed_shaped_fields = 0
    for f in field_set:
        if f.has_fixed_shape:
            nr_of_fixed_shaped_fields += 1

    if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set):
        fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
        var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
61
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
Martin Bauer's avatar
Martin Bauer committed
62
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (var_field_names, fixed_field_names)
63
64
        raise ValueError(msg)

Martin Bauer's avatar
Martin Bauer committed
65
66
67
68
    shape_set = set([f.spatial_shape for f in field_set])
    if nr_of_fixed_shaped_fields == len(field_set):
        if len(shape_set) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
69

Martin Bauer's avatar
Martin Bauer committed
70
    shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
71
72
73
    return shape


Martin Bauer's avatar
Martin Bauer committed
74
75
76
77
def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layers=None, loop_order=None):
    """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.

    Args:
78
        body: Block object with inner loop contents
Martin Bauer's avatar
Martin Bauer committed
79
80
81
82
83
84
85
86
87
        function_name: name of generated C function
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
             if None, the number of ghost layers is determined automatically and assumed to be equal for a
             all dimensions
        loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)

    Returns:
        :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
88
89
    """
    # find correct ordering by inspecting participating FieldAccesses
Martin Bauer's avatar
Martin Bauer committed
90
91
92
93
94
95
96
97
98
    field_accesses = body.atoms(Field.Access)
    # exclude accesses to buffers from field_list, because buffers are treated separately
    field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)]
    fields = set(field_list)
    num_buffer_accesses = len(field_accesses) - len(field_list)

    if loop_order is None:
        loop_order = get_optimal_loop_ordering(fields)

99
100
    shape = get_common_shape(fields)
    unify_shape_symbols(body, common_shape=shape, fields=fields)
Martin Bauer's avatar
Martin Bauer committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

    if iteration_slice is not None:
        iteration_slice = normalize_slice(iteration_slice, shape)

    if ghost_layers is None:
        required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
        ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
    if isinstance(ghost_layers, int):
        ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)

    def get_loop_stride(loop_begin, loop_end, step):
        return (loop_end - loop_begin) / step

    loop_strides = []
    loop_vars = []
    current_body = body
Martin Bauer's avatar
Martin Bauer committed
117
    for i, loop_coordinate in enumerate(reversed(loop_order)):
Martin Bauer's avatar
Martin Bauer committed
118
        if iteration_slice is None:
Martin Bauer's avatar
Martin Bauer committed
119
120
121
            begin = ghost_layers[loop_coordinate][0]
            end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
            new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
Martin Bauer's avatar
Martin Bauer committed
122
123
124
            current_body = ast.Block([new_loop])
            loop_strides.append(get_loop_stride(begin, end, 1))
            loop_vars.append(new_loop.loop_counter_symbol)
Martin Bauer's avatar
Martin Bauer committed
125
        else:
Martin Bauer's avatar
Martin Bauer committed
126
            slice_component = iteration_slice[loop_coordinate]
Martin Bauer's avatar
Martin Bauer committed
127
128
            if type(slice_component) is slice:
                sc = slice_component
Martin Bauer's avatar
Martin Bauer committed
129
                new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
Martin Bauer's avatar
Martin Bauer committed
130
131
132
                current_body = ast.Block([new_loop])
                loop_strides.append(get_loop_stride(sc.start, sc.stop, sc.step))
                loop_vars.append(new_loop.loop_counter_symbol)
Martin Bauer's avatar
Martin Bauer committed
133
            else:
Martin Bauer's avatar
Martin Bauer committed
134
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
Martin Bauer's avatar
Martin Bauer committed
135
136
                                                 sp.sympify(slice_component))
                current_body.insert_front(assignment)
137

Martin Bauer's avatar
Martin Bauer committed
138
139
140
    loop_vars = [num_buffer_accesses * var for var in loop_vars]
    ast_node = ast.KernelFunction(current_body, ghost_layers=ghost_layers, function_name=function_name, backend='cpu')
    return ast_node, loop_strides, loop_vars
141
142


Martin Bauer's avatar
Martin Bauer committed
143
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
Martin Bauer's avatar
Martin Bauer committed
144
    r"""
145
    Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
Martin Bauer's avatar
Martin Bauer committed
146
147
148
149
150
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.

151
152
153
154
155
156
157
158
159
    Args:
        field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
        coordinates: mapping of coordinate ids to its value, where stride*value is calculated
        previous_ptr: the pointer which is de-referenced

    Returns
        tuple with the new pointer symbol and the calculated offset

    Examples:
Martin Bauer's avatar
Martin Bauer committed
160
        >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
Martin Bauer's avatar
Martin Bauer committed
161
        >>> x, y = sp.symbols("x y")
Martin Bauer's avatar
Martin Bauer committed
162
163
        >>> prev_pointer = TypedSymbol("ptr", "double")
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
Martin Bauer's avatar
Martin Bauer committed
164
        (ptr_01, x*fstride_myfield[0] + fstride_myfield[0])
Martin Bauer's avatar
Martin Bauer committed
165
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
Martin Bauer's avatar
Martin Bauer committed
166
        (ptr_01_1m2, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1])
Martin Bauer's avatar
Martin Bauer committed
167
    """
Martin Bauer's avatar
Martin Bauer committed
168
    field = field_access.field
169
170
    offset = 0
    name = ""
Martin Bauer's avatar
Martin Bauer committed
171
    list_to_hash = []
Martin Bauer's avatar
Martin Bauer committed
172
173
    for coordinate_id, coordinate_value in coordinates.items():
        offset += field.strides[coordinate_id] * coordinate_value
174

Martin Bauer's avatar
Martin Bauer committed
175
176
177
        if coordinate_id < field.spatial_dimensions:
            offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
            if type(field_access.offsets[coordinate_id]) is int:
Martin Bauer's avatar
Martin Bauer committed
178
                name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
179
            else:
Martin Bauer's avatar
Martin Bauer committed
180
                list_to_hash.append(field_access.offsets[coordinate_id])
181
        else:
Martin Bauer's avatar
Martin Bauer committed
182
            if type(coordinate_value) is int:
Martin Bauer's avatar
Martin Bauer committed
183
                name += "_%d%d" % (coordinate_id, coordinate_value)
184
            else:
Martin Bauer's avatar
Martin Bauer committed
185
                list_to_hash.append(coordinate_value)
186

Martin Bauer's avatar
Martin Bauer committed
187
    if len(list_to_hash) > 0:
Martin Bauer's avatar
Martin Bauer committed
188
        name += "_%0.6X" % (hash(tuple(list_to_hash)))
189

Martin Bauer's avatar
Martin Bauer committed
190
    name = name.replace("-", 'm')
Martin Bauer's avatar
Martin Bauer committed
191
192
    new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
    return new_ptr, offset
193
194


Martin Bauer's avatar
Martin Bauer committed
195
def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
196
    """
Martin Bauer's avatar
Martin Bauer committed
197
    Creates base pointer specification for :func:`resolve_field_accesses` function.
Martin Bauer's avatar
Martin Bauer committed
198
199
200

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
201
    zero directly in the field access. These specifications are defined dependent on the loop ordering.
Martin Bauer's avatar
Martin Bauer committed
202
203
    This function translates more readable version into the specification above.

204
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
205
206
207
208
209
210
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

Martin Bauer's avatar
Martin Bauer committed
211
212
213
    Args:
        base_pointer_specification: nested list with above specifications
        loop_order: list with ordering of loops from outer to inner
Martin Bauer's avatar
Martin Bauer committed
214
215
        spatial_dimensions: number of spatial dimensions
        index_dimensions: number of index dimensions
Martin Bauer's avatar
Martin Bauer committed
216
217
218

    Returns:
        list of tuples that can be passed to :func:`resolve_field_accesses`
Martin Bauer's avatar
Martin Bauer committed
219
220
221
222
223

    Examples:
        >>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
        ...                         spatial_dimensions=3, index_dimensions=1)
        [[0], [3], [1, 2]]
224
225
    """
    result = []
Martin Bauer's avatar
Martin Bauer committed
226
227
    specified_coordinates = set()
    loop_order = list(reversed(loop_order))
Martin Bauer's avatar
Martin Bauer committed
228
    for spec_group in base_pointer_specification:
Martin Bauer's avatar
Martin Bauer committed
229
230
231
        new_group = []

        def add_new_element(elem):
Martin Bauer's avatar
Martin Bauer committed
232
            if elem >= spatial_dimensions + index_dimensions:
Martin Bauer's avatar
Martin Bauer committed
233
234
235
236
237
                raise ValueError("Coordinate %d does not exist" % (elem,))
            new_group.append(elem)
            if elem in specified_coordinates:
                raise ValueError("Coordinate %d specified two times" % (elem,))
            specified_coordinates.add(elem)
Martin Bauer's avatar
Martin Bauer committed
238
        for element in spec_group:
239
            if type(element) is int:
Martin Bauer's avatar
Martin Bauer committed
240
                add_new_element(element)
241
242
243
244
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
Martin Bauer's avatar
Martin Bauer committed
245
                    add_new_element(loop_order[index])
246
247
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
Martin Bauer's avatar
Martin Bauer committed
248
                    add_new_element(loop_order[-index])
249
                elif element == "all":
Martin Bauer's avatar
Martin Bauer committed
250
                    for i in range(spatial_dimensions):
Martin Bauer's avatar
Martin Bauer committed
251
                        add_new_element(i)
252
253
254
255
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
Martin Bauer's avatar
Martin Bauer committed
256
                add_new_element(spatial_dimensions + index)
257
258
259
            else:
                raise ValueError("Unknown specification %s" % (element,))

Martin Bauer's avatar
Martin Bauer committed
260
        result.append(new_group)
261

Martin Bauer's avatar
Martin Bauer committed
262
    all_coordinates = set(range(spatial_dimensions + index_dimensions))
Martin Bauer's avatar
Martin Bauer committed
263
    rest = all_coordinates - specified_coordinates
264
265
    if rest:
        result.append(list(rest))
266

267
268
269
    return result


Martin Bauer's avatar
Martin Bauer committed
270
271
def substitute_array_accesses_with_constants(ast_node):
    """Substitutes all instances of Indexed (array accesses) that are not field accesses with constants.
Martin Bauer's avatar
Martin Bauer committed
272
273
274
    Benchmarks showed that using an array access as loop bound or in pointer computations cause some compilers to do
    less optimizations.
    This transformation should be after field accesses have been resolved (since they introduce array accesses) and
275
276
277
    before constants are moved before the loops.
    """

Martin Bauer's avatar
Martin Bauer committed
278
    def handle_sympy_expression(expr, parent_block):
279
280
281
282
283
284
        """Returns sympy expression where array accesses have been replaced with constants, together with a list
        of assignments that define these constants"""
        if not isinstance(expr, sp.Expr):
            return expr

        # get all indexed expressions that are not field accesses
Martin Bauer's avatar
Martin Bauer committed
285
        indexed_expressions = [e for e in expr.atoms(sp.Indexed) if not isinstance(e, ast.ResolvedFieldAccess)]
286
287

        # special case: right hand side is a single indexed expression, then nothing has to be done
Martin Bauer's avatar
Martin Bauer committed
288
        if len(indexed_expressions) == 1 and expr == indexed_expressions[0]:
289
290
            return expr

Martin Bauer's avatar
Martin Bauer committed
291
292
        constants_definitions = []
        constant_substitutions = {}
Martin Bauer's avatar
Martin Bauer committed
293
294
        for indexed_expr in indexed_expressions:
            base, idx = indexed_expr.args
Martin Bauer's avatar
Martin Bauer committed
295
296
297
298
            typed_symbol = base.args[0]
            base_type = deepcopy(get_base_type(typed_symbol.dtype))
            base_type.const = False
            constant_replacing_indexed = TypedSymbol(typed_symbol.name + str(idx), base_type)
Martin Bauer's avatar
Martin Bauer committed
299
300
            constants_definitions.append(ast.SympyAssignment(constant_replacing_indexed, indexed_expr))
            constant_substitutions[indexed_expr] = constant_replacing_indexed
Martin Bauer's avatar
Martin Bauer committed
301
302
303
        constants_definitions.sort(key=lambda e: e.lhs.name)

        already_defined = parent_block.symbols_defined
Martin Bauer's avatar
Martin Bauer committed
304
305
306
        for new_assignment in constants_definitions:
            if new_assignment.lhs not in already_defined:
                parent_block.insert_before(new_assignment, ast_node)
Martin Bauer's avatar
Martin Bauer committed
307
308
309
310
311
312
313
314
315
316
317

        return expr.subs(constant_substitutions)

    if isinstance(ast_node, ast.SympyAssignment):
        ast_node.rhs = handle_sympy_expression(ast_node.rhs, ast_node.parent)
        ast_node.lhs = handle_sympy_expression(ast_node.lhs, ast_node.parent)
    elif isinstance(ast_node, ast.LoopOverCoordinate):
        ast_node.start = handle_sympy_expression(ast_node.start, ast_node.parent)
        ast_node.stop = handle_sympy_expression(ast_node.stop, ast_node.parent)
        ast_node.step = handle_sympy_expression(ast_node.step, ast_node.parent)
        substitute_array_accesses_with_constants(ast_node.body)
318
    else:
Martin Bauer's avatar
Martin Bauer committed
319
320
        for a in ast_node.args:
            substitute_array_accesses_with_constants(a)
321

Martin Bauer's avatar
Martin Bauer committed
322

Martin Bauer's avatar
Martin Bauer committed
323
324
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
325
        if isinstance(expr, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
326
            field_access = expr
327
328

            # Do not apply transformation if field is not a buffer
Martin Bauer's avatar
Martin Bauer committed
329
            if not FieldType.is_buffer(field_access.field):
330
331
                return expr

Martin Bauer's avatar
Martin Bauer committed
332
            buffer = field_access.field
333

Martin Bauer's avatar
Martin Bauer committed
334
335
            dtype = PointerType(buffer.dtype, const=buffer.name in read_only_field_names, restrict=True)
            field_ptr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbol_name_to_variable_name(buffer.name)), dtype)
336

Martin Bauer's avatar
Martin Bauer committed
337
338
            buffer_index = base_buffer_index
            if len(field_access.index) > 1:
339
340
                raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')

Martin Bauer's avatar
Martin Bauer committed
341
342
343
            if len(field_access.index) > 0:
                cell_index = field_access.index[0]
                buffer_index += cell_index
344

Martin Bauer's avatar
Martin Bauer committed
345
346
            result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
                                             field_access.index)
347

Martin Bauer's avatar
Martin Bauer committed
348
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
349
350
351
352
        else:
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr

Martin Bauer's avatar
Martin Bauer committed
353
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
354
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
355
356
357
358
359
360
361
362
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
363
        else:
Martin Bauer's avatar
Martin Bauer committed
364
365
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
366

Martin Bauer's avatar
Martin Bauer committed
367
    return visit_node(ast_node)
368

369

Martin Bauer's avatar
Martin Bauer committed
370
def resolve_field_accesses(ast_node, read_only_field_names=set(),
Martin Bauer's avatar
Martin Bauer committed
371
372
                           field_to_base_pointer_info=MappingProxyType({}),
                           field_to_fixed_coordinates=MappingProxyType({})):
Martin Bauer's avatar
Martin Bauer committed
373
374
375
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

376
377
378
379
380
381
    Args:
        ast_node: the AST root
        read_only_field_names: set of field names which are considered read-only
        field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
                                    for details see :func:`parse_base_pointer_info`
        field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
Martin Bauer's avatar
Martin Bauer committed
382
                                    counters to index the field these symbols are used as coordinates
383
384
385

    Returns
        transformed AST
Martin Bauer's avatar
Martin Bauer committed
386
    """
Martin Bauer's avatar
Martin Bauer committed
387
388
    field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
    field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
389

Martin Bauer's avatar
Martin Bauer committed
390
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
391
        if isinstance(expr, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
392
393
            field_access = expr
            field = field_access.field
394

Martin Bauer's avatar
Martin Bauer committed
395
396
            if field.name in field_to_base_pointer_info:
                base_pointer_info = field_to_base_pointer_info[field.name]
397
            else:
Martin Bauer's avatar
Martin Bauer committed
398
                base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
399

Martin Bauer's avatar
Martin Bauer committed
400
401
            dtype = PointerType(field.dtype, const=field.name in read_only_field_names, restrict=True)
            field_ptr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbol_name_to_variable_name(field.name)), dtype)
402

Martin Bauer's avatar
Martin Bauer committed
403
404
405
406
            def create_coordinate_dict(group_param):
                coordinates = {}
                for e in group_param:
                    if e < field.spatial_dimensions:
Martin Bauer's avatar
Martin Bauer committed
407
                        if field.name in field_to_fixed_coordinates:
Martin Bauer's avatar
Martin Bauer committed
408
                            coordinates[e] = field_to_fixed_coordinates[field.name][e]
409
                        else:
410
                            coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
Martin Bauer's avatar
Martin Bauer committed
411
                        coordinates[e] *= field.dtype.item_size
412
                    else:
413
                        if isinstance(field.dtype, StructType):
Martin Bauer's avatar
Martin Bauer committed
414
                            assert field.index_dimensions == 1
Martin Bauer's avatar
Martin Bauer committed
415
416
                            accessed_field_name = field_access.index[0]
                            assert isinstance(accessed_field_name, str)
Martin Bauer's avatar
Martin Bauer committed
417
                            coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
418
                        else:
Martin Bauer's avatar
Martin Bauer committed
419
                            coordinates[e] = field_access.index[e - field.spatial_dimensions]
420

Martin Bauer's avatar
Martin Bauer committed
421
                return coordinates
422

Martin Bauer's avatar
Martin Bauer committed
423
            last_pointer = field_ptr
424

Martin Bauer's avatar
Martin Bauer committed
425
426
            for group in reversed(base_pointer_info[1:]):
                coord_dict = create_coordinate_dict(group)
Martin Bauer's avatar
Martin Bauer committed
427
                new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
428
429
430
431
                if new_ptr not in enclosing_block.symbols_defined:
                    new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
                    enclosing_block.insert_before(new_assignment, sympy_assignment)
                last_pointer = new_ptr
432

Martin Bauer's avatar
Martin Bauer committed
433
            coord_dict = create_coordinate_dict(base_pointer_info[0])
Martin Bauer's avatar
Martin Bauer committed
434
            _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
435
436
            result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
                                             field_access.offsets, field_access.index)
437

Martin Bauer's avatar
Martin Bauer committed
438
439
            if isinstance(get_base_type(field_access.field.dtype), StructType):
                new_type = field_access.field.dtype.get_element_type(field_access.index[0])
Martin Bauer's avatar
Martin Bauer committed
440
                result = cast_func(result, new_type)
441

Martin Bauer's avatar
Martin Bauer committed
442
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
443
        else:
Martin Bauer's avatar
Martin Bauer committed
444
445
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr
446

Martin Bauer's avatar
Martin Bauer committed
447
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
448
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
449
450
451
452
453
454
455
456
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
457
        else:
Martin Bauer's avatar
Martin Bauer committed
458
459
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
460

Martin Bauer's avatar
Martin Bauer committed
461
    return visit_node(ast_node)
462
463


Martin Bauer's avatar
Martin Bauer committed
464
def move_constants_before_loop(ast_node):
465
466
    """Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.

Martin Bauer's avatar
Martin Bauer committed
467
    Call this after creating the loop structure with :func:`make_loop_over_domain`
Martin Bauer's avatar
Martin Bauer committed
468
    """
Martin Bauer's avatar
Martin Bauer committed
469
    def find_block_to_move_to(node):
Martin Bauer's avatar
Martin Bauer committed
470
471
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
472
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
473
474
475
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
476
477
        assert isinstance(node.parent, ast.Block)

Martin Bauer's avatar
Martin Bauer committed
478
479
        last_block = node.parent
        last_block_child = node
480
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
481
        prev_element = node
482
483
        while element:
            if isinstance(element, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
484
485
                last_block = element
                last_block_child = prev_element
486
487

            if isinstance(element, ast.Conditional):
Martin Bauer's avatar
Martin Bauer committed
488
                critical_symbols = element.condition_expr.atoms(sp.Symbol)
489
            else:
Martin Bauer's avatar
Martin Bauer committed
490
491
                critical_symbols = element.symbols_defined
            if node.undefined_symbols.intersection(critical_symbols):
492
                break
Martin Bauer's avatar
Martin Bauer committed
493
            prev_element = element
494
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
495
        return last_block, last_block_child
496

Martin Bauer's avatar
Martin Bauer committed
497
498
    def check_if_assignment_already_in_block(assignment, target_block):
        for arg in target_block.args:
499
500
501
502
503
504
            if type(arg) is not ast.SympyAssignment:
                continue
            if arg.lhs == assignment.lhs:
                return arg
        return None

Martin Bauer's avatar
Martin Bauer committed
505
    def get_blocks(node, result_list):
506
        if isinstance(node, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
507
            result_list.append(node)
508
509
        if isinstance(node, ast.Node):
            for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
510
                get_blocks(a, result_list)
511

Martin Bauer's avatar
Martin Bauer committed
512
513
514
    all_blocks = []
    get_blocks(ast_node, all_blocks)
    for block in all_blocks:
Martin Bauer's avatar
Martin Bauer committed
515
        children = block.take_child_nodes()
516
        for child in children:
517
518
519
            target, child_to_insert_before = find_block_to_move_to(child)
            if target == block:     # movement not possible
                target.append(child)
520
            else:
521
522
                if isinstance(child, ast.SympyAssignment):
                    exists_already = check_if_assignment_already_in_block(child, target)
523
                else:
524
                    exists_already = False
Martin Bauer's avatar
Martin Bauer committed
525

526
527
                if not exists_already:
                    target.insert_before(child, child_to_insert_before)
Martin Bauer's avatar
Martin Bauer committed
528
529
                elif exists_already and exists_already.rhs == child.rhs:
                    pass
530
                else:
Martin Bauer's avatar
Martin Bauer committed
531
                    block.append(child)  # don't move in this case - better would be to rename symbol
532
533


Martin Bauer's avatar
Martin Bauer committed
534
def split_inner_loop(ast_node: ast.Node, symbol_groups):
Martin Bauer's avatar
Martin Bauer committed
535
536
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
537

Martin Bauer's avatar
Martin Bauer committed
538
539
540
541
542
    Args:
        ast_node: AST root
        symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
                       updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups
                       and which no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
543
    """
Martin Bauer's avatar
Martin Bauer committed
544
545
546
547
548
549
550
551
552
553
554
555
556
    all_loops = ast_node.atoms(ast.LoopOverCoordinate)
    inner_loop = [l for l in all_loops if l.is_innermost_loop]
    assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
    inner_loop = inner_loop[0]
    assert type(inner_loop.body) is ast.Block
    outer_loop = [l for l in all_loops if l.is_outermost_loop]
    assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
    outer_loop = outer_loop[0]

    symbols_with_temporary_array = OrderedDict()
    assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args)

    assignment_groups = []
Martin Bauer's avatar
Martin Bauer committed
557
    for symbol_group in symbol_groups:
558
        # get all dependent symbols
Martin Bauer's avatar
Martin Bauer committed
559
        symbols_to_process = list(symbol_group)
Martin Bauer's avatar
Martin Bauer committed
560
561
562
563
        symbols_resolved = set()
        while symbols_to_process:
            s = symbols_to_process.pop()
            if s in symbols_resolved:
564
565
                continue

Martin Bauer's avatar
Martin Bauer committed
566
            if s in assignment_map:  # if there is no assignment inside the loop body it is independent already
Martin Bauer's avatar
Martin Bauer committed
567
568
569
                for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
                    if type(new_symbol) is not Field.Access and new_symbol not in symbols_with_temporary_array:
                        symbols_to_process.append(new_symbol)
Martin Bauer's avatar
Martin Bauer committed
570
            symbols_resolved.add(s)
571

Martin Bauer's avatar
Martin Bauer committed
572
        for symbol in symbol_group:
573
574
            if type(symbol) is not Field.Access:
                assert type(symbol) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
575
576
                new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
                symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
577

Martin Bauer's avatar
Martin Bauer committed
578
579
580
581
        assignment_group = []
        for assignment in inner_loop.body.args:
            if assignment.lhs in symbols_resolved:
                new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items())
Martin Bauer's avatar
Martin Bauer committed
582
                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbol_group:
583
                    assert type(assignment.lhs) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
584
585
                    new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
                    new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
586
                else:
Martin Bauer's avatar
Martin Bauer committed
587
588
589
                    new_lhs = assignment.lhs
                assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
        assignment_groups.append(assignment_group)
590

Martin Bauer's avatar
Martin Bauer committed
591
592
    new_loops = [inner_loop.new_loop_with_different_body(ast.Block(group)) for group in assignment_groups]
    inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
593

Martin Bauer's avatar
Martin Bauer committed
594
595
    for tmp_array in symbols_with_temporary_array:
        tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
Martin Bauer's avatar
Martin Bauer committed
596
597
598
599
        alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start)
        free_node = ast.TemporaryMemoryFree(alloc_node)
        outer_loop.parent.insert_front(alloc_node)
        outer_loop.parent.append(free_node)
600
601


Martin Bauer's avatar
Martin Bauer committed
602
def cut_loop(loop_node, cutting_points):
Martin Bauer's avatar
Martin Bauer committed
603
604
605
606
607
608
609
610
611
612
    """Cuts loop at given cutting points.

    One loop is transformed into len(cuttingPoints)+1 new loops that range from
    old_begin to cutting_points[1], ..., cutting_points[-1] to old_end

    Modifies the ast in place

    Returns:
        list of new loop nodes
    """
Martin Bauer's avatar
Martin Bauer committed
613
    if loop_node.step != 1:
614
        raise NotImplementedError("Can only split loops that have a step of 1")
Martin Bauer's avatar
Martin Bauer committed
615
616
617
    new_loops = []
    new_start = loop_node.start
    cutting_points = list(cutting_points) + [loop_node.stop]
Martin Bauer's avatar
Martin Bauer committed
618
619
    for new_end in cutting_points:
        if new_end - new_start == 1:
Martin Bauer's avatar
Martin Bauer committed
620
621
622
            new_body = deepcopy(loop_node.body)
            new_body.subs({loop_node.loop_counter_symbol: new_start})
            new_loops.append(new_body)
Martin Bauer's avatar
Martin Bauer committed
623
624
        elif new_end - new_start == 0:
            pass
625
        else:
Martin Bauer's avatar
Martin Bauer committed
626
627
            new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
                                              new_start, new_end, loop_node.step)
Martin Bauer's avatar
Martin Bauer committed
628
            new_loops.append(new_loop)
Martin Bauer's avatar
Martin Bauer committed
629
        new_start = new_end
Martin Bauer's avatar
Martin Bauer committed
630
    loop_node.parent.replace(loop_node, new_loops)
Martin Bauer's avatar
Martin Bauer committed
631
    return new_loops
632
633


634
635
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool=False) -> None:
    """Removes conditionals that are always true/false.
636
637

    Args:
638
639
640
641
642
643
        node: ast node, all descendants of this node are simplified
        loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false
                                     depending on the surrounding loop. For example if the surrounding loop goes from
                                     x=0 to 10 and the condition is x < 0, it is removed.
                                     This analysis needs the integer set library (ISL) islpy, so it is not done by
                                     default.
644
    """
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
    for conditional in node.atoms(ast.Conditional):
        conditional.condition_expr = sp.simplify(conditional.condition_expr)
        if conditional.condition_expr == sp.true:
            conditional.parent.replace(conditional, [conditional.true_block])
        elif conditional.condition_expr == sp.false:
            conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
        elif loop_counter_simplification:
            try:
                # noinspection PyUnresolvedReferences
                from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
                simplify_loop_counter_dependent_conditional(conditional)
            except ImportError:
                warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")


def cleanup_blocks(node: ast.Node) -> None:
661
662
663
664
665
    """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
    if isinstance(node, ast.SympyAssignment):
        return
    elif isinstance(node, ast.Block):
        for a in list(node.args):
Martin Bauer's avatar
Martin Bauer committed
666
            cleanup_blocks(a)
667
668
669
670
671
        if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
            node.parent.replace(node, node.args)
            return
    else:
        for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
672
            cleanup_blocks(a)
673
674


Martin Bauer's avatar
Martin Bauer committed
675
def symbol_name_to_variable_name(symbol_name):
676
    """Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names"""
Martin Bauer's avatar
Martin Bauer committed
677
    return symbol_name.replace("^", "_")
678
679


680
681
682
683
684
685
686
687
688
689
690
691
class KernelConstraintsCheck:
    """Checks if the input to create_kernel is valid.

    Test the following conditions:

    - SSA Form for pure symbols:
        -  Every pure symbol may occur only once as left-hand-side of an assignment
        -  Every pure symbol that is read, may not be written to later
    - Independence / Parallelization condition:
        - a field that is written may only be read at exact the same spatial position

    (Pure symbols are symbols that are not Field.Accesses)
Martin Bauer's avatar
Martin Bauer committed
692
    """
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
    FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])

    def __init__(self, type_for_symbol, check_independence_condition):
        self._type_for_symbol = type_for_symbol
        self._defined_pure_symbols = set()
        self._accessed_pure_symbols = set()

        self._field_writes = defaultdict(set)
        self.fields_read = set()
        self.check_independence_condition = check_independence_condition

    def process_assignment(self, assignment):
        # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
        new_rhs = self.process_expression(assignment.rhs)
        new_lhs = self._process_lhs(assignment.lhs)
        return ast.SympyAssignment(new_lhs, new_rhs)

    def process_expression(self, rhs):
        self._update_accesses_rhs(rhs)
        if isinstance(rhs, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
713
            self.fields_read.add(rhs.field)
714
715
716
717
718
            return rhs
        elif isinstance(rhs, TypedSymbol):
            return rhs
        elif isinstance(rhs, sp.Symbol):
            return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name])
719
720
721
722
        elif isinstance(rhs, sp.Number):
            return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
        elif isinstance(rhs, sp.Mul):
            new_args = [self.process_expression(arg) if arg not in (-1, 1) else arg for arg in rhs.args]
723
            return rhs.func(*new_args) if new_args else rhs
724
725
726
727
728
729
730
        else:
            if isinstance(rhs, sp.Pow):
                # don't process exponents -> they should remain integers
                return sp.Pow(self.process_expression(rhs.args[0]), rhs.args[1])
            else:
                new_args = [self.process_expression(arg) for arg in rhs.args]
                return rhs.func(*new_args) if new_args else rhs
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748

    @property
    def fields_written(self):
        return set(k.field for k, v in self._field_writes.items() if len(v))

    def _process_lhs(self, lhs):
        assert isinstance(lhs, sp.Symbol)
        self._update_accesses_lhs(lhs)
        if not isinstance(lhs, Field.Access) and not isinstance(lhs, TypedSymbol):
            return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
        else:
            return lhs

    def _update_accesses_lhs(self, lhs):
        if isinstance(lhs, Field.Access):
            fai = self.FieldAndIndex(lhs.field, lhs.index)
            self._field_writes[fai].add(lhs.offsets)
            if len(self._field_writes[fai]) > 1:
749
                raise ValueError("Field {} is written at two different locations".format(lhs.field.name))
750
751
        elif isinstance(lhs, sp.Symbol):
            if lhs in self._defined_pure_symbols:
752
                raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
753
            if lhs in self._accessed_pure_symbols:
754
                raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
755
756
757
758
759
760
761
762
            self._defined_pure_symbols.add(lhs)

    def _update_accesses_rhs(self, rhs):
        if isinstance(rhs, Field.Access) and self.check_independence_condition:
            writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
            for write_offset in writes:
                assert len(writes) == 1
                if write_offset != rhs.offsets:
763
764
                    raise ValueError("Violation of loop independence condition. Field "
                                     "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
765
766
767
768
769
770
771
772
            self.fields_read.add(rhs.field)
        elif isinstance(rhs, sp.Symbol):
            self._accessed_pure_symbols.add(rhs)


def add_types(eqs, type_for_symbol, check_independence_condition):
    """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.

Martin Bauer's avatar
Martin Bauer committed
773
774
    Additionally returns sets of all fields which are read/written

775
776
777
778
779
780
781
782
783
    Args:
        eqs: list of equations
        type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
        check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
                                      kernels

    Returns:
        ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
         list of equations where symbols have been replaced by typed symbols
Martin Bauer's avatar
Martin Bauer committed
784
    """
Martin Bauer's avatar
Martin Bauer committed
785
786
    if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
        type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
787

788
    check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
789

Martin Bauer's avatar
Martin Bauer committed
790
791
792
793
    def visit(obj):
        if isinstance(obj, list) or isinstance(obj, tuple):
            return [visit(e) for e in obj]
        if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
794
            return check.process_assignment(obj)
Martin Bauer's avatar
Martin Bauer committed
795
        elif isinstance(obj, ast.Conditional):
Martin Bauer's avatar
Martin Bauer committed
796
            false_block = None if obj.false_block is None else visit(obj.false_block)
797
            return ast.Conditional(check.process_expression(obj.condition_expr),
Martin Bauer's avatar
Martin Bauer committed
798
                                   true_block=visit(obj.true_block), false_block=false_block)
Martin Bauer's avatar
Martin Bauer committed
799
800
        elif isinstance(obj, ast.Block):
            return ast.Block([visit(e) for e in obj.args])
801
        elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
Martin Bauer's avatar
Martin Bauer committed
802
            return obj
803
804
        else:
            raise ValueError("Invalid object in kernel " + str(type(obj)))
805

Martin Bauer's avatar
Martin Bauer committed
806
    typed_equations = visit(eqs)
807

808
    return check.fields_read, check.fields_written, typed_equations
809
810


Martin Bauer's avatar
Martin Bauer committed
811
def insert_casts(node):
812
    """Checks the types and inserts casts and pointer arithmetic where necessary.
Martin Bauer's avatar
Martin Bauer committed
813

814
815
816
817
818
    Args:
        node: the head node of the ast

    Returns:
        modified AST
Martin Bauer's avatar
Martin Bauer committed
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
    """
    def cast(zipped_args_types, target_dtype):
        """
        Adds casts to the arguments if their type differs from the target type
        :param zipped_args_types: a zipped list of args and types
        :param target_dtype: The target data type
        :return: args with possible casts
        """
        casted_args = []
        for argument, data_type in zipped_args_types:
            if data_type.numpy_dtype != target_dtype.numpy_dtype:  # ignoring const
                casted_args.append(cast_func(argument, target_dtype))
            else:
                casted_args.append(argument)
        return casted_args

    def pointer_arithmetic(expr_args):
        """
        Creates a valid pointer arithmetic function
        :param expr_args: Arguments of the add expression
        :return: pointer_arithmetic_func
        """
        pointer = None
        new_args = []
        for arg, data_type in expr_args:
            if data_type.func is PointerType:
                assert pointer is None
                pointer = arg
        for arg, data_type in expr_args:
            if arg != pointer:
                assert data_type.is_int() or data_type.is_uint()
                new_args.append(arg)
        new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
        return pointer_arithmetic_func(pointer, new_args)

854
    if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func):
Martin Bauer's avatar
Martin Bauer committed
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
        return node
    args = []
    for arg in node.args:
        args.append(insert_casts(arg))
    # TODO indexed, LoopOverCoordinate
    if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
        # TODO optimize pow, don't cast integer on double
        types = [get_type_of_expression(arg) for arg in args]
        assert len(types) > 0
        target = collate_types(types)
        zipped = list(zip(args, types))
        if target.func is PointerType:
            assert node.func is sp.Add
            return pointer_arithmetic(zipped)
        else:
            return node.func(*cast(zipped, target))
    elif node.func is ast.SympyAssignment:
        lhs = args[0]
        rhs = args[1]
        target = get_type_of_expression(lhs)
        if target.func is PointerType:
            return node.func(*args)  # TODO fix, not complete
        else:
            return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
    elif node.func is ast.ResolvedFieldAccess:
        return node
    elif node.func is ast.Block:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is ast.LoopOverCoordinate:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is sp.Piecewise:
        expressions = [expr for (expr, _) in args]
        types = [get_type_of_expression(expr) for expr in expressions]
        target = collate_types(types)
        zipped = list(zip(expressions, types))
        casted_expressions = cast(zipped, target)
        args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_expressions)]

    return node.func(*args)


900
901
902
903
904
905
906
907
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
    """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""

    all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
    assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
    inner_loop = all_inner_loops.pop()

    for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
908
        cut_loop(loop, [loop.stop - 1])
909

910
    simplify_conditionals(function_node.body, loop_counter_simplification=True)
911
    cleanup_blocks(function_node.body)
Martin Bauer's avatar
Martin Bauer committed
912

913
914
915
916
    move_constants_before_loop(function_node.body)
    cleanup_blocks(function_node.body)


Martin Bauer's avatar
Martin Bauer committed
917
918
919
# --------------------------------------- Helper Functions -------------------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
920
def typing_from_sympy_inspection(eqs, default_type="double"):
Martin Bauer's avatar
Martin Bauer committed
921
922
923
    """
    Creates a default symbol name to type mapping.
    If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
924
925
926
927
928
929

    Args:
        eqs: list of equations
        default_type: the type for non-boolean symbols
    Returns:
        dictionary, mapping symbol name to type
Martin Bauer's avatar
Martin Bauer committed
930
    """
Martin Bauer's avatar
Martin Bauer committed
931
    result = defaultdict(lambda: default_type)
932
    for eq in eqs:
933
934
        if isinstance(eq, ast.Node):
            continue
935
936
937
        # problematic case here is when rhs is a symbol: then it is impossible to decide here without
        # further information what type the left hand side is - default fallback is the dict value then
        if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
938
939
940
941
            result[eq.lhs.name] = "bool"
    return result


Martin Bauer's avatar
Martin Bauer committed
942
def get_next_parent_of_type(node, parent_type):
943
944
945
946
    """Returns the next parent node of given type or None, if root is reached.

    Traverses the AST nodes parents until a parent of given type was found.
    If no such parent is found, None is returned
Martin Bauer's avatar
Martin Bauer committed
947
    """
948
949
    parent = node.parent
    while parent is not None:
Martin Bauer's avatar
Martin Bauer committed
950
        if isinstance(parent, parent_type):
951
952
953
954
955
            return parent
        parent = parent.parent
    return None


956
def parents_of_type(node, parent_type, include_current=False):
957
    """Generator for all parent nodes of given type"""
958
959
960
961
962
963
964
    parent = node if include_current else node.parent
    while parent is not None:
        if isinstance(parent, parent_type):
            yield parent
        parent = parent.parent


Martin Bauer's avatar
Martin Bauer committed
965
def get_optimal_loop_ordering(fields):
Martin Bauer's avatar
Martin Bauer committed
966
967
968
    """
    Determines the optimal loop order for a given set of fields.
    If the fields have different memory layout or different sizes an exception is thrown.
969
970
971
972
973
974

    Args:
        fields: sequence of fields

    Returns:
        list of coordinate ids, where the first list entry should be the outermost loop
Martin Bauer's avatar
Martin Bauer committed
975
    """
976
    assert len(fields) > 0
Martin Bauer's avatar
Martin Bauer committed
977
    ref_field = next(iter(fields))
978
    for field in fields:
Martin Bauer's avatar
Martin Bauer committed
979
        if field.spatial_dimensions != ref_field.spatial_dimensions:
980
            raise ValueError("All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
Martin Bauer's avatar
Martin Bauer committed
981
                             + str({f.name: f.spatial_shape for f in fields}))
982
983
984

    layouts = set([field.layout for field in fields])
    if len(layouts) > 1:
985
986
        raise ValueError("Due to different layout of the fields no optimal loop ordering exists " +
                         str({f.name: f.layout for f in fields}))
987
    layout = list(layouts)[0]
988
    return list(layout)
989
990


Martin Bauer's avatar
Martin Bauer committed
991
def get_loop_hierarchy(ast_node):
992
993
994
995
    """Determines the loop structure around a given AST node, i.e. the node has to be inside the loops.

    Returns:
        sequence of LoopOverCoordinate nodes, starting from outer loop to innermost loop
Martin Bauer's avatar
Martin Bauer committed
996
    """
997
    result = []
Martin Bauer's avatar
Martin Bauer committed
998
    node = ast_node
999
    while node is not None:
Martin Bauer's avatar
Martin Bauer committed
1000
        node = get_next_parent_of_type(node, ast.LoopOverCoordinate)