transformations.py 48.4 KB
Newer Older
1
import warnings
2
from collections import defaultdict, OrderedDict, namedtuple
3
from copy import deepcopy
Martin Bauer's avatar
Martin Bauer committed
4
from types import MappingProxyType
5

6
7
import sympy as sp
from sympy.logic.boolalg import Boolean
8
from sympy.tensor import IndexedBase
9
from pystencils.assignment import Assignment
10
from pystencils.assignment_collection.nestedscopes import NestedScopes
Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.field import Field, FieldType
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \
13
    pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
Martin Bauer's avatar
Martin Bauer committed
14
from pystencils.slicing import normalize_slice
Martin Bauer's avatar
Martin Bauer committed
15
import pystencils.astnodes as ast
16
17


Martin Bauer's avatar
Martin Bauer committed
18
def filtered_tree_iteration(node, node_type, stop_type=None):
19
    for arg in node.args:
Martin Bauer's avatar
Martin Bauer committed
20
        if isinstance(arg, node_type):
21
            yield arg
Martin Bauer's avatar
Martin Bauer committed
22
23
24
        elif stop_type and isinstance(node, stop_type):
            continue

Martin Bauer's avatar
Martin Bauer committed
25
        yield from filtered_tree_iteration(arg, node_type)
26
27


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def unify_shape_symbols(body, common_shape, fields):
    """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.

    When creating a kernel with variable array sizes, all passed arrays must have the same size.
    This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead
    of one for each field. For example shape_arr1[0]  and shape_arr2[0] must be equal, so they should also be
    represented by the same symbol.

    Args:
        body: ast node, for the kernel part where substitutions is made, is modified in-place
        common_shape: shape of the field that was chosen
        fields: all fields whose shapes should be replaced by common_shape
    """
    substitutions = {}
    for field in fields:
        assert len(field.spatial_shape) == len(common_shape)
        if not field.has_fixed_shape:
            for common_shape_component, shape_component in zip(common_shape, field.spatial_shape):
                if shape_component != common_shape_component:
                    substitutions[shape_component] = common_shape_component
    if substitutions:
        body.subs(substitutions)


Martin Bauer's avatar
Martin Bauer committed
52
def get_common_shape(field_set):
53
54
    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
    ValueError is raised"""
Martin Bauer's avatar
Martin Bauer committed
55
56
57
58
59
60
61
62
    nr_of_fixed_shaped_fields = 0
    for f in field_set:
        if f.has_fixed_shape:
            nr_of_fixed_shaped_fields += 1

    if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set):
        fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
        var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
63
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
Martin Bauer's avatar
Martin Bauer committed
64
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (var_field_names, fixed_field_names)
65
66
        raise ValueError(msg)

Martin Bauer's avatar
Martin Bauer committed
67
68
69
70
    shape_set = set([f.spatial_shape for f in field_set])
    if nr_of_fixed_shaped_fields == len(field_set):
        if len(shape_set) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
71

Martin Bauer's avatar
Martin Bauer committed
72
    shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
73
74
75
    return shape


Martin Bauer's avatar
Martin Bauer committed
76
77
78
79
def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layers=None, loop_order=None):
    """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.

    Args:
80
        body: Block object with inner loop contents
Martin Bauer's avatar
Martin Bauer committed
81
82
83
84
85
86
87
88
89
        function_name: name of generated C function
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
             if None, the number of ghost layers is determined automatically and assumed to be equal for a
             all dimensions
        loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)

    Returns:
        :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
90
91
    """
    # find correct ordering by inspecting participating FieldAccesses
Martin Bauer's avatar
Martin Bauer committed
92
    field_accesses = body.atoms(Field.Access)
93
94
    field_accesses = {e for e in field_accesses if not e.is_absolute_access}

Martin Bauer's avatar
Martin Bauer committed
95
96
97
98
99
100
101
    # exclude accesses to buffers from field_list, because buffers are treated separately
    field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)]
    fields = set(field_list)

    if loop_order is None:
        loop_order = get_optimal_loop_ordering(fields)

102
103
    shape = get_common_shape(fields)
    unify_shape_symbols(body, common_shape=shape, fields=fields)
Martin Bauer's avatar
Martin Bauer committed
104
105
106
107
108
109
110
111
112
113
114

    if iteration_slice is not None:
        iteration_slice = normalize_slice(iteration_slice, shape)

    if ghost_layers is None:
        required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
        ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
    if isinstance(ghost_layers, int):
        ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)

    current_body = body
Martin Bauer's avatar
Martin Bauer committed
115
    for i, loop_coordinate in enumerate(reversed(loop_order)):
Martin Bauer's avatar
Martin Bauer committed
116
        if iteration_slice is None:
Martin Bauer's avatar
Martin Bauer committed
117
118
119
            begin = ghost_layers[loop_coordinate][0]
            end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
            new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
Martin Bauer's avatar
Martin Bauer committed
120
            current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
121
        else:
Martin Bauer's avatar
Martin Bauer committed
122
            slice_component = iteration_slice[loop_coordinate]
Martin Bauer's avatar
Martin Bauer committed
123
124
            if type(slice_component) is slice:
                sc = slice_component
Martin Bauer's avatar
Martin Bauer committed
125
                new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
Martin Bauer's avatar
Martin Bauer committed
126
                current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
127
            else:
Martin Bauer's avatar
Martin Bauer committed
128
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
Martin Bauer's avatar
Martin Bauer committed
129
130
                                                 sp.sympify(slice_component))
                current_body.insert_front(assignment)
131

Martin Bauer's avatar
Martin Bauer committed
132
    ast_node = ast.KernelFunction(current_body, ghost_layers=ghost_layers, function_name=function_name, backend='cpu')
Martin Bauer's avatar
Martin Bauer committed
133
    return ast_node
134
135


Martin Bauer's avatar
Martin Bauer committed
136
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
Martin Bauer's avatar
Martin Bauer committed
137
    r"""
