transformations.py 40.6 KB
Newer Older
1
import warnings
2
from collections import defaultdict, OrderedDict
3
from copy import deepcopy
Martin Bauer's avatar
Martin Bauer committed
4
from types import MappingProxyType
5
6
import sympy as sp
from sympy.logic.boolalg import Boolean
7
from sympy.tensor import IndexedBase
8
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
9
from pystencils.field import Field, FieldType, offset_component_to_direction_string
Martin Bauer's avatar
Martin Bauer committed
10
11
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \
    pointer_arithmetic_func, get_type_of_expression, collate_types
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.slicing import normalize_slice
Martin Bauer's avatar
Martin Bauer committed
13
import pystencils.astnodes as ast
14
15


Martin Bauer's avatar
Martin Bauer committed
16
def filtered_tree_iteration(node, node_type):
17
    for arg in node.args:
Martin Bauer's avatar
Martin Bauer committed
18
        if isinstance(arg, node_type):
19
            yield arg
Martin Bauer's avatar
Martin Bauer committed
20
        yield from filtered_tree_iteration(arg, node_type)
21
22


23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def unify_shape_symbols(body, common_shape, fields):
    """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.

    When creating a kernel with variable array sizes, all passed arrays must have the same size.
    This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead
    of one for each field. For example shape_arr1[0]  and shape_arr2[0] must be equal, so they should also be
    represented by the same symbol.

    Args:
        body: ast node, for the kernel part where substitutions is made, is modified in-place
        common_shape: shape of the field that was chosen
        fields: all fields whose shapes should be replaced by common_shape
    """
    substitutions = {}
    for field in fields:
        assert len(field.spatial_shape) == len(common_shape)
        if not field.has_fixed_shape:
            for common_shape_component, shape_component in zip(common_shape, field.spatial_shape):
                if shape_component != common_shape_component:
                    substitutions[shape_component] = common_shape_component
    if substitutions:
        body.subs(substitutions)


Martin Bauer's avatar
Martin Bauer committed
47
def get_common_shape(field_set):
48
49
    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
    ValueError is raised"""
Martin Bauer's avatar
Martin Bauer committed
50
51
52
53
54
55
56
57
    nr_of_fixed_shaped_fields = 0
    for f in field_set:
        if f.has_fixed_shape:
            nr_of_fixed_shaped_fields += 1

    if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set):
        fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
        var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
58
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
Martin Bauer's avatar
Martin Bauer committed
59
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (var_field_names, fixed_field_names)
60
61
        raise ValueError(msg)

Martin Bauer's avatar
Martin Bauer committed
62
63
64
65
    shape_set = set([f.spatial_shape for f in field_set])
    if nr_of_fixed_shaped_fields == len(field_set):
        if len(shape_set) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
66

Martin Bauer's avatar
Martin Bauer committed
67
    shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
68
69
70
    return shape


Martin Bauer's avatar
Martin Bauer committed
71
72
73
74
def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layers=None, loop_order=None):
    """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.

    Args:
75
        body: Block object with inner loop contents
Martin Bauer's avatar
Martin Bauer committed
76
77
78
79
80
81
82
83
84
        function_name: name of generated C function
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
             if None, the number of ghost layers is determined automatically and assumed to be equal for a
             all dimensions
        loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)

    Returns:
        :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
85
86
    """
    # find correct ordering by inspecting participating FieldAccesses
Martin Bauer's avatar
Martin Bauer committed
87
88
89
90
91
92
93
94
95
    field_accesses = body.atoms(Field.Access)
    # exclude accesses to buffers from field_list, because buffers are treated separately
    field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)]
    fields = set(field_list)
    num_buffer_accesses = len(field_accesses) - len(field_list)

    if loop_order is None:
        loop_order = get_optimal_loop_ordering(fields)

96
97
    shape = get_common_shape(fields)
    unify_shape_symbols(body, common_shape=shape, fields=fields)
Martin Bauer's avatar
Martin Bauer committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

    if iteration_slice is not None:
        iteration_slice = normalize_slice(iteration_slice, shape)

    if ghost_layers is None:
        required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
        ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
    if isinstance(ghost_layers, int):
        ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)

    def get_loop_stride(loop_begin, loop_end, step):
        return (loop_end - loop_begin) / step

    loop_strides = []
    loop_vars = []
    current_body = body
Martin Bauer's avatar
Martin Bauer committed
114
    for i, loop_coordinate in enumerate(reversed(loop_order)):
Martin Bauer's avatar
Martin Bauer committed
115
        if iteration_slice is None:
Martin Bauer's avatar
Martin Bauer committed
116
117
118
            begin = ghost_layers[loop_coordinate][0]
            end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
            new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
Martin Bauer's avatar
Martin Bauer committed
119
120
121
            current_body = ast.Block([new_loop])
            loop_strides.append(get_loop_stride(begin, end, 1))
            loop_vars.append(new_loop.loop_counter_symbol)
Martin Bauer's avatar
Martin Bauer committed
122
        else:
Martin Bauer's avatar
Martin Bauer committed
123
            slice_component = iteration_slice[loop_coordinate]
Martin Bauer's avatar
Martin Bauer committed
124
125
            if type(slice_component) is slice:
                sc = slice_component
Martin Bauer's avatar
Martin Bauer committed
126
                new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
Martin Bauer's avatar
Martin Bauer committed
127
128
129
                current_body = ast.Block([new_loop])
                loop_strides.append(get_loop_stride(sc.start, sc.stop, sc.step))
                loop_vars.append(new_loop.loop_counter_symbol)
Martin Bauer's avatar
Martin Bauer committed
130
            else:
Martin Bauer's avatar
Martin Bauer committed
131
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
Martin Bauer's avatar
Martin Bauer committed
132
133
                                                 sp.sympify(slice_component))
                current_body.insert_front(assignment)
134

Martin Bauer's avatar
Martin Bauer committed
135
136
137
    loop_vars = [num_buffer_accesses * var for var in loop_vars]
    ast_node = ast.KernelFunction(current_body, ghost_layers=ghost_layers, function_name=function_name, backend='cpu')
    return ast_node, loop_strides, loop_vars
138
139


Martin Bauer's avatar
Martin Bauer committed
140
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
Martin Bauer's avatar
Martin Bauer committed
141
142
143
144
145
146
    r"""
    Addressing elements in structured arrays are done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.
