transformations.py 41.4 KB
Newer Older
1
from collections import defaultdict, OrderedDict
2
from copy import deepcopy
Martin Bauer's avatar
Martin Bauer committed
3
from types import MappingProxyType
4
5
import sympy as sp
from sympy.logic.boolalg import Boolean
6
from sympy.tensor import IndexedBase
7
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
8
from pystencils.field import Field, FieldType, offset_component_to_direction_string
Martin Bauer's avatar
Martin Bauer committed
9
10
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \
    pointer_arithmetic_func, get_type_of_expression, collate_types
Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.slicing import normalize_slice
Martin Bauer's avatar
Martin Bauer committed
12
import pystencils.astnodes as ast
13
14


Martin Bauer's avatar
Martin Bauer committed
15
def filtered_tree_iteration(node, node_type):
16
    for arg in node.args:
Martin Bauer's avatar
Martin Bauer committed
17
        if isinstance(arg, node_type):
18
            yield arg
Martin Bauer's avatar
Martin Bauer committed
19
        yield from filtered_tree_iteration(arg, node_type)
20
21


Martin Bauer's avatar
Martin Bauer committed
22
def get_common_shape(field_set):
23
24
    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
    ValueError is raised"""
Martin Bauer's avatar
Martin Bauer committed
25
26
27
28
29
30
31
32
    nr_of_fixed_shaped_fields = 0
    for f in field_set:
        if f.has_fixed_shape:
            nr_of_fixed_shaped_fields += 1

    if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set):
        fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
        var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
33
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
Martin Bauer's avatar
Martin Bauer committed
34
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (var_field_names, fixed_field_names)
35
36
        raise ValueError(msg)

Martin Bauer's avatar
Martin Bauer committed
37
38
39
40
    shape_set = set([f.spatial_shape for f in field_set])
    if nr_of_fixed_shaped_fields == len(field_set):
        if len(shape_set) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
41

Martin Bauer's avatar
Martin Bauer committed
42
    shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
43
44
45
    return shape


Martin Bauer's avatar
Martin Bauer committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layers=None, loop_order=None):
    """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.

    Args:
        body: list of nodes
        function_name: name of generated C function
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
             if None, the number of ghost layers is determined automatically and assumed to be equal for a
             all dimensions
        loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)

    Returns:
        :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
60
61
    """
    # find correct ordering by inspecting participating FieldAccesses
Martin Bauer's avatar
Martin Bauer committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    field_accesses = body.atoms(Field.Access)
    # exclude accesses to buffers from field_list, because buffers are treated separately
    field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)]
    fields = set(field_list)
    num_buffer_accesses = len(field_accesses) - len(field_list)

    if loop_order is None:
        loop_order = get_optimal_loop_ordering(fields)

    shape = get_common_shape(list(fields))

    if iteration_slice is not None:
        iteration_slice = normalize_slice(iteration_slice, shape)

    if ghost_layers is None:
        required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
        ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
    if isinstance(ghost_layers, int):
        ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)

    def get_loop_stride(loop_begin, loop_end, step):
        return (loop_end - loop_begin) / step

    loop_strides = []
    loop_vars = []
    current_body = body
Martin Bauer's avatar
Martin Bauer committed
88
    for i, loop_coordinate in enumerate(reversed(loop_order)):
Martin Bauer's avatar
Martin Bauer committed
89
        if iteration_slice is None:
Martin Bauer's avatar
Martin Bauer committed
90
91
92
            begin = ghost_layers[loop_coordinate][0]
            end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
            new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
Martin Bauer's avatar
Martin Bauer committed
93
94
95
            current_body = ast.Block([new_loop])
            loop_strides.append(get_loop_stride(begin, end, 1))
            loop_vars.append(new_loop.loop_counter_symbol)
Martin Bauer's avatar
Martin Bauer committed
96
        else:
Martin Bauer's avatar
Martin Bauer committed
97
            slice_component = iteration_slice[loop_coordinate]
Martin Bauer's avatar
Martin Bauer committed
98
99
            if type(slice_component) is slice:
                sc = slice_component
Martin Bauer's avatar
Martin Bauer committed
100
                new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
Martin Bauer's avatar
Martin Bauer committed
101
102
103
                current_body = ast.Block([new_loop])
                loop_strides.append(get_loop_stride(sc.start, sc.stop, sc.step))
                loop_vars.append(new_loop.loop_counter_symbol)
Martin Bauer's avatar
Martin Bauer committed
104
            else:
Martin Bauer's avatar
Martin Bauer committed
105
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
Martin Bauer's avatar
Martin Bauer committed
106
107
                                                 sp.sympify(slice_component))
                current_body.insert_front(assignment)
108

Martin Bauer's avatar
Martin Bauer committed
109
110
111
    loop_vars = [num_buffer_accesses * var for var in loop_vars]
    ast_node = ast.KernelFunction(current_body, ghost_layers=ghost_layers, function_name=function_name, backend='cpu')
    return ast_node, loop_strides, loop_vars
112
113


Martin Bauer's avatar
Martin Bauer committed
114
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
Martin Bauer's avatar
Martin Bauer committed
115
116
117
118
119
120
    r"""
    Addressing elements in structured arrays are done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.
Martin Bauer's avatar
Martin Bauer committed
121
    :param field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
Martin Bauer's avatar
Martin Bauer committed
122
    :param coordinates: mapping of coordinate ids to its value, where stride*value is calculated