138
    Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
Martin Bauer's avatar
Martin Bauer committed
139
140
141
142
143
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.

144
145
146
147
148
149
150
151
152
    Args:
        field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
        coordinates: mapping of coordinate ids to its value, where stride*value is calculated
        previous_ptr: the pointer which is de-referenced

    Returns
        tuple with the new pointer symbol and the calculated offset

    Examples:
Martin Bauer's avatar
Martin Bauer committed
153
        >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
Martin Bauer's avatar
Martin Bauer committed
154
        >>> x, y = sp.symbols("x y")
Martin Bauer's avatar
Martin Bauer committed
155
156
        >>> prev_pointer = TypedSymbol("ptr", "double")
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
Martin Bauer's avatar
Martin Bauer committed
157
        (ptr_01, x*fstride_myfield[0] + fstride_myfield[0])
Martin Bauer's avatar
Martin Bauer committed
158
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
Martin Bauer's avatar
Martin Bauer committed
159
        (ptr_01_1m2, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1])
Martin Bauer's avatar
Martin Bauer committed
160
    """
Martin Bauer's avatar
Martin Bauer committed
161
    field = field_access.field
162
163
    offset = 0
    name = ""
Martin Bauer's avatar
Martin Bauer committed
164
    list_to_hash = []
Martin Bauer's avatar
Martin Bauer committed
165
166
    for coordinate_id, coordinate_value in coordinates.items():
        offset += field.strides[coordinate_id] * coordinate_value
167

Martin Bauer's avatar
Martin Bauer committed
168
169
170
        if coordinate_id < field.spatial_dimensions:
            offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
            if type(field_access.offsets[coordinate_id]) is int:
Martin Bauer's avatar
Martin Bauer committed
171
                name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
172
            else:
Martin Bauer's avatar
Martin Bauer committed
173
                list_to_hash.append(field_access.offsets[coordinate_id])
174
        else:
Martin Bauer's avatar
Martin Bauer committed
175
            if type(coordinate_value) is int:
Martin Bauer's avatar
Martin Bauer committed
176
                name += "_%d%d" % (coordinate_id, coordinate_value)
177
            else:
Martin Bauer's avatar
Martin Bauer committed
178
                list_to_hash.append(coordinate_value)
179

Martin Bauer's avatar
Martin Bauer committed
180
    if len(list_to_hash) > 0:
Martin Bauer's avatar
Martin Bauer committed
181
        name += "_%0.6X" % (hash(tuple(list_to_hash)))
182

Martin Bauer's avatar
Martin Bauer committed
183
    name = name.replace("-", 'm')
Martin Bauer's avatar
Martin Bauer committed
184
185
    new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
    return new_ptr, offset
186
187


Martin Bauer's avatar
Martin Bauer committed
188
def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
189
    """
Martin Bauer's avatar
Martin Bauer committed
190
    Creates base pointer specification for :func:`resolve_field_accesses` function.
Martin Bauer's avatar
Martin Bauer committed
191
192
193

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
194
    zero directly in the field access. These specifications are defined dependent on the loop ordering.
Martin Bauer's avatar
Martin Bauer committed
195
196
    This function translates more readable version into the specification above.

197
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
198
199
200
201
202
203
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

Martin Bauer's avatar
Martin Bauer committed
204
205
206
    Args:
        base_pointer_specification: nested list with above specifications
        loop_order: list with ordering of loops from outer to inner
Martin Bauer's avatar
Martin Bauer committed
207
208
        spatial_dimensions: number of spatial dimensions
        index_dimensions: number of index dimensions
Martin Bauer's avatar
Martin Bauer committed
209
210
211

    Returns:
        list of tuples that can be passed to :func:`resolve_field_accesses`
Martin Bauer's avatar
Martin Bauer committed
212
213
214
215
216

    Examples:
        >>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
        ...                         spatial_dimensions=3, index_dimensions=1)
        [[0], [3], [1, 2]]