Martin Bauer's avatar
Martin Bauer committed
147
    :param field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
Martin Bauer's avatar
Martin Bauer committed
148
    :param coordinates: mapping of coordinate ids to its value, where stride*value is calculated
Martin Bauer's avatar
Martin Bauer committed
149
    :param previous_ptr: the pointer which is de-referenced
Martin Bauer's avatar
Martin Bauer committed
150
151
152
    :return: tuple with the new pointer symbol and the calculated offset

    Example:
Martin Bauer's avatar
Martin Bauer committed
153
        >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
Martin Bauer's avatar
Martin Bauer committed
154
        >>> x, y = sp.symbols("x y")
Martin Bauer's avatar
Martin Bauer committed
155
156
        >>> prev_pointer = TypedSymbol("ptr", "double")
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
Martin Bauer's avatar
Martin Bauer committed
157
        (ptr_E, x*fstride_myfield[0] + fstride_myfield[0])
Martin Bauer's avatar
Martin Bauer committed
158
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
Martin Bauer's avatar
Martin Bauer committed
159
160
        (ptr_E_2S, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1])
    """
Martin Bauer's avatar
Martin Bauer committed
161
    field = field_access.field
162
163
    offset = 0
    name = ""
Martin Bauer's avatar
Martin Bauer committed
164
    list_to_hash = []
Martin Bauer's avatar
Martin Bauer committed
165
166
    for coordinate_id, coordinate_value in coordinates.items():
        offset += field.strides[coordinate_id] * coordinate_value
167

Martin Bauer's avatar
Martin Bauer committed
168
169
170
171
        if coordinate_id < field.spatial_dimensions:
            offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
            if type(field_access.offsets[coordinate_id]) is int:
                offset_comp = offset_component_to_direction_string(coordinate_id, field_access.offsets[coordinate_id])
172
                name += "_"
Martin Bauer's avatar
Martin Bauer committed
173
                name += offset_comp if offset_comp else "C"
174
            else:
Martin Bauer's avatar
Martin Bauer committed
175
                list_to_hash.append(field_access.offsets[coordinate_id])
176
        else:
Martin Bauer's avatar
Martin Bauer committed
177
178
            if type(coordinate_value) is int:
                name += "_%d" % (coordinate_value,)
179
            else:
Martin Bauer's avatar
Martin Bauer committed
180
                list_to_hash.append(coordinate_value)
181

Martin Bauer's avatar
Martin Bauer committed
182
183
    if len(list_to_hash) > 0:
        name += "%0.6X" % (abs(hash(tuple(list_to_hash))))
184

Martin Bauer's avatar
Martin Bauer committed
185
    new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
186

Martin Bauer's avatar
Martin Bauer committed
187
    return new_ptr, offset
188
189


Martin Bauer's avatar
Martin Bauer committed
190
def parse_base_pointer_info(base_pointer_specification, loop_order, field):
191
    """
Martin Bauer's avatar
Martin Bauer committed
192
    Creates base pointer specification for :func:`resolve_field_accesses` function.
Martin Bauer's avatar
Martin Bauer committed
193
194
195
196
197
198

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
    zero directly in the field access. These specifications are more sensible defined dependent on the loop ordering.
    This function translates more readable version into the specification above.

199
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
200
201
202
203
204
205
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

Martin Bauer's avatar
Martin Bauer committed
206
207
208
209
210
211
212
    Args:
        base_pointer_specification: nested list with above specifications
        loop_order: list with ordering of loops from outer to inner
        field:

    Returns:
        list of tuples that can be passed to :func:`resolve_field_accesses`