Martin Bauer's avatar
Martin Bauer committed
123
    :param previous_ptr: the pointer which is de-referenced
Martin Bauer's avatar
Martin Bauer committed
124
125
126
    :return: tuple with the new pointer symbol and the calculated offset

    Example:
Martin Bauer's avatar
Martin Bauer committed
127
        >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
Martin Bauer's avatar
Martin Bauer committed
128
        >>> x, y = sp.symbols("x y")
Martin Bauer's avatar
Martin Bauer committed
129
130
        >>> prev_pointer = TypedSymbol("ptr", "double")
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
Martin Bauer's avatar
Martin Bauer committed
131
        (ptr_E, x*fstride_myfield[0] + fstride_myfield[0])
Martin Bauer's avatar
Martin Bauer committed
132
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
Martin Bauer's avatar
Martin Bauer committed
133
134
        (ptr_E_2S, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1])
    """
Martin Bauer's avatar
Martin Bauer committed
135
    field = field_access.field
136
137
    offset = 0
    name = ""
Martin Bauer's avatar
Martin Bauer committed
138
    list_to_hash = []
Martin Bauer's avatar
Martin Bauer committed
139
140
    for coordinate_id, coordinate_value in coordinates.items():
        offset += field.strides[coordinate_id] * coordinate_value
141

Martin Bauer's avatar
Martin Bauer committed
142
143
144
145
        if coordinate_id < field.spatial_dimensions:
            offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
            if type(field_access.offsets[coordinate_id]) is int:
                offset_comp = offset_component_to_direction_string(coordinate_id, field_access.offsets[coordinate_id])
146
                name += "_"
Martin Bauer's avatar
Martin Bauer committed
147
                name += offset_comp if offset_comp else "C"
148
            else:
Martin Bauer's avatar
Martin Bauer committed
149
                list_to_hash.append(field_access.offsets[coordinate_id])
150
        else:
Martin Bauer's avatar
Martin Bauer committed
151
152
            if type(coordinate_value) is int:
                name += "_%d" % (coordinate_value,)
153
            else:
Martin Bauer's avatar
Martin Bauer committed
154
                list_to_hash.append(coordinate_value)
155

Martin Bauer's avatar
Martin Bauer committed
156
157
    if len(list_to_hash) > 0:
        name += "%0.6X" % (abs(hash(tuple(list_to_hash))))
158

Martin Bauer's avatar
Martin Bauer committed
159
    new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
160

Martin Bauer's avatar
Martin Bauer committed
161
    return new_ptr, offset
162
163


Martin Bauer's avatar
Martin Bauer committed
164
def parse_base_pointer_info(base_pointer_specification, loop_order, field):
165
    """
Martin Bauer's avatar
Martin Bauer committed
166
    Creates base pointer specification for :func:`resolve_field_accesses` function.
Martin Bauer's avatar
Martin Bauer committed
167
168
169
170
171
172

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
    zero directly in the field access. These specifications are more sensible defined dependent on the loop ordering.
    This function translates more readable version into the specification above.

173
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
174
175
176
177
178
179
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

Martin Bauer's avatar
Martin Bauer committed
180
181
182
183
184
185
186
    Args:
        base_pointer_specification: nested list with above specifications
        loop_order: list with ordering of loops from outer to inner
        field:

    Returns:
        list of tuples that can be passed to :func:`resolve_field_accesses`
187
188
    """
    result = []
Martin Bauer's avatar
Martin Bauer committed
189
190
    specified_coordinates = set()
    loop_order = list(reversed(loop_order))
Martin Bauer's avatar
Martin Bauer committed
191
    for spec_group in base_pointer_specification:
Martin Bauer's avatar
Martin Bauer committed
192
193
194
195
196
197
198
199
200
        new_group = []

        def add_new_element(elem):
            if elem >= field.spatial_dimensions + field.index_dimensions:
                raise ValueError("Coordinate %d does not exist" % (elem,))
            new_group.append(elem)
            if elem in specified_coordinates:
                raise ValueError("Coordinate %d specified two times" % (elem,))
            specified_coordinates.add(elem)
Martin Bauer's avatar
Martin Bauer committed
201
        for element in spec_group:
202
            if type(element) is int:
Martin Bauer's avatar
Martin Bauer committed
203
                add_new_element(element)
204
205
206
207
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
Martin Bauer's avatar
Martin Bauer committed
208
                    add_new_element(loop_order[index])
209
210
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
Martin Bauer's avatar
Martin Bauer committed
211
                    add_new_element(loop_order[-index])
212
                elif element == "all":
Martin Bauer's avatar
Martin Bauer committed
213
214
                    for i in range(field.spatial_dimensions):
                        add_new_element(i)
215
216
217
218
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
Martin Bauer's avatar
Martin Bauer committed
219
                add_new_element(field.spatial_dimensions + index)
220
221
222
            else:
                raise ValueError("Unknown specification %s" % (element,))

Martin Bauer's avatar
Martin Bauer committed
223
        result.append(new_group)
224

Martin Bauer's avatar
Martin Bauer committed
225
226
    all_coordinates = set(range(field.spatial_dimensions + field.index_dimensions))
    rest = all_coordinates - specified_coordinates
227
228
    if rest:
        result.append(list(rest))
229

230
231
232
    return result


Martin Bauer's avatar
Martin Bauer committed
233
234
def substitute_array_accesses_with_constants(ast_node):
    """Substitutes all instances of Indexed (array accesses) that are not field accesses with constants.