217
218
    """
    result = []
Martin Bauer's avatar
Martin Bauer committed
219
220
    specified_coordinates = set()
    loop_order = list(reversed(loop_order))
Martin Bauer's avatar
Martin Bauer committed
221
    for spec_group in base_pointer_specification:
Martin Bauer's avatar
Martin Bauer committed
222
223
224
        new_group = []

        def add_new_element(elem):
Martin Bauer's avatar
Martin Bauer committed
225
            if elem >= spatial_dimensions + index_dimensions:
Martin Bauer's avatar
Martin Bauer committed
226
227
228
229
230
                raise ValueError("Coordinate %d does not exist" % (elem,))
            new_group.append(elem)
            if elem in specified_coordinates:
                raise ValueError("Coordinate %d specified two times" % (elem,))
            specified_coordinates.add(elem)
Martin Bauer's avatar
Martin Bauer committed
231
        for element in spec_group:
232
            if type(element) is int:
Martin Bauer's avatar
Martin Bauer committed
233
                add_new_element(element)
234
235
236
237
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
Martin Bauer's avatar
Martin Bauer committed
238
                    add_new_element(loop_order[index])
239
240
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
Martin Bauer's avatar
Martin Bauer committed
241
                    add_new_element(loop_order[-index])
242
                elif element == "all":
Martin Bauer's avatar
Martin Bauer committed
243
                    for i in range(spatial_dimensions):
Martin Bauer's avatar
Martin Bauer committed
244
                        add_new_element(i)
245
246
247
248
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
Martin Bauer's avatar
Martin Bauer committed
249
                add_new_element(spatial_dimensions + index)
250
251
252
            else:
                raise ValueError("Unknown specification %s" % (element,))

Martin Bauer's avatar
Martin Bauer committed
253
        result.append(new_group)
254

Martin Bauer's avatar
Martin Bauer committed
255
    all_coordinates = set(range(spatial_dimensions + index_dimensions))
Martin Bauer's avatar
Martin Bauer committed
256
    rest = all_coordinates - specified_coordinates
257
258
    if rest:
        result.append(list(rest))
259

260
261
262
    return result


Martin Bauer's avatar
Martin Bauer committed
263
264
def substitute_array_accesses_with_constants(ast_node):
    """Substitutes all instances of Indexed (array accesses) that are not field accesses with constants.
Martin Bauer's avatar
Martin Bauer committed
265
266
267
    Benchmarks showed that using an array access as loop bound or in pointer computations cause some compilers to do
    less optimizations.
    This transformation should be after field accesses have been resolved (since they introduce array accesses) and
268
269
270
    before constants are moved before the loops.
    """

Martin Bauer's avatar
Martin Bauer committed
271
    def handle_sympy_expression(expr, parent_block):
272
273
274
275
276
277
        """Returns sympy expression where array accesses have been replaced with constants, together with a list
        of assignments that define these constants"""
        if not isinstance(expr, sp.Expr):
            return expr

        # get all indexed expressions that are not field accesses
Martin Bauer's avatar
Martin Bauer committed
278
        indexed_expressions = [e for e in expr.atoms(sp.Indexed) if not isinstance(e, ast.ResolvedFieldAccess)]
279
280
        if len(indexed_expressions) == 0:
            return expr
281
282

        # special case: right hand side is a single indexed expression, then nothing has to be done
Martin Bauer's avatar
Martin Bauer committed
283
        if len(indexed_expressions) == 1 and expr == indexed_expressions[0]:
284
285
            return expr

Martin Bauer's avatar
Martin Bauer committed
286
287
        constants_definitions = []
        constant_substitutions = {}
Martin Bauer's avatar
Martin Bauer committed
288
289
        for indexed_expr in indexed_expressions:
            base, idx = indexed_expr.args
Martin Bauer's avatar
Martin Bauer committed
290
291
292
293
            typed_symbol = base.args[0]
            base_type = deepcopy(get_base_type(typed_symbol.dtype))
            base_type.const = False
            constant_replacing_indexed = TypedSymbol(typed_symbol.name + str(idx), base_type)
Martin Bauer's avatar
Martin Bauer committed
294
295
            constants_definitions.append(ast.SympyAssignment(constant_replacing_indexed, indexed_expr))
            constant_substitutions[indexed_expr] = constant_replacing_indexed
Martin Bauer's avatar
Martin Bauer committed
296
297
298
        constants_definitions.sort(key=lambda e: e.lhs.name)

        already_defined = parent_block.symbols_defined
Martin Bauer's avatar
Martin Bauer committed
299
300
301
        for new_assignment in constants_definitions:
            if new_assignment.lhs not in already_defined:
                parent_block.insert_before(new_assignment, ast_node)
Martin Bauer's avatar
Martin Bauer committed
302
303
304
305
306
307
308
309
310
311
312

        return expr.subs(constant_substitutions)

    if isinstance(ast_node, ast.SympyAssignment):
        ast_node.rhs = handle_sympy_expression(ast_node.rhs, ast_node.parent)
        ast_node.lhs = handle_sympy_expression(ast_node.lhs, ast_node.parent)
    elif isinstance(ast_node, ast.LoopOverCoordinate):
        ast_node.start = handle_sympy_expression(ast_node.start, ast_node.parent)
        ast_node.stop = handle_sympy_expression(ast_node.stop, ast_node.parent)
        ast_node.step = handle_sympy_expression(ast_node.step, ast_node.parent)
        substitute_array_accesses_with_constants(ast_node.body)
313
    else:
Martin Bauer's avatar
Martin Bauer committed
314
315
        for a in ast_node.args:
            substitute_array_accesses_with_constants(a)
316

Martin Bauer's avatar
Martin Bauer committed
317

Martin Bauer's avatar
Martin Bauer committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
    """Used for buffer fields to determine the linearized index of the buffer dependent on loop counter symbols.

    Args:
        ast_node: ast before any field accesses are resolved
        loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
                       for GPU kernels: list of 'loop counters' from inner to outer loop
        loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default

    Returns:
        base buffer index - required by 'resolve_buffer_accesses' function
    """
    if loop_counters is None or loop_iterations is None:
        loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
        loops.reverse()
        parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
        assert len(loops) == len(parents_of_innermost_loop)
        assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))

        loop_iterations = [(l.stop - l.start) / l.step for l in loops]
        loop_counters = [l.loop_counter_symbol for l in loops]

    field_accesses = ast_node.atoms(Field.Access)
    buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
    loop_counters = [v * len(buffer_accesses) for v in loop_counters]

    base_buffer_index = loop_counters[0]
    stride = 1
    for idx, var in enumerate(loop_counters[1:]):
        cur_stride = loop_iterations[idx]
        stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
        base_buffer_index += var * stride
    return base_buffer_index


Martin Bauer's avatar
Martin Bauer committed
353
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
Martin Bauer's avatar
Martin Bauer committed
354

Martin Bauer's avatar
Martin Bauer committed
355
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
356
        if isinstance(expr, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
357
            field_access = expr
358
359

            # Do not apply transformation if field is not a buffer
Martin Bauer's avatar
Martin Bauer committed
360
            if not FieldType.is_buffer(field_access.field):
361
362
                return expr

Martin Bauer's avatar
Martin Bauer committed
363
            buffer = field_access.field
364

365
            dtype = PointerType(buffer.dtype, const=buffer.name in read_only_field_names, restrict=False)
Martin Bauer's avatar
Martin Bauer committed
366
            field_ptr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbol_name_to_variable_name(buffer.name)), dtype)
367

Martin Bauer's avatar
Martin Bauer committed
368
369
            buffer_index = base_buffer_index
            if len(field_access.index) > 1:
370
371
                raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')

Martin Bauer's avatar
Martin Bauer committed
372
373
374
            if len(field_access.index) > 0:
                cell_index = field_access.index[0]
                buffer_index += cell_index
375

Martin Bauer's avatar
Martin Bauer committed
376
377
            result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
                                             field_access.index)
378

Martin Bauer's avatar
Martin Bauer committed
379
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
380
381
382
383
        else:
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr

Martin Bauer's avatar
Martin Bauer committed
384
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
385
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
386
387
388
389
390
391
392
393
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
394
        else:
Martin Bauer's avatar
Martin Bauer committed
395
396
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
397

Martin Bauer's avatar
Martin Bauer committed
398
    return visit_node(ast_node)
399

400

Martin Bauer's avatar
Martin Bauer committed
401
def resolve_field_accesses(ast_node, read_only_field_names=set(),
Martin Bauer's avatar
Martin Bauer committed
402
403
                           field_to_base_pointer_info=MappingProxyType({}),
                           field_to_fixed_coordinates=MappingProxyType({})):
Martin Bauer's avatar
Martin Bauer committed
404
405
406
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

407
408
409
410
411
412
    Args:
        ast_node: the AST root
        read_only_field_names: set of field names which are considered read-only
        field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
                                    for details see :func:`parse_base_pointer_info`
        field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
Martin Bauer's avatar
Martin Bauer committed
413
                                    counters to index the field these symbols are used as coordinates
414
415
416

    Returns
        transformed AST
Martin Bauer's avatar
Martin Bauer committed
417
    """