213
214
    """
    result = []
Martin Bauer's avatar
Martin Bauer committed
215
216
    specified_coordinates = set()
    loop_order = list(reversed(loop_order))
Martin Bauer's avatar
Martin Bauer committed
217
    for spec_group in base_pointer_specification:
Martin Bauer's avatar
Martin Bauer committed
218
219
220
221
222
223
224
225
226
        new_group = []

        def add_new_element(elem):
            if elem >= field.spatial_dimensions + field.index_dimensions:
                raise ValueError("Coordinate %d does not exist" % (elem,))
            new_group.append(elem)
            if elem in specified_coordinates:
                raise ValueError("Coordinate %d specified two times" % (elem,))
            specified_coordinates.add(elem)
Martin Bauer's avatar
Martin Bauer committed
227
        for element in spec_group:
228
            if type(element) is int:
Martin Bauer's avatar
Martin Bauer committed
229
                add_new_element(element)
230
231
232
233
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
Martin Bauer's avatar
Martin Bauer committed
234
                    add_new_element(loop_order[index])
235
236
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
Martin Bauer's avatar
Martin Bauer committed
237
                    add_new_element(loop_order[-index])
238
                elif element == "all":
Martin Bauer's avatar
Martin Bauer committed
239
240
                    for i in range(field.spatial_dimensions):
                        add_new_element(i)
241
242
243
244
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
Martin Bauer's avatar
Martin Bauer committed
245
                add_new_element(field.spatial_dimensions + index)
246
247
248
            else:
                raise ValueError("Unknown specification %s" % (element,))

Martin Bauer's avatar
Martin Bauer committed
249
        result.append(new_group)
250

Martin Bauer's avatar
Martin Bauer committed
251
252
    all_coordinates = set(range(field.spatial_dimensions + field.index_dimensions))
    rest = all_coordinates - specified_coordinates
253
254
    if rest:
        result.append(list(rest))
255

256
257
258
    return result


Martin Bauer's avatar
Martin Bauer committed
259
260
def substitute_array_accesses_with_constants(ast_node):
    """Substitutes all instances of Indexed (array accesses) that are not field accesses with constants.
Martin Bauer's avatar
Martin Bauer committed
261
262
263
    Benchmarks showed that using an array access as loop bound or in pointer computations cause some compilers to do
    less optimizations.
    This transformation should be after field accesses have been resolved (since they introduce array accesses) and
264
265
266
    before constants are moved before the loops.
    """

Martin Bauer's avatar
Martin Bauer committed
267
    def handle_sympy_expression(expr, parent_block):
268
269
270
271
272
273
        """Returns sympy expression where array accesses have been replaced with constants, together with a list
        of assignments that define these constants"""
        if not isinstance(expr, sp.Expr):
            return expr

        # get all indexed expressions that are not field accesses
Martin Bauer's avatar
Martin Bauer committed
274
        indexed_expressions = [e for e in expr.atoms(sp.Indexed) if not isinstance(e, ast.ResolvedFieldAccess)]
275
276

        # special case: right hand side is a single indexed expression, then nothing has to be done
Martin Bauer's avatar
Martin Bauer committed
277
        if len(indexed_expressions) == 1 and expr == indexed_expressions[0]:
278
279
            return expr

Martin Bauer's avatar
Martin Bauer committed
280
281
        constants_definitions = []
        constant_substitutions = {}
Martin Bauer's avatar
Martin Bauer committed
282
283
        for indexed_expr in indexed_expressions:
            base, idx = indexed_expr.args
Martin Bauer's avatar
Martin Bauer committed
284
285
286
287
            typed_symbol = base.args[0]
            base_type = deepcopy(get_base_type(typed_symbol.dtype))
            base_type.const = False
            constant_replacing_indexed = TypedSymbol(typed_symbol.name + str(idx), base_type)
Martin Bauer's avatar
Martin Bauer committed
288
289
            constants_definitions.append(ast.SympyAssignment(constant_replacing_indexed, indexed_expr))
            constant_substitutions[indexed_expr] = constant_replacing_indexed
Martin Bauer's avatar
Martin Bauer committed
290
291
292
        constants_definitions.sort(key=lambda e: e.lhs.name)

        already_defined = parent_block.symbols_defined
Martin Bauer's avatar
Martin Bauer committed
293
294
295
        for new_assignment in constants_definitions:
            if new_assignment.lhs not in already_defined:
                parent_block.insert_before(new_assignment, ast_node)
Martin Bauer's avatar
Martin Bauer committed
296
297
298
299
300
301
302
303
304
305
306

        return expr.subs(constant_substitutions)

    if isinstance(ast_node, ast.SympyAssignment):
        ast_node.rhs = handle_sympy_expression(ast_node.rhs, ast_node.parent)
        ast_node.lhs = handle_sympy_expression(ast_node.lhs, ast_node.parent)
    elif isinstance(ast_node, ast.LoopOverCoordinate):
        ast_node.start = handle_sympy_expression(ast_node.start, ast_node.parent)
        ast_node.stop = handle_sympy_expression(ast_node.stop, ast_node.parent)
        ast_node.step = handle_sympy_expression(ast_node.step, ast_node.parent)
        substitute_array_accesses_with_constants(ast_node.body)
307
    else:
Martin Bauer's avatar
Martin Bauer committed
308
309
        for a in ast_node.args:
            substitute_array_accesses_with_constants(a)
310

Martin Bauer's avatar
Martin Bauer committed
311

Martin Bauer's avatar
Martin Bauer committed
312
313
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
314
        if isinstance(expr, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
315
            field_access = expr
316
317

            # Do not apply transformation if field is not a buffer
Martin Bauer's avatar
Martin Bauer committed
318
            if not FieldType.is_buffer(field_access.field):
319
320
                return expr

Martin Bauer's avatar
Martin Bauer committed
321
            buffer = field_access.field
322

Martin Bauer's avatar
Martin Bauer committed
323
324
            dtype = PointerType(buffer.dtype, const=buffer.name in read_only_field_names, restrict=True)
            field_ptr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbol_name_to_variable_name(buffer.name)), dtype)
325

Martin Bauer's avatar
Martin Bauer committed
326
327
            buffer_index = base_buffer_index
            if len(field_access.index) > 1:
328
329
                raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')

Martin Bauer's avatar
Martin Bauer committed
330
331
332
            if len(field_access.index) > 0:
                cell_index = field_access.index[0]
                buffer_index += cell_index
333

Martin Bauer's avatar
Martin Bauer committed
334
335
            result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
                                             field_access.index)
336

Martin Bauer's avatar
Martin Bauer committed
337
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
338
339
340
341
        else:
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr

Martin Bauer's avatar
Martin Bauer committed
342
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
343
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
344
345
346
347
348
349
350
351
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
352
        else:
Martin Bauer's avatar
Martin Bauer committed
353
354
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
355

Martin Bauer's avatar
Martin Bauer committed
356
    return visit_node(ast_node)
357

358

Martin Bauer's avatar
Martin Bauer committed
359
def resolve_field_accesses(ast_node, read_only_field_names=set(),
Martin Bauer's avatar
Martin Bauer committed
360
361
                           field_to_base_pointer_info=MappingProxyType({}),
                           field_to_fixed_coordinates=MappingProxyType({})):
Martin Bauer's avatar
Martin Bauer committed
362
363
364
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

Martin Bauer's avatar
Martin Bauer committed
365
366
    :param ast_node: the AST root
    :param read_only_field_names: set of field names which are considered read-only
Martin Bauer's avatar
Martin Bauer committed
367
    :param field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
Martin Bauer's avatar
Martin Bauer committed
368
                                   for details see :func:`parse_base_pointer_info`
Martin Bauer's avatar
Martin Bauer committed
369
    :param field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
Martin Bauer's avatar
Martin Bauer committed
370
371
372
                                    counters to index the field these symbols are used as coordinates
    :return: transformed AST
    """