Martin Bauer's avatar
Martin Bauer committed
235
236
237
    Benchmarks showed that using an array access as loop bound or in pointer computations cause some compilers to do
    less optimizations.
    This transformation should be after field accesses have been resolved (since they introduce array accesses) and
238
239
240
    before constants are moved before the loops.
    """

Martin Bauer's avatar
Martin Bauer committed
241
    def handle_sympy_expression(expr, parent_block):
242
243
244
245
246
247
        """Returns sympy expression where array accesses have been replaced with constants, together with a list
        of assignments that define these constants"""
        if not isinstance(expr, sp.Expr):
            return expr

        # get all indexed expressions that are not field accesses
Martin Bauer's avatar
Martin Bauer committed
248
        indexed_expressions = [e for e in expr.atoms(sp.Indexed) if not isinstance(e, ast.ResolvedFieldAccess)]
249
250

        # special case: right hand side is a single indexed expression, then nothing has to be done
Martin Bauer's avatar
Martin Bauer committed
251
        if len(indexed_expressions) == 1 and expr == indexed_expressions[0]:
252
253
            return expr

Martin Bauer's avatar
Martin Bauer committed
254
255
        constants_definitions = []
        constant_substitutions = {}
Martin Bauer's avatar
Martin Bauer committed
256
257
        for indexed_expr in indexed_expressions:
            base, idx = indexed_expr.args
Martin Bauer's avatar
Martin Bauer committed
258
259
260
261
            typed_symbol = base.args[0]
            base_type = deepcopy(get_base_type(typed_symbol.dtype))
            base_type.const = False
            constant_replacing_indexed = TypedSymbol(typed_symbol.name + str(idx), base_type)
Martin Bauer's avatar
Martin Bauer committed
262
263
            constants_definitions.append(ast.SympyAssignment(constant_replacing_indexed, indexed_expr))
            constant_substitutions[indexed_expr] = constant_replacing_indexed
Martin Bauer's avatar
Martin Bauer committed
264
265
266
        constants_definitions.sort(key=lambda e: e.lhs.name)

        already_defined = parent_block.symbols_defined
Martin Bauer's avatar
Martin Bauer committed
267
268
269
        for new_assignment in constants_definitions:
            if new_assignment.lhs not in already_defined:
                parent_block.insert_before(new_assignment, ast_node)
Martin Bauer's avatar
Martin Bauer committed
270
271
272
273
274
275
276
277
278
279
280

        return expr.subs(constant_substitutions)

    if isinstance(ast_node, ast.SympyAssignment):
        ast_node.rhs = handle_sympy_expression(ast_node.rhs, ast_node.parent)
        ast_node.lhs = handle_sympy_expression(ast_node.lhs, ast_node.parent)
    elif isinstance(ast_node, ast.LoopOverCoordinate):
        ast_node.start = handle_sympy_expression(ast_node.start, ast_node.parent)
        ast_node.stop = handle_sympy_expression(ast_node.stop, ast_node.parent)
        ast_node.step = handle_sympy_expression(ast_node.step, ast_node.parent)
        substitute_array_accesses_with_constants(ast_node.body)
281
    else:
Martin Bauer's avatar
Martin Bauer committed
282
283
        for a in ast_node.args:
            substitute_array_accesses_with_constants(a)
284

Martin Bauer's avatar
Martin Bauer committed
285

Martin Bauer's avatar
Martin Bauer committed
286
287
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
288
        if isinstance(expr, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
289
            field_access = expr
290
291

            # Do not apply transformation if field is not a buffer
Martin Bauer's avatar
Martin Bauer committed
292
            if not FieldType.is_buffer(field_access.field):
293
294
                return expr

Martin Bauer's avatar
Martin Bauer committed
295
            buffer = field_access.field
296

Martin Bauer's avatar
Martin Bauer committed
297
298
            dtype = PointerType(buffer.dtype, const=buffer.name in read_only_field_names, restrict=True)
            field_ptr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbol_name_to_variable_name(buffer.name)), dtype)
299

Martin Bauer's avatar
Martin Bauer committed
300
301
            buffer_index = base_buffer_index
            if len(field_access.index) > 1:
302
303
                raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')

Martin Bauer's avatar
Martin Bauer committed
304
305
306
            if len(field_access.index) > 0:
                cell_index = field_access.index[0]
                buffer_index += cell_index
307

Martin Bauer's avatar
Martin Bauer committed
308
309
            result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
                                             field_access.index)
310

Martin Bauer's avatar
Martin Bauer committed
311
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
312
313
314
315
        else:
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr

Martin Bauer's avatar
Martin Bauer committed
316
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
317
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
318
319
320
321
322
323
324
325
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
326
        else:
Martin Bauer's avatar
Martin Bauer committed
327
328
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
329

Martin Bauer's avatar
Martin Bauer committed
330
    return visit_node(ast_node)
331

332

Martin Bauer's avatar
Martin Bauer committed
333
def resolve_field_accesses(ast_node, read_only_field_names=set(),
Martin Bauer's avatar
Martin Bauer committed
334
335
                           field_to_base_pointer_info=MappingProxyType({}),
                           field_to_fixed_coordinates=MappingProxyType({})):
Martin Bauer's avatar
Martin Bauer committed
336
337
338
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

Martin Bauer's avatar
Martin Bauer committed
339
340
    :param ast_node: the AST root
    :param read_only_field_names: set of field names which are considered read-only
Martin Bauer's avatar
Martin Bauer committed
341
    :param field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
Martin Bauer's avatar
Martin Bauer committed
342
                                   for details see :func:`parse_base_pointer_info`
Martin Bauer's avatar
Martin Bauer committed
343
    :param field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
Martin Bauer's avatar
Martin Bauer committed
344
345
346
                                    counters to index the field these symbols are used as coordinates
    :return: transformed AST
    """