Martin Bauer's avatar
Martin Bauer committed
418
419
    field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
    field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
420

Martin Bauer's avatar
Martin Bauer committed
421
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
422
        if isinstance(expr, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
423
            field_access = expr
Martin Bauer's avatar
Martin Bauer committed
424
            field = field_access.field
425

Martin Bauer's avatar
Martin Bauer committed
426
            if field_access.indirect_addressing_fields:
427
428
                new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
                                    for off in field_access.offsets)
Martin Bauer's avatar
Martin Bauer committed
429
430
431
432
433
                new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
                                    if isinstance(ind, sp.Basic) else ind
                                    for ind in field_access.index)
                field_access = Field.Access(field_access.field, new_offsets,
                                            new_indices, field_access.is_absolute_access)
434

Martin Bauer's avatar
Martin Bauer committed
435
436
            if field.name in field_to_base_pointer_info:
                base_pointer_info = field_to_base_pointer_info[field.name]
437
            else:
Martin Bauer's avatar
Martin Bauer committed
438
                base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
439

440
            dtype = PointerType(field.dtype, const=field.name in read_only_field_names, restrict=False)
Martin Bauer's avatar
Martin Bauer committed
441
            field_ptr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbol_name_to_variable_name(field.name)), dtype)
442

Martin Bauer's avatar
Martin Bauer committed
443
444
445
446
            def create_coordinate_dict(group_param):
                coordinates = {}
                for e in group_param:
                    if e < field.spatial_dimensions:
Martin Bauer's avatar
Martin Bauer committed
447
                        if field.name in field_to_fixed_coordinates:
448
449
450
451
                            if not field_access.is_absolute_access:
                                coordinates[e] = field_to_fixed_coordinates[field.name][e]
                            else:
                                coordinates[e] = 0
452
                        else:
Martin Bauer's avatar
Martin Bauer committed
453
454
455
456
                            if not field_access.is_absolute_access:
                                coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
                            else:
                                coordinates[e] = 0
Martin Bauer's avatar
Martin Bauer committed
457
                        coordinates[e] *= field.dtype.item_size
458
                    else:
459
                        if isinstance(field.dtype, StructType):