Martin Bauer's avatar
Martin Bauer committed
373
374
    field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
    field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
375

Martin Bauer's avatar
Martin Bauer committed
376
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
377
        if isinstance(expr, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
378
379
            field_access = expr
            field = field_access.field
380

Martin Bauer's avatar
Martin Bauer committed
381
382
            if field.name in field_to_base_pointer_info:
                base_pointer_info = field_to_base_pointer_info[field.name]
383
            else:
Martin Bauer's avatar
Martin Bauer committed
384
                base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
385

Martin Bauer's avatar
Martin Bauer committed
386
387
            dtype = PointerType(field.dtype, const=field.name in read_only_field_names, restrict=True)
            field_ptr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbol_name_to_variable_name(field.name)), dtype)
388

Martin Bauer's avatar
Martin Bauer committed
389
390
391
392
            def create_coordinate_dict(group_param):
                coordinates = {}
                for e in group_param:
                    if e < field.spatial_dimensions:
Martin Bauer's avatar
Martin Bauer committed
393
                        if field.name in field_to_fixed_coordinates:
Martin Bauer's avatar
Martin Bauer committed
394
                            coordinates[e] = field_to_fixed_coordinates[field.name][e]
395
                        else:
Martin Bauer's avatar
Martin Bauer committed
396
                            ctr_name = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
Martin Bauer's avatar
Martin Bauer committed
397
398
                            coordinates[e] = TypedSymbol("%s_%d" % (ctr_name, e), 'int')
                        coordinates[e] *= field.dtype.item_size
399
                    else:
400
                        if isinstance(field.dtype, StructType):
Martin Bauer's avatar
Martin Bauer committed
401
                            assert field.index_dimensions == 1
Martin Bauer's avatar
Martin Bauer committed
402
403
                            accessed_field_name = field_access.index[0]
                            assert isinstance(accessed_field_name, str)
Martin Bauer's avatar
Martin Bauer committed
404
                            coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
405
                        else:
Martin Bauer's avatar
Martin Bauer committed
406
                            coordinates[e] = field_access.index[e - field.spatial_dimensions]
407

Martin Bauer's avatar
Martin Bauer committed
408
                return coordinates
409

Martin Bauer's avatar
Martin Bauer committed
410
            last_pointer = field_ptr
411

Martin Bauer's avatar
Martin Bauer committed
412
413
            for group in reversed(base_pointer_info[1:]):
                coord_dict = create_coordinate_dict(group)
Martin Bauer's avatar
Martin Bauer committed
414
                new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
415
416
417
418
                if new_ptr not in enclosing_block.symbols_defined:
                    new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
                    enclosing_block.insert_before(new_assignment, sympy_assignment)
                last_pointer = new_ptr