Martin Bauer's avatar
Martin Bauer committed
347
348
    field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
    field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
349

Martin Bauer's avatar
Martin Bauer committed
350
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
351
        if isinstance(expr, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
352
353
            field_access = expr
            field = field_access.field
354

Martin Bauer's avatar
Martin Bauer committed
355
356
            if field.name in field_to_base_pointer_info:
                base_pointer_info = field_to_base_pointer_info[field.name]
357
            else:
Martin Bauer's avatar
Martin Bauer committed
358
                base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
359

Martin Bauer's avatar
Martin Bauer committed
360
361
            dtype = PointerType(field.dtype, const=field.name in read_only_field_names, restrict=True)
            field_ptr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbol_name_to_variable_name(field.name)), dtype)
362

Martin Bauer's avatar
Martin Bauer committed
363
364
365
366
            def create_coordinate_dict(group_param):
                coordinates = {}
                for e in group_param:
                    if e < field.spatial_dimensions:
Martin Bauer's avatar
Martin Bauer committed
367
                        if field.name in field_to_fixed_coordinates:
Martin Bauer's avatar
Martin Bauer committed
368
                            coordinates[e] = field_to_fixed_coordinates[field.name][e]
369
                        else:
Martin Bauer's avatar
Martin Bauer committed
370
                            ctr_name = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
Martin Bauer's avatar
Martin Bauer committed
371
372
                            coordinates[e] = TypedSymbol("%s_%d" % (ctr_name, e), 'int')
                        coordinates[e] *= field.dtype.item_size
373
                    else:
374
                        if isinstance(field.dtype, StructType):
Martin Bauer's avatar
Martin Bauer committed
375
                            assert field.index_dimensions == 1
Martin Bauer's avatar
Martin Bauer committed
376
377
                            accessed_field_name = field_access.index[0]
                            assert isinstance(accessed_field_name, str)
Martin Bauer's avatar
Martin Bauer committed
378
                            coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
379
                        else:
Martin Bauer's avatar
Martin Bauer committed
380
                            coordinates[e] = field_access.index[e - field.spatial_dimensions]
381

Martin Bauer's avatar
Martin Bauer committed
382
                return coordinates
383

Martin Bauer's avatar
Martin Bauer committed
384
            last_pointer = field_ptr
385

Martin Bauer's avatar
Martin Bauer committed
386
387
            for group in reversed(base_pointer_info[1:]):
                coord_dict = create_coordinate_dict(group)
Martin Bauer's avatar
Martin Bauer committed
388
                new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
389
390
391
392
                if new_ptr not in enclosing_block.symbols_defined:
                    new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
                    enclosing_block.insert_before(new_assignment, sympy_assignment)
                last_pointer = new_ptr
393

Martin Bauer's avatar
Martin Bauer committed
394
            coord_dict = create_coordinate_dict(base_pointer_info[0])
395

Martin Bauer's avatar
Martin Bauer committed
396
            _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
397
398
            result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
                                             field_access.offsets, field_access.index)
399

Martin Bauer's avatar
Martin Bauer committed
400
401
            if isinstance(get_base_type(field_access.field.dtype), StructType):
                new_type = field_access.field.dtype.get_element_type(field_access.index[0])
Martin Bauer's avatar
Martin Bauer committed
402
                result = cast_func(result, new_type)
403

Martin Bauer's avatar
Martin Bauer committed
404
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
405
        else:
Martin Bauer's avatar
Martin Bauer committed
406
407
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr
408

Martin Bauer's avatar
Martin Bauer committed
409
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
410
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
411
412
413
414
415
416
417
418
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
419
        else:
Martin Bauer's avatar
Martin Bauer committed
420
421
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
422

Martin Bauer's avatar
Martin Bauer committed
423
    return visit_node(ast_node)
424
425


Martin Bauer's avatar
Martin Bauer committed
426
def move_constants_before_loop(ast_node):
Martin Bauer's avatar
Martin Bauer committed
427
428
    """
    Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.
Martin Bauer's avatar
Martin Bauer committed
429
430
    Call this after creating the loop structure with :func:`make_loop_over_domain`
    :param ast_node:
Martin Bauer's avatar
Martin Bauer committed
431
432
    :return:
    """
Martin Bauer's avatar
Martin Bauer committed
433
    def find_block_to_move_to(node):
Martin Bauer's avatar
Martin Bauer committed
434
435
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
436
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
437
438
439
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
440
441
442
        assert isinstance(node, ast.SympyAssignment)
        assert isinstance(node.parent, ast.Block)

Martin Bauer's avatar
Martin Bauer committed
443
444
        last_block = node.parent
        last_block_child = node
445
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
446
        prev_element = node