Martin Bauer's avatar
Martin Bauer committed
460
                            assert field.index_dimensions == 1
Martin Bauer's avatar
Martin Bauer committed
461
462
                            accessed_field_name = field_access.index[0]
                            assert isinstance(accessed_field_name, str)
Martin Bauer's avatar
Martin Bauer committed
463
                            coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
464
                        else:
Martin Bauer's avatar
Martin Bauer committed
465
                            coordinates[e] = field_access.index[e - field.spatial_dimensions]
466

Martin Bauer's avatar
Martin Bauer committed
467
                return coordinates
468

Martin Bauer's avatar
Martin Bauer committed
469
            last_pointer = field_ptr
470

Martin Bauer's avatar
Martin Bauer committed
471
472
            for group in reversed(base_pointer_info[1:]):
                coord_dict = create_coordinate_dict(group)
Martin Bauer's avatar
Martin Bauer committed
473
                new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
474
475
476
477
                if new_ptr not in enclosing_block.symbols_defined:
                    new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
                    enclosing_block.insert_before(new_assignment, sympy_assignment)
                last_pointer = new_ptr
478

Martin Bauer's avatar
Martin Bauer committed
479
            coord_dict = create_coordinate_dict(base_pointer_info[0])
Martin Bauer's avatar
Martin Bauer committed
480
            _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
481
482
            result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
                                             field_access.offsets, field_access.index)
483

Martin Bauer's avatar
Martin Bauer committed
484
485
            if isinstance(get_base_type(field_access.field.dtype), StructType):
                new_type = field_access.field.dtype.get_element_type(field_access.index[0])
Martin Bauer's avatar
Martin Bauer committed
486
                result = cast_func(result, new_type)
487

Martin Bauer's avatar
Martin Bauer committed
488
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
489
        else:
Martin Bauer's avatar
Martin Bauer committed
490
491
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr
492

Martin Bauer's avatar
Martin Bauer committed
493
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
494
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
495
496
497
498
499
500
501
502
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
503
        else:
Martin Bauer's avatar
Martin Bauer committed
504
505
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
506

Martin Bauer's avatar
Martin Bauer committed
507
    return visit_node(ast_node)
508
509


Martin Bauer's avatar
Martin Bauer committed
510
def move_constants_before_loop(ast_node):
511
512
    """Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.

Martin Bauer's avatar
Martin Bauer committed
513
    Call this after creating the loop structure with :func:`make_loop_over_domain`
Martin Bauer's avatar
Martin Bauer committed
514
    """
Martin Bauer's avatar
Martin Bauer committed
515
    def find_block_to_move_to(node):
Martin Bauer's avatar
Martin Bauer committed
516
517
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
518
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
519
520
521
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
522
523
        assert isinstance(node.parent, ast.Block)

Martin Bauer's avatar
Martin Bauer committed
524
525
        last_block = node.parent
        last_block_child = node
526
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
527
        prev_element = node