419

Martin Bauer's avatar
Martin Bauer committed
420
            coord_dict = create_coordinate_dict(base_pointer_info[0])
421

Martin Bauer's avatar
Martin Bauer committed
422
            _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
423
424
            result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
                                             field_access.offsets, field_access.index)
425

Martin Bauer's avatar
Martin Bauer committed
426
427
            if isinstance(get_base_type(field_access.field.dtype), StructType):
                new_type = field_access.field.dtype.get_element_type(field_access.index[0])
Martin Bauer's avatar
Martin Bauer committed
428
                result = cast_func(result, new_type)
429

Martin Bauer's avatar
Martin Bauer committed
430
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
431
        else:
Martin Bauer's avatar
Martin Bauer committed
432
433
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr
434

Martin Bauer's avatar
Martin Bauer committed
435
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
436
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
437
438
439
440
441
442
443
444
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
445
        else:
Martin Bauer's avatar
Martin Bauer committed
446
447
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
448

Martin Bauer's avatar
Martin Bauer committed
449
    return visit_node(ast_node)
450
451


Martin Bauer's avatar
Martin Bauer committed
452
def move_constants_before_loop(ast_node):
Martin Bauer's avatar
Martin Bauer committed
453
454
    """
    Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.
Martin Bauer's avatar
Martin Bauer committed
455
456
    Call this after creating the loop structure with :func:`make_loop_over_domain`
    :param ast_node:
Martin Bauer's avatar
Martin Bauer committed
457
458
    :return:
    """
Martin Bauer's avatar
Martin Bauer committed
459
    def find_block_to_move_to(node):
Martin Bauer's avatar
Martin Bauer committed
460
461
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
462
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
463
464
465
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
466
467
468
        assert isinstance(node, ast.SympyAssignment)
        assert isinstance(node.parent, ast.Block)

Martin Bauer's avatar
Martin Bauer committed
469
470
        last_block = node.parent
        last_block_child = node
471
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
472
        prev_element = node