447
448
        while element:
            if isinstance(element, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
449
450
                last_block = element
                last_block_child = prev_element
451
452

            if isinstance(element, ast.Conditional):
Martin Bauer's avatar
Martin Bauer committed
453
                critical_symbols = element.condition_expr.atoms(sp.Symbol)
454
            else:
Martin Bauer's avatar
Martin Bauer committed
455
456
                critical_symbols = element.symbols_defined
            if node.undefined_symbols.intersection(critical_symbols):
457
                break
Martin Bauer's avatar
Martin Bauer committed
458
            prev_element = element
459
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
460
        return last_block, last_block_child
461

Martin Bauer's avatar
Martin Bauer committed
462
463
    def check_if_assignment_already_in_block(assignment, target_block):
        for arg in target_block.args:
464
465
466
467
468
469
            if type(arg) is not ast.SympyAssignment:
                continue
            if arg.lhs == assignment.lhs:
                return arg
        return None

Martin Bauer's avatar
Martin Bauer committed
470
    def get_blocks(node, result_list):
471
        if isinstance(node, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
472
            result_list.insert(0, node)
473
474
        if isinstance(node, ast.Node):
            for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
475
                get_blocks(a, result_list)
476

Martin Bauer's avatar
Martin Bauer committed
477
478
479
    all_blocks = []
    get_blocks(ast_node, all_blocks)
    for block in all_blocks:
Martin Bauer's avatar
Martin Bauer committed
480
        children = block.take_child_nodes()
481
482
483
484
        for child in children:
            if not isinstance(child, ast.SympyAssignment):
                block.append(child)
            else:
Martin Bauer's avatar
Martin Bauer committed
485
                target, child_to_insert_before = find_block_to_move_to(child)
486
487
488
                if target == block:     # movement not possible
                    target.append(child)
                else:
Martin Bauer's avatar
Martin Bauer committed
489
490
491
                    existing_assignment = check_if_assignment_already_in_block(child, target)
                    if not existing_assignment:
                        target.insert_before(child, child_to_insert_before)
492
                    else:
Martin Bauer's avatar
Martin Bauer committed
493
                        assert existing_assignment.rhs == child.rhs, "Symbol with same name exists already"
494
495


Martin Bauer's avatar
Martin Bauer committed
496
def split_inner_loop(ast_node: ast.Node, symbol_groups):
Martin Bauer's avatar
Martin Bauer committed
497
498
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
499

Martin Bauer's avatar
Martin Bauer committed
500
501
502
503
504
    Args:
        ast_node: AST root
        symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
                       updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups
                       and which no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
505
    """
Martin Bauer's avatar
Martin Bauer committed
506
507
508
509
510
511
512
513
514
515
516
517
518
    all_loops = ast_node.atoms(ast.LoopOverCoordinate)
    inner_loop = [l for l in all_loops if l.is_innermost_loop]
    assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
    inner_loop = inner_loop[0]
    assert type(inner_loop.body) is ast.Block
    outer_loop = [l for l in all_loops if l.is_outermost_loop]
    assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
    outer_loop = outer_loop[0]

    symbols_with_temporary_array = OrderedDict()
    assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args)

    assignment_groups = []
Martin Bauer's avatar
Martin Bauer committed
519
    for symbol_group in symbol_groups:
520
        # get all dependent symbols
Martin Bauer's avatar
Martin Bauer committed
521
        symbols_to_process = list(symbol_group)
Martin Bauer's avatar
Martin Bauer committed
522
523
524
525
        symbols_resolved = set()
        while symbols_to_process:
            s = symbols_to_process.pop()
            if s in symbols_resolved:
526
527
                continue

Martin Bauer's avatar
Martin Bauer committed
528
            if s in assignment_map:  # if there is no assignment inside the loop body it is independent already
Martin Bauer's avatar
Martin Bauer committed
529
530
531
                for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
                    if type(new_symbol) is not Field.Access and new_symbol not in symbols_with_temporary_array:
                        symbols_to_process.append(new_symbol)
Martin Bauer's avatar
Martin Bauer committed
532
            symbols_resolved.add(s)
533

Martin Bauer's avatar
Martin Bauer committed
534
        for symbol in symbol_group:
535
536
            if type(symbol) is not Field.Access:
                assert type(symbol) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
537
538
                new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
                symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
539

Martin Bauer's avatar
Martin Bauer committed
540
541
542
543
        assignment_group = []
        for assignment in inner_loop.body.args:
            if assignment.lhs in symbols_resolved:
                new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items())
Martin Bauer's avatar
Martin Bauer committed
544
                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbol_group:
545
                    assert type(assignment.lhs) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
546
547
                    new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
                    new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
548
                else:
Martin Bauer's avatar
Martin Bauer committed
549
550
551
                    new_lhs = assignment.lhs
                assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
        assignment_groups.append(assignment_group)
552

Martin Bauer's avatar
Martin Bauer committed
553
554
    new_loops = [inner_loop.new_loop_with_different_body(ast.Block(group)) for group in assignment_groups]
    inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
555

Martin Bauer's avatar
Martin Bauer committed
556
557
    for tmp_array in symbols_with_temporary_array:
        tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
Martin Bauer's avatar
Martin Bauer committed
558
559
        outer_loop.parent.insert_front(ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop))
        outer_loop.parent.append(ast.TemporaryMemoryFree(tmp_array_pointer))
560
561


Martin Bauer's avatar
Martin Bauer committed
562
def cut_loop(loop_node, cutting_points):
563
    """Cuts loop at given cutting points, that means one loop is transformed into len(cuttingPoints)+1 new loops
Martin Bauer's avatar
Martin Bauer committed
564
    that range from  old_begin to cutting_points[1], ..., cutting_points[-1] to old_end"""
Martin Bauer's avatar
Martin Bauer committed
565
    if loop_node.step != 1:
566
        raise NotImplementedError("Can only split loops that have a step of 1")
Martin Bauer's avatar
Martin Bauer committed
567
568
569
    new_loops = []
    new_start = loop_node.start
    cutting_points = list(cutting_points) + [loop_node.stop]
Martin Bauer's avatar
Martin Bauer committed
570
571
    for new_end in cutting_points:
        if new_end - new_start == 1:
Martin Bauer's avatar
Martin Bauer committed
572
573
574
            new_body = deepcopy(loop_node.body)
            new_body.subs({loop_node.loop_counter_symbol: new_start})
            new_loops.append(new_body)
575
        else:
Martin Bauer's avatar
Martin Bauer committed
576
577
            new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
                                              new_start, new_end, loop_node.step)
Martin Bauer's avatar
Martin Bauer committed
578
            new_loops.append(new_loop)
Martin Bauer's avatar
Martin Bauer committed
579
        new_start = new_end
Martin Bauer's avatar
Martin Bauer committed
580
    loop_node.parent.replace(loop_node, new_loops)
581
582


Martin Bauer's avatar
Martin Bauer committed
583
def is_condition_necessary(condition, pre_condition, symbol):
584
    """
585
586
587
588
589
590
591
592
593
594
    Determines if a logical condition of a single variable is already contained in a stronger pre_condition
    so if from pre_condition follows that condition is always true, then this condition is not necessary

    Args:
        condition: sympy relational of one variable
        pre_condition: logical expression that is known to be true
        symbol: the single symbol of interest

    Returns:
        returns  not (pre_condition => condition) where "=>" is logical implication
595
596
597
598
    """
    from sympy.solvers.inequalities import reduce_rational_inequalities
    from sympy.logic.boolalg import to_dnf

599
600
601
602
603
604
605
    def normalize_relational(e):
        if isinstance(e, sp.Rel):
            return e.func(e.lhs - e.rhs, 0)
        else:
            new_args = [normalize_relational(a) for a in e.args]
            return e.func(*new_args) if new_args else e

Martin Bauer's avatar
Martin Bauer committed
606
    def to_dnf_list(expr):
607
608
        result = to_dnf(expr)
        if isinstance(result, sp.Or):
Martin Bauer's avatar
Martin Bauer committed
609
            return [or_term.args for or_term in result.args]
610
611
612
        elif isinstance(result, sp.And):
            return [result.args]
        else:
613
            return [result]
614

615
616
617
618
    condition = normalize_relational(condition)
    pre_condition = normalize_relational(pre_condition)
    a1 = to_dnf_list(pre_condition)
    a2 = to_dnf_list(condition)
Martin Bauer's avatar
Martin Bauer committed
619
620
    t1 = reduce_rational_inequalities(to_dnf_list(sp.And(condition, pre_condition)), symbol)
    t2 = reduce_rational_inequalities(to_dnf_list(pre_condition), symbol)
621
622
623
    return t1 != t2


Martin Bauer's avatar
Martin Bauer committed
624
def simplify_boolean_expression(expr, single_variable_ranges):
625
626
627
628
629
630
631
632
633
634
635
636
    """Simplification of boolean expression using known ranges of variables
    The singleVariableRanges parameter is a dict mapping a variable name to a sympy logical expression that
    contains only this variable and defines a range for it. For example with a being a symbol
    { a: sp.And(a >=0, a < 10) }
    """
    from sympy.core.relational import Relational
    from sympy.logic.boolalg import to_dnf

    expr = to_dnf(expr)

    def visit(e):
        if isinstance(e, Relational):
637
            symbols = e.atoms(sp.Symbol).intersection(single_variable_ranges.keys())
638
639
            if len(symbols) == 1:
                symbol = symbols.pop()
640
641
                if not is_condition_necessary(e, single_variable_ranges[symbol], symbol):
                    return sp.true
642
643
            return e
        else:
Martin Bauer's avatar
Martin Bauer committed
644
645
            new_args = [visit(a) for a in e.args]
            return e.func(*new_args) if new_args else e
646
647
648
649

    return visit(expr)


Martin Bauer's avatar
Martin Bauer committed
650
def simplify_conditionals(node, loop_conditionals=MappingProxyType({})):
651
652
    """Simplifies/Removes conditions inside loops that depend on the loop counter."""
    if isinstance(node, ast.LoopOverCoordinate):
Martin Bauer's avatar
Martin Bauer committed
653
        ctr_sym = node.loop_counter_symbol
654
        loop_conditionals = loop_conditionals.copy()
Martin Bauer's avatar
Martin Bauer committed
655
        loop_conditionals[ctr_sym] = sp.And(ctr_sym >= node.start, ctr_sym < node.stop)
656
        simplify_conditionals(node.body, loop_conditionals)
657
    elif isinstance(node, ast.Conditional):
Martin Bauer's avatar
Martin Bauer committed
658
659
660
        node.condition_expr = simplify_boolean_expression(node.condition_expr, loop_conditionals)
        simplify_conditionals(node.true_block)
        if node.false_block:
661
            simplify_conditionals(node.false_block, loop_conditionals)
Martin Bauer's avatar
Martin Bauer committed
662
663
664
665
        if node.condition_expr == sp.true:
            node.parent.replace(node, [node.true_block])
        if node.condition_expr == sp.false:
            node.parent.replace(node, [node.false_block] if node.false_block else [])
666
667
    elif isinstance(node, ast.Block):
        for a in list(node.args):
668
            simplify_conditionals(a, loop_conditionals)
669
670
671
672
673
674
    elif isinstance(node, ast.SympyAssignment):
        return node
    else:
        raise ValueError("Can not handle node", type(node))


Martin Bauer's avatar
Martin Bauer committed
675
def cleanup_blocks(node):
676
677
678
679
680
    """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
    if isinstance(node, ast.SympyAssignment):
        return
    elif isinstance(node, ast.Block):
        for a in list(node.args):
Martin Bauer's avatar
Martin Bauer committed
681
            cleanup_blocks(a)
682
683
684
685
686
        if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
            node.parent.replace(node, node.args)
            return
    else:
        for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
687
            cleanup_blocks(a)
688
689


Martin Bauer's avatar
Martin Bauer committed
690
def symbol_name_to_variable_name(symbol_name):
691
    """Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names"""
Martin Bauer's avatar
Martin Bauer committed
692
    return symbol_name.replace("^", "_")
693
694


Martin Bauer's avatar
Martin Bauer committed
695
def type_all_equations(eqs, type_for_symbol):
Martin Bauer's avatar
Martin Bauer committed
696
697
698
699
700
    """
    Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
    Additionally returns sets of all fields which are read/written

    :param eqs: list of equations
Martin Bauer's avatar
Martin Bauer committed
701
    :param type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
Martin Bauer's avatar
Martin Bauer committed
702
703
    :return: ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
              list of equations where symbols have been replaced by typed symbols
Martin Bauer's avatar
Martin Bauer committed
704
    """
Martin Bauer's avatar
Martin Bauer committed
705
706
    if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
        type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
707

Martin Bauer's avatar
Martin Bauer committed
708
709
    fields_written = set()
    fields_read = set()
710

Martin Bauer's avatar
Martin Bauer committed
711
    def process_rhs(term):
712
713
714
715
        """Replaces Symbols by:
            - TypedSymbol if symbol is not a field access
        """
        if isinstance(term, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
716
            fields_read.add(term.field)
717
            return term
718
719
        elif isinstance(term, TypedSymbol):
            return term
720
        elif isinstance(term, sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
721
            return TypedSymbol(symbol_name_to_variable_name(term.name), type_for_symbol[term.name])
722
        else:
Martin Bauer's avatar
Martin Bauer committed
723
724
            new_args = [process_rhs(arg) for arg in term.args]
            return term.func(*new_args) if new_args else term
725

Martin Bauer's avatar
Martin Bauer committed
726
    def process_lhs(term):
727
728
        """Replaces symbol by TypedSymbol and adds field to fieldsWriten"""
        if isinstance(term, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
729
            fields_written.add(term.field)
730
            return term
731
732
        elif isinstance(term, TypedSymbol):
            return term
733
        elif isinstance(term, sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
734
            return TypedSymbol(term.name, type_for_symbol[term.name])
735
736
737
        else:
            assert False, "Expected a symbol as left-hand-side"

Martin Bauer's avatar
Martin Bauer committed
738
739
740
741
742
743
744
745
    def visit(obj):
        if isinstance(obj, list) or isinstance(obj, tuple):
            return [visit(e) for e in obj]
        if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
            new_lhs = process_lhs(obj.lhs)
            new_rhs = process_rhs(obj.rhs)
            return ast.SympyAssignment(new_lhs, new_rhs)
        elif isinstance(obj, ast.Conditional):
Martin Bauer's avatar
Martin Bauer committed
746
747
748
            false_block = None if obj.false_block is None else visit(obj.false_block)
            return ast.Conditional(process_rhs(obj.condition_expr),
                                   true_block=visit(obj.true_block), false_block=false_block)
Martin Bauer's avatar
Martin Bauer committed
749
750
        elif isinstance(obj, ast.Block):
            return ast.Block([visit(e) for e in obj.args])
751
        else:
Martin Bauer's avatar
Martin Bauer committed
752
            return obj
753

Martin Bauer's avatar
Martin Bauer committed
754
    typed_equations = visit(eqs)
755

Martin Bauer's avatar
Martin Bauer committed
756
    return fields_read, fields_written, typed_equations
757
758


Martin Bauer's avatar
Martin Bauer committed
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
def insert_casts(node):
    """Checks the types and inserts casts and pointer arithmetic where necessary

    :param node: the head node of the ast
    :return: modified ast
    """
    def cast(zipped_args_types, target_dtype):
        """
        Adds casts to the arguments if their type differs from the target type
        :param zipped_args_types: a zipped list of args and types
        :param target_dtype: The target data type
        :return: args with possible casts
        """
        casted_args = []
        for argument, data_type in zipped_args_types:
            if data_type.numpy_dtype != target_dtype.numpy_dtype:  # ignoring const
                casted_args.append(cast_func(argument, target_dtype))
            else:
                casted_args.append(argument)
        return casted_args

    def pointer_arithmetic(expr_args):
        """
        Creates a valid pointer arithmetic function
        :param expr_args: Arguments of the add expression
        :return: pointer_arithmetic_func
        """
        pointer = None
        new_args = []
        for arg, data_type in expr_args:
            if data_type.func is PointerType:
                assert pointer is None
                pointer = arg
        for arg, data_type in expr_args:
            if arg != pointer:
                assert data_type.is_int() or data_type.is_uint()
                new_args.append(arg)
        new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
        return pointer_arithmetic_func(pointer, new_args)

    if isinstance(node, sp.AtomicExpr):
        return node
    args = []
    for arg in node.args:
        args.append(insert_casts(arg))
    # TODO indexed, LoopOverCoordinate
    if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
        # TODO optimize pow, don't cast integer on double
        types = [get_type_of_expression(arg) for arg in args]
        assert len(types) > 0
        target = collate_types(types)
        zipped = list(zip(args, types))
        if target.func is PointerType:
            assert node.func is sp.Add
            return pointer_arithmetic(zipped)
        else:
            return node.func(*cast(zipped, target))
    elif node.func is ast.SympyAssignment:
        lhs = args[0]
        rhs = args[1]
        target = get_type_of_expression(lhs)
        if target.func is PointerType:
            return node.func(*args)  # TODO fix, not complete
        else:
            return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
    elif node.func is ast.ResolvedFieldAccess:
        return node
    elif node.func is ast.Block:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is ast.LoopOverCoordinate:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is sp.Piecewise:
        expressions = [expr for (expr, _) in args]
        types = [get_type_of_expression(expr) for expr in expressions]
        target = collate_types(types)
        zipped = list(zip(expressions, types))
        casted_expressions = cast(zipped, target)
        args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_expressions)]

    return node.func(*args)


845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
    """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""

    all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
    assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
    inner_loop = all_inner_loops.pop()

    for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
        cut_loop(loop, [loop.stop-1])

    simplify_conditionals(function_node.body)
    cleanup_blocks(function_node.body)
    move_constants_before_loop(function_node.body)
    cleanup_blocks(function_node.body)


Martin Bauer's avatar
Martin Bauer committed
861
862
863
# --------------------------------------- Helper Functions -------------------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
864
def typing_from_sympy_inspection(eqs, default_type="double"):
Martin Bauer's avatar
Martin Bauer committed
865
866
867
    """
    Creates a default symbol name to type mapping.
    If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
868
869
870
871
872
873

    Args:
        eqs: list of equations
        default_type: the type for non-boolean symbols
    Returns:
        dictionary, mapping symbol name to type
Martin Bauer's avatar
Martin Bauer committed
874
    """
Martin Bauer's avatar
Martin Bauer committed
875
    result = defaultdict(lambda: default_type)
876
    for eq in eqs:
877
878
        if isinstance(eq, ast.Node):
            continue
879
880
881
        # problematic case here is when rhs is a symbol: then it is impossible to decide here without
        # further information what type the left hand side is - default fallback is the dict value then
        if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
882
883
884
885
            result[eq.lhs.name] = "bool"
    return result


Martin Bauer's avatar
Martin Bauer committed
886
def get_next_parent_of_type(node, parent_type):
Martin Bauer's avatar
Martin Bauer committed
887
888
889
    """
    Traverses the AST nodes parents until a parent of given type was found. If no such parent is found, None is returned
    """
890
891
    parent = node.parent
    while parent is not None:
Martin Bauer's avatar
Martin Bauer committed
892
        if isinstance(parent, parent_type):
893
894
895
896
897
            return parent
        parent = parent.parent
    return None


898
899
900
901
902
903
904
905
906
907
def parents_of_type(node, parent_type, include_current=False):
    """Similar to get_next_parent_of_type, but as generator"""
    parent = node if include_current else node.parent
    while parent is not None:
        if isinstance(parent, parent_type):
            yield parent
        parent = parent.parent
    return None


Martin Bauer's avatar
Martin Bauer committed
908
def get_optimal_loop_ordering(fields):
Martin Bauer's avatar
Martin Bauer committed
909
910
911
912
    """
    Determines the optimal loop order for a given set of fields.
    If the fields have different memory layout or different sizes an exception is thrown.
    :param fields: sequence of fields
913
    :return: list of coordinate ids, where the first list entry should be the outermost loop
Martin Bauer's avatar
Martin Bauer committed
914
    """
915
    assert len(fields) > 0
Martin Bauer's avatar
Martin Bauer committed
916
    ref_field = next(iter(fields))
917
    for field in fields:
Martin Bauer's avatar
Martin Bauer committed
918
        if field.spatial_dimensions != ref_field.spatial_dimensions:
919
            raise ValueError("All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
Martin Bauer's avatar
Martin Bauer committed
920
                             + str({f.name: f.spatial_shape for f in fields}))
921
922
923

    layouts = set([field.layout for field in fields])
    if len(layouts) > 1:
924
925
        raise ValueError("Due to different layout of the fields no optimal loop ordering exists " +
                         str({f.name: f.layout for f in fields}))
926
    layout = list(layouts)[0]
927
    return list(layout)
928
929


Martin Bauer's avatar
Martin Bauer committed
930
def get_loop_hierarchy(ast_node):
931
932
933
934
    """Determines the loop structure around a given AST node, i.e. the node has to be inside the loops.

    Returns:
        sequence of LoopOverCoordinate nodes, starting from outer loop to innermost loop
Martin Bauer's avatar
Martin Bauer committed
935
    """
936
    result = []
Martin Bauer's avatar
Martin Bauer committed
937
    node = ast_node
938
    while node is not None:
Martin Bauer's avatar
Martin Bauer committed
939
        node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
940
        if node:
Martin Bauer's avatar
Martin Bauer committed
941
            result.append(node.coordinate_to_loop_over)
942
    return reversed(result)