528
529
        while element:
            if isinstance(element, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
530
531
                last_block = element
                last_block_child = prev_element
532
533

            if isinstance(element, ast.Conditional):
Martin Bauer's avatar
Martin Bauer committed
534
                critical_symbols = element.condition_expr.atoms(sp.Symbol)
535
            else:
Martin Bauer's avatar
Martin Bauer committed
536
537
                critical_symbols = element.symbols_defined
            if node.undefined_symbols.intersection(critical_symbols):
538
                break
Martin Bauer's avatar
Martin Bauer committed
539
            prev_element = element
540
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
541
        return last_block, last_block_child
542

Martin Bauer's avatar
Martin Bauer committed
543
544
    def check_if_assignment_already_in_block(assignment, target_block):
        for arg in target_block.args:
545
546
547
548
549
550
            if type(arg) is not ast.SympyAssignment:
                continue
            if arg.lhs == assignment.lhs:
                return arg
        return None

Martin Bauer's avatar
Martin Bauer committed
551
    def get_blocks(node, result_list):
552
        if isinstance(node, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
553
            result_list.append(node)
554
555
        if isinstance(node, ast.Node):
            for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
556
                get_blocks(a, result_list)
557

Martin Bauer's avatar
Martin Bauer committed
558
559
560
    all_blocks = []
    get_blocks(ast_node, all_blocks)
    for block in all_blocks:
Martin Bauer's avatar
Martin Bauer committed
561
        children = block.take_child_nodes()
562
        for child in children:
563
564
565
            target, child_to_insert_before = find_block_to_move_to(child)
            if target == block:     # movement not possible
                target.append(child)
566
            else:
567
568
                if isinstance(child, ast.SympyAssignment):
                    exists_already = check_if_assignment_already_in_block(child, target)
569
                else:
570
                    exists_already = False
Martin Bauer's avatar
Martin Bauer committed
571

572
573
                if not exists_already:
                    target.insert_before(child, child_to_insert_before)
Martin Bauer's avatar
Martin Bauer committed
574
575
                elif exists_already and exists_already.rhs == child.rhs:
                    pass
576
                else:
Martin Bauer's avatar
Martin Bauer committed
577
                    block.append(child)  # don't move in this case - better would be to rename symbol
578
579


Martin Bauer's avatar
Martin Bauer committed
580
def split_inner_loop(ast_node: ast.Node, symbol_groups):
Martin Bauer's avatar
Martin Bauer committed
581
582
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
583

Martin Bauer's avatar
Martin Bauer committed
584
585
586
587
588
    Args:
        ast_node: AST root
        symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
                       updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups
                       and which no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
589
    """
Martin Bauer's avatar
Martin Bauer committed
590
591
592
593
594
595
596
597
598
599
600
601
602
    all_loops = ast_node.atoms(ast.LoopOverCoordinate)
    inner_loop = [l for l in all_loops if l.is_innermost_loop]
    assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
    inner_loop = inner_loop[0]
    assert type(inner_loop.body) is ast.Block
    outer_loop = [l for l in all_loops if l.is_outermost_loop]
    assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
    outer_loop = outer_loop[0]

    symbols_with_temporary_array = OrderedDict()
    assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args)

    assignment_groups = []
Martin Bauer's avatar
Martin Bauer committed
603
    for symbol_group in symbol_groups:
604
        # get all dependent symbols
Martin Bauer's avatar
Martin Bauer committed
605
        symbols_to_process = list(symbol_group)
Martin Bauer's avatar
Martin Bauer committed
606
607
608
609
        symbols_resolved = set()
        while symbols_to_process:
            s = symbols_to_process.pop()
            if s in symbols_resolved:
610
611
                continue

Martin Bauer's avatar
Martin Bauer committed
612
            if s in assignment_map:  # if there is no assignment inside the loop body it is independent already
Martin Bauer's avatar
Martin Bauer committed
613
614
615
                for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
                    if type(new_symbol) is not Field.Access and new_symbol not in symbols_with_temporary_array:
                        symbols_to_process.append(new_symbol)
Martin Bauer's avatar
Martin Bauer committed
616
            symbols_resolved.add(s)
617

Martin Bauer's avatar
Martin Bauer committed
618
        for symbol in symbol_group:
619
620
            if type(symbol) is not Field.Access:
                assert type(symbol) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
621
622
                new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
                symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
623

Martin Bauer's avatar
Martin Bauer committed
624
625
626
627
        assignment_group = []
        for assignment in inner_loop.body.args:
            if assignment.lhs in symbols_resolved:
                new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items())
Martin Bauer's avatar
Martin Bauer committed
628
                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbol_group:
629
                    assert type(assignment.lhs) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
630
631
                    new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
                    new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
632
                else:
Martin Bauer's avatar
Martin Bauer committed
633
634
635
                    new_lhs = assignment.lhs
                assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
        assignment_groups.append(assignment_group)
636

Martin Bauer's avatar
Martin Bauer committed
637
638
    new_loops = [inner_loop.new_loop_with_different_body(ast.Block(group)) for group in assignment_groups]
    inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
639

Martin Bauer's avatar
Martin Bauer committed
640
641
    for tmp_array in symbols_with_temporary_array:
        tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
Martin Bauer's avatar
Martin Bauer committed
642
643
644
645
        alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start)
        free_node = ast.TemporaryMemoryFree(alloc_node)
        outer_loop.parent.insert_front(alloc_node)
        outer_loop.parent.append(free_node)
646
647


Martin Bauer's avatar
Martin Bauer committed
648
def cut_loop(loop_node, cutting_points):
Martin Bauer's avatar
Martin Bauer committed
649
650
651
652
653
654
655
656
657
658
    """Cuts loop at given cutting points.

    One loop is transformed into len(cuttingPoints)+1 new loops that range from
    old_begin to cutting_points[1], ..., cutting_points[-1] to old_end

    Modifies the ast in place

    Returns:
        list of new loop nodes
    """
Martin Bauer's avatar
Martin Bauer committed
659
    if loop_node.step != 1:
660
        raise NotImplementedError("Can only split loops that have a step of 1")
Martin Bauer's avatar
Martin Bauer committed
661
662
663
    new_loops = []
    new_start = loop_node.start
    cutting_points = list(cutting_points) + [loop_node.stop]
Martin Bauer's avatar
Martin Bauer committed
664
665
    for new_end in cutting_points:
        if new_end - new_start == 1:
Martin Bauer's avatar
Martin Bauer committed
666
667
668
            new_body = deepcopy(loop_node.body)
            new_body.subs({loop_node.loop_counter_symbol: new_start})
            new_loops.append(new_body)
Martin Bauer's avatar
Martin Bauer committed
669
670
        elif new_end - new_start == 0:
            pass
671
        else:
Martin Bauer's avatar
Martin Bauer committed
672
673
            new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
                                              new_start, new_end, loop_node.step)
Martin Bauer's avatar
Martin Bauer committed
674
            new_loops.append(new_loop)
Martin Bauer's avatar
Martin Bauer committed
675
        new_start = new_end
Martin Bauer's avatar
Martin Bauer committed
676
    loop_node.parent.replace(loop_node, new_loops)
Martin Bauer's avatar
Martin Bauer committed
677
    return new_loops
678
679


680
681
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool=False) -> None:
    """Removes conditionals that are always true/false.
682
683

    Args:
684
685
686
687
688
689
        node: ast node, all descendants of this node are simplified
        loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false
                                     depending on the surrounding loop. For example if the surrounding loop goes from
                                     x=0 to 10 and the condition is x < 0, it is removed.
                                     This analysis needs the integer set library (ISL) islpy, so it is not done by
                                     default.
690
    """
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
    for conditional in node.atoms(ast.Conditional):
        conditional.condition_expr = sp.simplify(conditional.condition_expr)
        if conditional.condition_expr == sp.true:
            conditional.parent.replace(conditional, [conditional.true_block])
        elif conditional.condition_expr == sp.false:
            conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
        elif loop_counter_simplification:
            try:
                # noinspection PyUnresolvedReferences
                from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
                simplify_loop_counter_dependent_conditional(conditional)
            except ImportError:
                warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")


def cleanup_blocks(node: ast.Node) -> None:
707
708
709
710
711
    """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
    if isinstance(node, ast.SympyAssignment):
        return
    elif isinstance(node, ast.Block):
        for a in list(node.args):
Martin Bauer's avatar
Martin Bauer committed
712
            cleanup_blocks(a)
713
714
715
716
717
        if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
            node.parent.replace(node, node.args)
            return
    else:
        for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
718
            cleanup_blocks(a)
719
720


Martin Bauer's avatar
Martin Bauer committed
721
def symbol_name_to_variable_name(symbol_name):
722
    """Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names"""
Martin Bauer's avatar
Martin Bauer committed
723
    return symbol_name.replace("^", "_")
724
725


726
727
728
729
730
731
732
733
734
735
736
737
class KernelConstraintsCheck:
    """Checks if the input to create_kernel is valid.

    Test the following conditions:

    - SSA Form for pure symbols:
        -  Every pure symbol may occur only once as left-hand-side of an assignment
        -  Every pure symbol that is read, may not be written to later
    - Independence / Parallelization condition:
        - a field that is written may only be read at exact the same spatial position

    (Pure symbols are symbols that are not Field.Accesses)
Martin Bauer's avatar
Martin Bauer committed
738
    """
739
740
741
742
743
    FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])

    def __init__(self, type_for_symbol, check_independence_condition):
        self._type_for_symbol = type_for_symbol