473
474
        while element:
            if isinstance(element, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
475
476
                last_block = element
                last_block_child = prev_element
477
478

            if isinstance(element, ast.Conditional):
Martin Bauer's avatar
Martin Bauer committed
479
                critical_symbols = element.condition_expr.atoms(sp.Symbol)
480
            else:
Martin Bauer's avatar
Martin Bauer committed
481
482
                critical_symbols = element.symbols_defined
            if node.undefined_symbols.intersection(critical_symbols):
483
                break
Martin Bauer's avatar
Martin Bauer committed
484
            prev_element = element
485
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
486
        return last_block, last_block_child
487

Martin Bauer's avatar
Martin Bauer committed
488
489
    def check_if_assignment_already_in_block(assignment, target_block):
        for arg in target_block.args:
490
491
492
493
494
495
            if type(arg) is not ast.SympyAssignment:
                continue
            if arg.lhs == assignment.lhs:
                return arg
        return None

Martin Bauer's avatar
Martin Bauer committed
496
    def get_blocks(node, result_list):
497
        if isinstance(node, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
498
            result_list.insert(0, node)
499
500
        if isinstance(node, ast.Node):
            for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
501
                get_blocks(a, result_list)
502

Martin Bauer's avatar
Martin Bauer committed
503
504
505
    all_blocks = []
    get_blocks(ast_node, all_blocks)
    for block in all_blocks:
Martin Bauer's avatar
Martin Bauer committed
506
        children = block.take_child_nodes()
507
508
509
510
        for child in children:
            if not isinstance(child, ast.SympyAssignment):
                block.append(child)
            else:
Martin Bauer's avatar
Martin Bauer committed
511
                target, child_to_insert_before = find_block_to_move_to(child)
512
513
514
                if target == block:     # movement not possible
                    target.append(child)
                else:
Martin Bauer's avatar
Martin Bauer committed
515
516
517
                    existing_assignment = check_if_assignment_already_in_block(child, target)
                    if not existing_assignment:
                        target.insert_before(child, child_to_insert_before)
518
                    else:
Martin Bauer's avatar
Martin Bauer committed
519
                        assert existing_assignment.rhs == child.rhs, "Symbol with same name exists already"
520
521


Martin Bauer's avatar
Martin Bauer committed
522
def split_inner_loop(ast_node: ast.Node, symbol_groups):
Martin Bauer's avatar
Martin Bauer committed
523
524
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
525

Martin Bauer's avatar
Martin Bauer committed
526
527
528
529
530
    Args:
        ast_node: AST root
        symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
                       updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups
                       and which no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
531
    """
Martin Bauer's avatar
Martin Bauer committed
532
533
534
535
536
537
538
539
540
541
542
543
544
    all_loops = ast_node.atoms(ast.LoopOverCoordinate)
    inner_loop = [l for l in all_loops if l.is_innermost_loop]
    assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
    inner_loop = inner_loop[0]
    assert type(inner_loop.body) is ast.Block
    outer_loop = [l for l in all_loops if l.is_outermost_loop]
    assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
    outer_loop = outer_loop[0]

    symbols_with_temporary_array = OrderedDict()
    assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args)

    assignment_groups = []
Martin Bauer's avatar
Martin Bauer committed
545
    for symbol_group in symbol_groups:
546
        # get all dependent symbols
Martin Bauer's avatar
Martin Bauer committed
547
        symbols_to_process = list(symbol_group)
Martin Bauer's avatar
Martin Bauer committed
548
549
550
551
        symbols_resolved = set()
        while symbols_to_process:
            s = symbols_to_process.pop()
            if s in symbols_resolved:
552
553
                continue

Martin Bauer's avatar
Martin Bauer committed
554
            if s in assignment_map:  # if there is no assignment inside the loop body it is independent already
Martin Bauer's avatar
Martin Bauer committed
555
556
557
                for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
                    if type(new_symbol) is not Field.Access and new_symbol not in symbols_with_temporary_array:
                        symbols_to_process.append(new_symbol)
Martin Bauer's avatar
Martin Bauer committed
558
            symbols_resolved.add(s)
559

Martin Bauer's avatar
Martin Bauer committed
560
        for symbol in symbol_group:
561
562
            if type(symbol) is not Field.Access:
                assert type(symbol) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
563
564
                new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
                symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
565

Martin Bauer's avatar
Martin Bauer committed
566
567
568
569
        assignment_group = []
        for assignment in inner_loop.body.args:
            if assignment.lhs in symbols_resolved:
                new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items())
Martin Bauer's avatar
Martin Bauer committed
570
                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbol_group:
571
                    assert type(assignment.lhs) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
572
573
                    new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
                    new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
574
                else:
Martin Bauer's avatar
Martin Bauer committed
575
576
577
                    new_lhs = assignment.lhs
                assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
        assignment_groups.append(assignment_group)
578

Martin Bauer's avatar
Martin Bauer committed
579
580
    new_loops = [inner_loop.new_loop_with_different_body(ast.Block(group)) for group in assignment_groups]
    inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
581

Martin Bauer's avatar
Martin Bauer committed
582
583
    for tmp_array in symbols_with_temporary_array:
        tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
Martin Bauer's avatar
Martin Bauer committed
584
585
        outer_loop.parent.insert_front(ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop))
        outer_loop.parent.append(ast.TemporaryMemoryFree(tmp_array_pointer))
586
587


Martin Bauer's avatar
Martin Bauer committed
588
def cut_loop(loop_node, cutting_points):
589
    """Cuts loop at given cutting points, that means one loop is transformed into len(cuttingPoints)+1 new loops
Martin Bauer's avatar
Martin Bauer committed
590
    that range from  old_begin to cutting_points[1], ..., cutting_points[-1] to old_end"""
Martin Bauer's avatar
Martin Bauer committed
591
    if loop_node.step != 1:
592
        raise NotImplementedError("Can only split loops that have a step of 1")
Martin Bauer's avatar
Martin Bauer committed
593
594
595
    new_loops = []
    new_start = loop_node.start
    cutting_points = list(cutting_points) + [loop_node.stop]
Martin Bauer's avatar
Martin Bauer committed
596
597
    for new_end in cutting_points:
        if new_end - new_start == 1:
Martin Bauer's avatar
Martin Bauer committed
598
599
600
            new_body = deepcopy(loop_node.body)
            new_body.subs({loop_node.loop_counter_symbol: new_start})
            new_loops.append(new_body)
601
        else:
Martin Bauer's avatar
Martin Bauer committed
602
603
            new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
                                              new_start, new_end, loop_node.step)
Martin Bauer's avatar
Martin Bauer committed
604
            new_loops.append(new_loop)
Martin Bauer's avatar
Martin Bauer committed
605
        new_start = new_end
Martin Bauer's avatar
Martin Bauer committed
606
    loop_node.parent.replace(loop_node, new_loops)
607
608


609
610
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool=False) -> None:
    """Removes conditionals that are always true/false.
611
612

    Args:
613
614
615
616
617
618
        node: ast node, all descendants of this node are simplified
        loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false
                                     depending on the surrounding loop. For example if the surrounding loop goes from
                                     x=0 to 10 and the condition is x < 0, it is removed.
                                     This analysis needs the integer set library (ISL) islpy, so it is not done by
                                     default.
619
    """
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
    for conditional in node.atoms(ast.Conditional):
        conditional.condition_expr = sp.simplify(conditional.condition_expr)
        if conditional.condition_expr == sp.true:
            conditional.parent.replace(conditional, [conditional.true_block])
        elif conditional.condition_expr == sp.false:
            conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
        elif loop_counter_simplification:
            try:
                # noinspection PyUnresolvedReferences
                from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
                simplify_loop_counter_dependent_conditional(conditional)
            except ImportError:
                warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")


def cleanup_blocks(node: ast.Node) -> None:
636
637
638
639
640
    """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
    if isinstance(node, ast.SympyAssignment):
        return
    elif isinstance(node, ast.Block):
        for a in list(node.args):
Martin Bauer's avatar
Martin Bauer committed
641
            cleanup_blocks(a)
642
643
644
645
646
        if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
            node.parent.replace(node, node.args)
            return
    else:
        for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
647
            cleanup_blocks(a)
648
649


Martin Bauer's avatar
Martin Bauer committed
650
def symbol_name_to_variable_name(symbol_name):
651
    """Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names"""
Martin Bauer's avatar
Martin Bauer committed
652
    return symbol_name.replace("^", "_")
653
654


Martin Bauer's avatar
Martin Bauer committed
655
def type_all_equations(eqs, type_for_symbol):
Martin Bauer's avatar
Martin Bauer committed
656
657
658
659
660
    """
    Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
    Additionally returns sets of all fields which are read/written

    :param eqs: list of equations
Martin Bauer's avatar
Martin Bauer committed
661
    :param type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
Martin Bauer's avatar
Martin Bauer committed
662
663
    :return: ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
              list of equations where symbols have been replaced by typed symbols
Martin Bauer's avatar
Martin Bauer committed
664
    """
Martin Bauer's avatar
Martin Bauer committed
665
666
    if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
        type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
667

Martin Bauer's avatar
Martin Bauer committed
668
669
    fields_written = set()
    fields_read = set()
670

Martin Bauer's avatar
Martin Bauer committed
671
    def process_rhs(term):
672
673
674
675
        """Replaces Symbols by:
            - TypedSymbol if symbol is not a field access
        """
        if isinstance(term, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
676
            fields_read.add(term.field)
677
            return term
678
679
        elif isinstance(term, TypedSymbol):
            return term
680
        elif isinstance(term, sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
681
            return TypedSymbol(symbol_name_to_variable_name(term.name), type_for_symbol[term.name])
682
        else:
Martin Bauer's avatar
Martin Bauer committed
683
684
            new_args = [process_rhs(arg) for arg in term.args]
            return term.func(*new_args) if new_args else term
685

Martin Bauer's avatar
Martin Bauer committed
686
    def process_lhs(term):
687
688
        """Replaces symbol by TypedSymbol and adds field to fieldsWriten"""
        if isinstance(term, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
689
            fields_written.add(term.field)
690
            return term
691
692
        elif isinstance(term, TypedSymbol):
            return term
693
        elif isinstance(term, sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
694
            return TypedSymbol(term.name, type_for_symbol[term.name])
695
696
697
        else:
            assert False, "Expected a symbol as left-hand-side"

Martin Bauer's avatar
Martin Bauer committed
698
699
700
701
702
703
704
705
    def visit(obj):
        if isinstance(obj, list) or isinstance(obj, tuple):
            return [visit(e) for e in obj]
        if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
            new_lhs = process_lhs(obj.lhs)
            new_rhs = process_rhs(obj.rhs)
            return ast.SympyAssignment(new_lhs, new_rhs)
        elif isinstance(obj, ast.Conditional):
Martin Bauer's avatar
Martin Bauer committed
706
707
708
            false_block = None if obj.false_block is None else visit(obj.false_block)
            return ast.Conditional(process_rhs(obj.condition_expr),
                                   true_block=visit(obj.true_block), false_block=false_block)
Martin Bauer's avatar
Martin Bauer committed
709
710
        elif isinstance(obj, ast.Block):
            return ast.Block([visit(e) for e in obj.args])
711
        else:
Martin Bauer's avatar
Martin Bauer committed
712
            return obj
713

Martin Bauer's avatar
Martin Bauer committed
714
    typed_equations = visit(eqs)
715

Martin Bauer's avatar
Martin Bauer committed
716
    return fields_read, fields_written, typed_equations
717
718


Martin Bauer's avatar
Martin Bauer committed
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
def insert_casts(node):
    """Checks the types and inserts casts and pointer arithmetic where necessary

    :param node: the head node of the ast
    :return: modified ast
    """
    def cast(zipped_args_types, target_dtype):
        """
        Adds casts to the arguments if their type differs from the target type
        :param zipped_args_types: a zipped list of args and types
        :param target_dtype: The target data type
        :return: args with possible casts
        """
        casted_args = []
        for argument, data_type in zipped_args_types:
            if data_type.numpy_dtype != target_dtype.numpy_dtype:  # ignoring const
                casted_args.append(cast_func(argument, target_dtype))
            else:
                casted_args.append(argument)
        return casted_args

    def pointer_arithmetic(expr_args):
        """
        Creates a valid pointer arithmetic function
        :param expr_args: Arguments of the add expression
        :return: pointer_arithmetic_func
        """
        pointer = None
        new_args = []
        for arg, data_type in expr_args:
            if data_type.func is PointerType:
                assert pointer is None
                pointer = arg
        for arg, data_type in expr_args:
            if arg != pointer:
                assert data_type.is_int() or data_type.is_uint()
                new_args.append(arg)
        new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
        return pointer_arithmetic_func(pointer, new_args)

    if isinstance(node, sp.AtomicExpr):
        return node
    args = []
    for arg in node.args:
        args.append(insert_casts(arg))
    # TODO indexed, LoopOverCoordinate
    if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
        # TODO optimize pow, don't cast integer on double
        types = [get_type_of_expression(arg) for arg in args]
        assert len(types) > 0
        target = collate_types(types)
        zipped = list(zip(args, types))
        if target.func is PointerType:
            assert node.func is sp.Add
            return pointer_arithmetic(zipped)
        else:
            return node.func(*cast(zipped, target))
    elif node.func is ast.SympyAssignment:
        lhs = args[0]
        rhs = args[1]
        target = get_type_of_expression(lhs)
        if target.func is PointerType:
            return node.func(*args)  # TODO fix, not complete
        else:
            return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
    elif node.func is ast.ResolvedFieldAccess:
        return node
    elif node.func is ast.Block:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is ast.LoopOverCoordinate:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is sp.Piecewise:
        expressions = [expr for (expr, _) in args]
        types = [get_type_of_expression(expr) for expr in expressions]
        target = collate_types(types)
        zipped = list(zip(expressions, types))
        casted_expressions = cast(zipped, target)
        args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_expressions)]

    return node.func(*args)


805
806
807
808
809
810
811
812
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
    """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""

    all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
    assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
    inner_loop = all_inner_loops.pop()

    for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
813
        cut_loop(loop, [loop.stop - 1])
814

815
    simplify_conditionals(function_node.body, loop_counter_simplification=True)
816
817
818
819
820
    cleanup_blocks(function_node.body)
    move_constants_before_loop(function_node.body)
    cleanup_blocks(function_node.body)


Martin Bauer's avatar
Martin Bauer committed
821
822
823
# --------------------------------------- Helper Functions -------------------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
824
def typing_from_sympy_inspection(eqs, default_type="double"):
Martin Bauer's avatar
Martin Bauer committed
825
826
827
    """
    Creates a default symbol name to type mapping.
    If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
828
829
830
831
832
833

    Args:
        eqs: list of equations
        default_type: the type for non-boolean symbols
    Returns:
        dictionary, mapping symbol name to type
Martin Bauer's avatar
Martin Bauer committed
834
    """
Martin Bauer's avatar
Martin Bauer committed
835
    result = defaultdict(lambda: default_type)
836
    for eq in eqs:
837
838
        if isinstance(eq, ast.Node):
            continue
839
840
841
        # problematic case here is when rhs is a symbol: then it is impossible to decide here without
        # further information what type the left hand side is - default fallback is the dict value then
        if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
842
843
844
845
            result[eq.lhs.name] = "bool"
    return result


Martin Bauer's avatar
Martin Bauer committed
846
def get_next_parent_of_type(node, parent_type):
847
848
849
850
    """Returns the next parent node of given type or None, if root is reached.

    Traverses the AST nodes parents until a parent of given type was found.
    If no such parent is found, None is returned
Martin Bauer's avatar
Martin Bauer committed
851
    """
852
853
    parent = node.parent
    while parent is not None:
Martin Bauer's avatar
Martin Bauer committed
854
        if isinstance(parent, parent_type):
855
856
857
858
859
            return parent
        parent = parent.parent
    return None


860
def parents_of_type(node, parent_type, include_current=False):
861
    """Generator for all parent nodes of given type"""
862
863
864
865
866
867
868
    parent = node if include_current else node.parent
    while parent is not None:
        if isinstance(parent, parent_type):
            yield parent
        parent = parent.parent


Martin Bauer's avatar
Martin Bauer committed
869
def get_optimal_loop_ordering(fields):
Martin Bauer's avatar
Martin Bauer committed
870
871
872
    """
    Determines the optimal loop order for a given set of fields.
    If the fields have different memory layout or different sizes an exception is thrown.
873
874
875
876
877
878

    Args:
        fields: sequence of fields

    Returns:
        list of coordinate ids, where the first list entry should be the outermost loop
Martin Bauer's avatar
Martin Bauer committed
879
    """
880
    assert len(fields) > 0
Martin Bauer's avatar
Martin Bauer committed
881
    ref_field = next(iter(fields))
882
    for field in fields:
Martin Bauer's avatar
Martin Bauer committed
883
        if field.spatial_dimensions != ref_field.spatial_dimensions:
884
            raise ValueError("All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
Martin Bauer's avatar
Martin Bauer committed
885
                             + str({f.name: f.spatial_shape for f in fields}))
886
887
888

    layouts = set([field.layout for field in fields])
    if len(layouts) > 1:
889
890
        raise ValueError("Due to different layout of the fields no optimal loop ordering exists " +
                         str({f.name: f.layout for f in fields}))
891
    layout = list(layouts)[0]
892
    return list(layout)
893
894


Martin Bauer's avatar
Martin Bauer committed
895
def get_loop_hierarchy(ast_node):
896
897
898
899
    """Determines the loop structure around a given AST node, i.e. the node has to be inside the loops.

    Returns:
        sequence of LoopOverCoordinate nodes, starting from outer loop to innermost loop
Martin Bauer's avatar
Martin Bauer committed
900
    """
901
    result = []
Martin Bauer's avatar
Martin Bauer committed
902
    node = ast_node
903
    while node is not None:
Martin Bauer's avatar
Martin Bauer committed
904
        node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
905
        if node:
Martin Bauer's avatar
Martin Bauer committed
906
            result.append(node.coordinate_to_loop_over)
907
    return reversed(result)