744
        self.scopes = NestedScopes()
745
746
747
748
749
750
751
752
753
754
        self._field_writes = defaultdict(set)
        self.fields_read = set()
        self.check_independence_condition = check_independence_condition

    def process_assignment(self, assignment):
        # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
        new_rhs = self.process_expression(assignment.rhs)
        new_lhs = self._process_lhs(assignment.lhs)
        return ast.SympyAssignment(new_lhs, new_rhs)

755
    def process_expression(self, rhs, type_constants=True):
756
757
        self._update_accesses_rhs(rhs)
        if isinstance(rhs, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
758
            self.fields_read.add(rhs.field)
Martin Bauer's avatar
Martin Bauer committed
759
            self.fields_read.update(rhs.indirect_addressing_fields)
760
761
762
763
764
            return rhs
        elif isinstance(rhs, TypedSymbol):
            return rhs
        elif isinstance(rhs, sp.Symbol):
            return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name])
765
        elif type_constants and isinstance(rhs, sp.Number):
766
767
            return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
        elif isinstance(rhs, sp.Mul):
768
            new_args = [self.process_expression(arg, type_constants) if arg not in (-1, 1) else arg for arg in rhs.args]
769
            return rhs.func(*new_args) if new_args else rhs
770
771
        elif isinstance(rhs, sp.Indexed):
            return rhs
772
773
774
        else:
            if isinstance(rhs, sp.Pow):
                # don't process exponents -> they should remain integers
775
                return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1])
776
            else:
777
                new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
778
                return rhs.func(*new_args) if new_args else rhs
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796

    @property
    def fields_written(self):
        return set(k.field for k, v in self._field_writes.items() if len(v))

    def _process_lhs(self, lhs):
        assert isinstance(lhs, sp.Symbol)
        self._update_accesses_lhs(lhs)
        if not isinstance(lhs, Field.Access) and not isinstance(lhs, TypedSymbol):
            return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
        else:
            return lhs

    def _update_accesses_lhs(self, lhs):
        if isinstance(lhs, Field.Access):
            fai = self.FieldAndIndex(lhs.field, lhs.index)
            self._field_writes[fai].add(lhs.offsets)
            if len(self._field_writes[fai]) > 1:
797
                raise ValueError("Field {} is written at two different locations".format(lhs.field.name))
798
        elif isinstance(lhs, sp.Symbol):
799
            if self.scopes.is_defined_locally(lhs):
800
                raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
801
            if lhs in self.scopes.free_parameters:
802
                raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
803
            self.scopes.define_symbol(lhs)
804
805
806
807
808
809
810

    def _update_accesses_rhs(self, rhs):
        if isinstance(rhs, Field.Access) and self.check_independence_condition:
            writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
            for write_offset in writes:
                assert len(writes) == 1
                if write_offset != rhs.offsets:
811
812
                    raise ValueError("Violation of loop independence condition. Field "
                                     "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
813
814
            self.fields_read.add(rhs.field)
        elif isinstance(rhs, sp.Symbol):
815
            self.scopes.access_symbol(rhs)
816
817
818
819
820


def add_types(eqs, type_for_symbol, check_independence_condition):
    """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.

Martin Bauer's avatar
Martin Bauer committed
821
822
    Additionally returns sets of all fields which are read/written

823
824
825
826
827
828
829
830
831
    Args:
        eqs: list of equations
        type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
        check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
                                      kernels

    Returns:
        ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
         list of equations where symbols have been replaced by typed symbols
Martin Bauer's avatar
Martin Bauer committed
832
    """
Martin Bauer's avatar
Martin Bauer committed
833
834
    if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
        type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
835

836
    check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
837

Martin Bauer's avatar
Martin Bauer committed
838
839
840
841
    def visit(obj):
        if isinstance(obj, list) or isinstance(obj, tuple):
            return [visit(e) for e in obj]
        if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
842
            return check.process_assignment(obj)
Martin Bauer's avatar
Martin Bauer committed
843
        elif isinstance(obj, ast.Conditional):
844
            check.scopes.push()
Martin Bauer's avatar
Martin Bauer committed
845
            false_block = None if obj.false_block is None else visit(obj.false_block)
846
847
848
849
            result = ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False),
                                     true_block=visit(obj.true_block), false_block=false_block)
            check.scopes.pop()
            return result
Martin Bauer's avatar
Martin Bauer committed
850
        elif isinstance(obj, ast.Block):
851
852
853
854
            check.scopes.push()
            result = ast.Block([visit(e) for e in obj.args])
            check.scopes.pop()
            return result
855
        elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
Martin Bauer's avatar
Martin Bauer committed
856
            return obj
857
858
        else:
            raise ValueError("Invalid object in kernel " + str(type(obj)))
859

Martin Bauer's avatar
Martin Bauer committed
860
    typed_equations = visit(eqs)
861

862
    return check.fields_read, check.fields_written, typed_equations
863
864


Martin Bauer's avatar
Martin Bauer committed
865
def insert_casts(node):
866
    """Checks the types and inserts casts and pointer arithmetic where necessary.
Martin Bauer's avatar
Martin Bauer committed
867

868
869
870
871
872
    Args:
        node: the head node of the ast

    Returns:
        modified AST
Martin Bauer's avatar
Martin Bauer committed
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
    """
    def cast(zipped_args_types, target_dtype):
        """
        Adds casts to the arguments if their type differs from the target type
        :param zipped_args_types: a zipped list of args and types
        :param target_dtype: The target data type
        :return: args with possible casts
        """
        casted_args = []
        for argument, data_type in zipped_args_types:
            if data_type.numpy_dtype != target_dtype.numpy_dtype:  # ignoring const
                casted_args.append(cast_func(argument, target_dtype))
            else:
                casted_args.append(argument)
        return casted_args

    def pointer_arithmetic(expr_args):
        """
        Creates a valid pointer arithmetic function
        :param expr_args: Arguments of the add expression
        :return: pointer_arithmetic_func
        """
        pointer = None
        new_args = []
        for arg, data_type in expr_args:
            if data_type.func is PointerType:
                assert pointer is None
                pointer = arg
        for arg, data_type in expr_args:
            if arg != pointer:
                assert data_type.is_int() or data_type.is_uint()
                new_args.append(arg)
        new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
        return pointer_arithmetic_func(pointer, new_args)

908
    if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func):
Martin Bauer's avatar
Martin Bauer committed
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
        return node
    args = []
    for arg in node.args:
        args.append(insert_casts(arg))
    # TODO indexed, LoopOverCoordinate
    if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
        # TODO optimize pow, don't cast integer on double
        types = [get_type_of_expression(arg) for arg in args]
        assert len(types) > 0
        target = collate_types(types)
        zipped = list(zip(args, types))
        if target.func is PointerType:
            assert node.func is sp.Add
            return pointer_arithmetic(zipped)
        else:
            return node.func(*cast(zipped, target))
    elif node.func is ast.SympyAssignment:
        lhs = args[0]
        rhs = args[1]
        target = get_type_of_expression(lhs)
        if target.func is PointerType:
            return node.func(*args)  # TODO fix, not complete
        else:
            return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
    elif node.func is ast.ResolvedFieldAccess:
        return node
    elif node.func is ast.Block:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is ast.LoopOverCoordinate:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is sp.Piecewise:
        expressions = [expr for (expr, _) in args]
        types = [get_type_of_expression(expr) for expr in expressions]
        target = collate_types(types)
        zipped = list(zip(expressions, types))
        casted_expressions = cast(zipped, target)
        args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_expressions)]

    return node.func(*args)


954
955
956
957
958
959
960
961
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
    """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""

    all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
    assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
    inner_loop = all_inner_loops.pop()

    for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
962
        cut_loop(loop, [loop.stop - 1])
963

964
    simplify_conditionals(function_node.body, loop_counter_simplification=True)
965
    cleanup_blocks(function_node.body)
Martin Bauer's avatar
Martin Bauer committed
966

967
968
969
970
    move_constants_before_loop(function_node.body)
    cleanup_blocks(function_node.body)


Martin Bauer's avatar
Martin Bauer committed
971
972
973
# --------------------------------------- Helper Functions -------------------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
974
def typing_from_sympy_inspection(eqs, default_type="double"):
Martin Bauer's avatar
Martin Bauer committed
975
976
977
    """
    Creates a default symbol name to type mapping.
    If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
978
979
980
981
982
983

    Args:
        eqs: list of equations
        default_type: the type for non-boolean symbols
    Returns:
        dictionary, mapping symbol name to type
Martin Bauer's avatar
Martin Bauer committed
984
    """
Martin Bauer's avatar
Martin Bauer committed