transformations.py 49 KB
Newer Older
1
import warnings
2
from collections import defaultdict, OrderedDict, namedtuple
3
from copy import deepcopy
Martin Bauer's avatar
Martin Bauer committed
4
from types import MappingProxyType
5
6
import pickle
import hashlib
7
8
import sympy as sp
from sympy.logic.boolalg import Boolean
9
from sympy.tensor import IndexedBase
10
11

from pystencils.simp.assignment_collection import AssignmentCollection
12
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
13
from pystencils.field import Field, FieldType
14
15
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \
    cast_func, pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
16
from pystencils.kernelparameters import FieldPointerSymbol
Martin Bauer's avatar
Martin Bauer committed
17
from pystencils.slicing import normalize_slice
Martin Bauer's avatar
Martin Bauer committed
18
import pystencils.astnodes as ast
19
20


21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class NestedScopes:
    """Symbol visibility model using nested scopes

    - every accessed symbol that was not defined before, is added as a "free parameter"
    - free parameters are global, i.e. they are not in scopes
    - push/pop adds or removes a scope

    >>> s = NestedScopes()
    >>> s.access_symbol("a")
    >>> s.is_defined("a")
    False
    >>> s.free_parameters
    {'a'}
    >>> s.define_symbol("b")
    >>> s.is_defined("b")
    True
    >>> s.push()
    >>> s.is_defined_locally("b")
    False
    >>> s.define_symbol("c")
    >>> s.pop()
    >>> s.is_defined("c")
    False
    """

    def __init__(self):
        self.free_parameters = set()
        self._defined = [set()]

    def access_symbol(self, symbol):
        if not self.is_defined(symbol):
            self.free_parameters.add(symbol)

    def define_symbol(self, symbol):
        self._defined[-1].add(symbol)

    def is_defined(self, symbol):
        return any(symbol in scopes for scopes in self._defined)

    def is_defined_locally(self, symbol):
        return symbol in self._defined[-1]

    def push(self):
        self._defined.append(set())

    def pop(self):
        self._defined.pop()
        assert self.depth >= 1

    @property
    def depth(self):
        return len(self._defined)


Martin Bauer's avatar
Martin Bauer committed
75
def filtered_tree_iteration(node, node_type, stop_type=None):
76
    for arg in node.args:
Martin Bauer's avatar
Martin Bauer committed
77
        if isinstance(arg, node_type):
78
            yield arg
Martin Bauer's avatar
Martin Bauer committed
79
80
81
        elif stop_type and isinstance(node, stop_type):
            continue

Martin Bauer's avatar
Martin Bauer committed
82
        yield from filtered_tree_iteration(arg, node_type)
83
84


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def generic_visit(term, visitor):
    if isinstance(term, AssignmentCollection):
        new_main_assignments = generic_visit(term.main_assignments, visitor)
        new_subexpressions = generic_visit(term.subexpressions, visitor)
        return term.copy(new_main_assignments, new_subexpressions)
    elif isinstance(term, list):
        return [generic_visit(e, visitor) for e in term]
    elif isinstance(term, Assignment):
        return Assignment(term.lhs, generic_visit(term.rhs, visitor))
    elif isinstance(term, sp.Matrix):
        return term.applyfunc(lambda e: generic_visit(e, visitor))
    else:
        return visitor(term)


100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def unify_shape_symbols(body, common_shape, fields):
    """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.

    When creating a kernel with variable array sizes, all passed arrays must have the same size.
    This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead
    of one for each field. For example shape_arr1[0]  and shape_arr2[0] must be equal, so they should also be
    represented by the same symbol.

    Args:
        body: ast node, for the kernel part where substitutions is made, is modified in-place
        common_shape: shape of the field that was chosen
        fields: all fields whose shapes should be replaced by common_shape
    """
    substitutions = {}
    for field in fields:
        assert len(field.spatial_shape) == len(common_shape)
        if not field.has_fixed_shape:
            for common_shape_component, shape_component in zip(common_shape, field.spatial_shape):
                if shape_component != common_shape_component:
                    substitutions[shape_component] = common_shape_component
    if substitutions:
        body.subs(substitutions)


Martin Bauer's avatar
Martin Bauer committed
124
def get_common_shape(field_set):
125
126
    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
    ValueError is raised"""
Martin Bauer's avatar
Martin Bauer committed
127
128
129
130
131
132
133
134
    nr_of_fixed_shaped_fields = 0
    for f in field_set:
        if f.has_fixed_shape:
            nr_of_fixed_shaped_fields += 1

    if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set):
        fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
        var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
135
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
Martin Bauer's avatar
Martin Bauer committed
136
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (var_field_names, fixed_field_names)
137
138
        raise ValueError(msg)

Martin Bauer's avatar
Martin Bauer committed
139
140
141
142
    shape_set = set([f.spatial_shape for f in field_set])
    if nr_of_fixed_shaped_fields == len(field_set):
        if len(shape_set) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
143

Martin Bauer's avatar
Martin Bauer committed
144
    shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
145
146
147
    return shape


Martin Bauer's avatar
Martin Bauer committed
148
149
150
151
def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layers=None, loop_order=None):
    """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.

    Args:
152
        body: Block object with inner loop contents
Martin Bauer's avatar
Martin Bauer committed
153
154
155
156
157
158
159
160
161
        function_name: name of generated C function
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
             if None, the number of ghost layers is determined automatically and assumed to be equal for a
             all dimensions
        loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)

    Returns:
        :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
162
163
    """
    # find correct ordering by inspecting participating FieldAccesses
Martin Bauer's avatar
Martin Bauer committed
164
    field_accesses = body.atoms(Field.Access)
165
166
    field_accesses = {e for e in field_accesses if not e.is_absolute_access}

Martin Bauer's avatar
Martin Bauer committed
167
168
169
170
171
172
173
    # exclude accesses to buffers from field_list, because buffers are treated separately
    field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)]
    fields = set(field_list)

    if loop_order is None:
        loop_order = get_optimal_loop_ordering(fields)

174
175
    shape = get_common_shape(fields)
    unify_shape_symbols(body, common_shape=shape, fields=fields)
Martin Bauer's avatar
Martin Bauer committed
176
177
178
179
180
181
182
183
184
185
186

    if iteration_slice is not None:
        iteration_slice = normalize_slice(iteration_slice, shape)

    if ghost_layers is None:
        required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
        ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
    if isinstance(ghost_layers, int):
        ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)

    current_body = body
Martin Bauer's avatar
Martin Bauer committed
187
    for i, loop_coordinate in enumerate(reversed(loop_order)):
Martin Bauer's avatar
Martin Bauer committed
188
        if iteration_slice is None:
Martin Bauer's avatar
Martin Bauer committed
189
190
191
            begin = ghost_layers[loop_coordinate][0]
            end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
            new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
Martin Bauer's avatar
Martin Bauer committed
192
            current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
193
        else:
Martin Bauer's avatar
Martin Bauer committed
194
            slice_component = iteration_slice[loop_coordinate]
Martin Bauer's avatar
Martin Bauer committed
195
196
            if type(slice_component) is slice:
                sc = slice_component
Martin Bauer's avatar
Martin Bauer committed
197
                new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
Martin Bauer's avatar
Martin Bauer committed
198
                current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
199
            else:
Martin Bauer's avatar
Martin Bauer committed
200
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
Martin Bauer's avatar
Martin Bauer committed
201
202
                                                 sp.sympify(slice_component))
                current_body.insert_front(assignment)
203

Martin Bauer's avatar
Martin Bauer committed
204
    ast_node = ast.KernelFunction(current_body, ghost_layers=ghost_layers, function_name=function_name, backend='cpu')
Martin Bauer's avatar
Martin Bauer committed
205
    return ast_node
206
207


Martin Bauer's avatar
Martin Bauer committed
208
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
Martin Bauer's avatar
Martin Bauer committed
209
    r"""
210
    Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
Martin Bauer's avatar
Martin Bauer committed
211
212
213
214
215
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.

216
217
218
219
220
221
222
223
224
    Args:
        field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
        coordinates: mapping of coordinate ids to its value, where stride*value is calculated
        previous_ptr: the pointer which is de-referenced

    Returns
        tuple with the new pointer symbol and the calculated offset

    Examples:
Martin Bauer's avatar
Martin Bauer committed
225
        >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
Martin Bauer's avatar
Martin Bauer committed
226
        >>> x, y = sp.symbols("x y")
Martin Bauer's avatar
Martin Bauer committed
227
228
        >>> prev_pointer = TypedSymbol("ptr", "double")
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
229
        (ptr_01, _stride_myfield_0*x + _stride_myfield_0)
Martin Bauer's avatar
Martin Bauer committed
230
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
231
        (ptr_01_1m2, _stride_myfield_0*x + _stride_myfield_0 + _stride_myfield_1*y - 2*_stride_myfield_1)
Martin Bauer's avatar
Martin Bauer committed
232
    """
Martin Bauer's avatar
Martin Bauer committed
233
    field = field_access.field
234
235
    offset = 0
    name = ""
Martin Bauer's avatar
Martin Bauer committed
236
    list_to_hash = []
Martin Bauer's avatar
Martin Bauer committed
237
238
    for coordinate_id, coordinate_value in coordinates.items():
        offset += field.strides[coordinate_id] * coordinate_value
239

Martin Bauer's avatar
Martin Bauer committed
240
241
242
        if coordinate_id < field.spatial_dimensions:
            offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
            if type(field_access.offsets[coordinate_id]) is int:
Martin Bauer's avatar
Martin Bauer committed
243
                name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
244
            else:
Martin Bauer's avatar
Martin Bauer committed
245
                list_to_hash.append(field_access.offsets[coordinate_id])
246
        else:
Martin Bauer's avatar
Martin Bauer committed
247
            if type(coordinate_value) is int:
Martin Bauer's avatar
Martin Bauer committed
248
                name += "_%d%d" % (coordinate_id, coordinate_value)
249
            else:
Martin Bauer's avatar
Martin Bauer committed
250
                list_to_hash.append(coordinate_value)
251

Martin Bauer's avatar
Martin Bauer committed
252
    if len(list_to_hash) > 0:
253
        name += hashlib.md5(pickle.dumps(list_to_hash)).hexdigest()[:16]
254

Martin Bauer's avatar
Martin Bauer committed
255
    name = name.replace("-", 'm')
Martin Bauer's avatar
Martin Bauer committed
256
257
    new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
    return new_ptr, offset
258
259


Martin Bauer's avatar
Martin Bauer committed
260
def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
261
    """
Martin Bauer's avatar
Martin Bauer committed
262
    Creates base pointer specification for :func:`resolve_field_accesses` function.
Martin Bauer's avatar
Martin Bauer committed
263
264
265

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
266
    zero directly in the field access. These specifications are defined dependent on the loop ordering.
Martin Bauer's avatar
Martin Bauer committed
267
268
    This function translates more readable version into the specification above.

269
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
270
271
272
273
274
275
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

Martin Bauer's avatar
Martin Bauer committed
276
277
278
    Args:
        base_pointer_specification: nested list with above specifications
        loop_order: list with ordering of loops from outer to inner
Martin Bauer's avatar
Martin Bauer committed
279
280
        spatial_dimensions: number of spatial dimensions
        index_dimensions: number of index dimensions
Martin Bauer's avatar
Martin Bauer committed
281
282
283

    Returns:
        list of tuples that can be passed to :func:`resolve_field_accesses`
Martin Bauer's avatar
Martin Bauer committed
284
285
286
287
288

    Examples:
        >>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
        ...                         spatial_dimensions=3, index_dimensions=1)
        [[0], [3], [1, 2]]
289
290
    """
    result = []
Martin Bauer's avatar
Martin Bauer committed
291
292
    specified_coordinates = set()
    loop_order = list(reversed(loop_order))
Martin Bauer's avatar
Martin Bauer committed
293
    for spec_group in base_pointer_specification:
Martin Bauer's avatar
Martin Bauer committed
294
295
296
        new_group = []

        def add_new_element(elem):
Martin Bauer's avatar
Martin Bauer committed
297
            if elem >= spatial_dimensions + index_dimensions:
Martin Bauer's avatar
Martin Bauer committed
298
299
300
301
302
                raise ValueError("Coordinate %d does not exist" % (elem,))
            new_group.append(elem)
            if elem in specified_coordinates:
                raise ValueError("Coordinate %d specified two times" % (elem,))
            specified_coordinates.add(elem)
Martin Bauer's avatar
Martin Bauer committed
303
        for element in spec_group:
304
            if type(element) is int:
Martin Bauer's avatar
Martin Bauer committed
305
                add_new_element(element)
306
307
308
309
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
Martin Bauer's avatar
Martin Bauer committed
310
                    add_new_element(loop_order[index])
311
312
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
Martin Bauer's avatar
Martin Bauer committed
313
                    add_new_element(loop_order[-index])
314
                elif element == "all":
Martin Bauer's avatar
Martin Bauer committed
315
                    for i in range(spatial_dimensions):
Martin Bauer's avatar
Martin Bauer committed
316
                        add_new_element(i)
317
318
319
320
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
Martin Bauer's avatar
Martin Bauer committed
321
                add_new_element(spatial_dimensions + index)
322
323
324
            else:
                raise ValueError("Unknown specification %s" % (element,))

Martin Bauer's avatar
Martin Bauer committed
325
        result.append(new_group)
326

Martin Bauer's avatar
Martin Bauer committed
327
    all_coordinates = set(range(spatial_dimensions + index_dimensions))
Martin Bauer's avatar
Martin Bauer committed
328
    rest = all_coordinates - specified_coordinates
329
330
    if rest:
        result.append(list(rest))
331

332
333
334
    return result


Martin Bauer's avatar
Martin Bauer committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
    """Used for buffer fields to determine the linearized index of the buffer dependent on loop counter symbols.

    Args:
        ast_node: ast before any field accesses are resolved
        loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
                       for GPU kernels: list of 'loop counters' from inner to outer loop
        loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default

    Returns:
        base buffer index - required by 'resolve_buffer_accesses' function
    """
    if loop_counters is None or loop_iterations is None:
        loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
        loops.reverse()
        parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
        assert len(loops) == len(parents_of_innermost_loop)
        assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))

        loop_iterations = [(l.stop - l.start) / l.step for l in loops]
        loop_counters = [l.loop_counter_symbol for l in loops]

    field_accesses = ast_node.atoms(Field.Access)
    buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
    loop_counters = [v * len(buffer_accesses) for v in loop_counters]

    base_buffer_index = loop_counters[0]
    stride = 1
    for idx, var in enumerate(loop_counters[1:]):
        cur_stride = loop_iterations[idx]
        stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
        base_buffer_index += var * stride
    return base_buffer_index


Martin Bauer's avatar
Martin Bauer committed
370
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
Martin Bauer's avatar
Martin Bauer committed
371

Martin Bauer's avatar
Martin Bauer committed
372
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
373
        if isinstance(expr, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
374
            field_access = expr
375
376

            # Do not apply transformation if field is not a buffer
Martin Bauer's avatar
Martin Bauer committed
377
            if not FieldType.is_buffer(field_access.field):
378
379
                return expr

Martin Bauer's avatar
Martin Bauer committed
380
            buffer = field_access.field
381
            field_ptr = FieldPointerSymbol(buffer.name, buffer.dtype, const=buffer.name in read_only_field_names)
382

Martin Bauer's avatar
Martin Bauer committed
383
384
            buffer_index = base_buffer_index
            if len(field_access.index) > 1:
385
386
                raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')

Martin Bauer's avatar
Martin Bauer committed
387
388
389
            if len(field_access.index) > 0:
                cell_index = field_access.index[0]
                buffer_index += cell_index
390

Martin Bauer's avatar
Martin Bauer committed
391
392
            result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
                                             field_access.index)
393

Martin Bauer's avatar
Martin Bauer committed
394
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
395
396
397
398
        else:
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr

Martin Bauer's avatar
Martin Bauer committed
399
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
400
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
401
402
403
404
405
406
407
408
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
409
        else:
Martin Bauer's avatar
Martin Bauer committed
410
411
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
412

Martin Bauer's avatar
Martin Bauer committed
413
    return visit_node(ast_node)
414

415

Martin Bauer's avatar
Martin Bauer committed
416
def resolve_field_accesses(ast_node, read_only_field_names=set(),
Martin Bauer's avatar
Martin Bauer committed
417
418
                           field_to_base_pointer_info=MappingProxyType({}),
                           field_to_fixed_coordinates=MappingProxyType({})):
Martin Bauer's avatar
Martin Bauer committed
419
420
421
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

422
423
424
425
426
427
    Args:
        ast_node: the AST root
        read_only_field_names: set of field names which are considered read-only
        field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
                                    for details see :func:`parse_base_pointer_info`
        field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
Martin Bauer's avatar
Martin Bauer committed
428
                                    counters to index the field these symbols are used as coordinates
429
430
431

    Returns
        transformed AST
Martin Bauer's avatar
Martin Bauer committed
432
    """
Martin Bauer's avatar
Martin Bauer committed
433
434
    field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
    field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
435

Martin Bauer's avatar
Martin Bauer committed
436
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
437
        if isinstance(expr, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
438
            field_access = expr
Martin Bauer's avatar
Martin Bauer committed
439
            field = field_access.field
440

Martin Bauer's avatar
Martin Bauer committed
441
            if field_access.indirect_addressing_fields:
442
443
                new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
                                    for off in field_access.offsets)
Martin Bauer's avatar
Martin Bauer committed
444
445
446
447
448
                new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
                                    if isinstance(ind, sp.Basic) else ind
                                    for ind in field_access.index)
                field_access = Field.Access(field_access.field, new_offsets,
                                            new_indices, field_access.is_absolute_access)
449

Martin Bauer's avatar
Martin Bauer committed
450
451
            if field.name in field_to_base_pointer_info:
                base_pointer_info = field_to_base_pointer_info[field.name]
452
            else:
Martin Bauer's avatar
Martin Bauer committed
453
                base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
454

455
            field_ptr = FieldPointerSymbol(field.name, field.dtype, const=field.name in read_only_field_names)
456

Martin Bauer's avatar
Martin Bauer committed
457
458
459
460
            def create_coordinate_dict(group_param):
                coordinates = {}
                for e in group_param:
                    if e < field.spatial_dimensions:
Martin Bauer's avatar
Martin Bauer committed
461
                        if field.name in field_to_fixed_coordinates:
462
463
464
465
                            if not field_access.is_absolute_access:
                                coordinates[e] = field_to_fixed_coordinates[field.name][e]
                            else:
                                coordinates[e] = 0
466
                        else:
Martin Bauer's avatar
Martin Bauer committed
467
468
469
470
                            if not field_access.is_absolute_access:
                                coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
                            else:
                                coordinates[e] = 0
Martin Bauer's avatar
Martin Bauer committed
471
                        coordinates[e] *= field.dtype.item_size
472
                    else:
473
                        if isinstance(field.dtype, StructType):
Martin Bauer's avatar
Martin Bauer committed
474
                            assert field.index_dimensions == 1
Martin Bauer's avatar
Martin Bauer committed
475
476
                            accessed_field_name = field_access.index[0]
                            assert isinstance(accessed_field_name, str)
Martin Bauer's avatar
Martin Bauer committed
477
                            coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
478
                        else:
Martin Bauer's avatar
Martin Bauer committed
479
                            coordinates[e] = field_access.index[e - field.spatial_dimensions]
480

Martin Bauer's avatar
Martin Bauer committed
481
                return coordinates
482

Martin Bauer's avatar
Martin Bauer committed
483
            last_pointer = field_ptr
484

Martin Bauer's avatar
Martin Bauer committed
485
486
            for group in reversed(base_pointer_info[1:]):
                coord_dict = create_coordinate_dict(group)
Martin Bauer's avatar
Martin Bauer committed
487
                new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
488
489
490
491
                if new_ptr not in enclosing_block.symbols_defined:
                    new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
                    enclosing_block.insert_before(new_assignment, sympy_assignment)
                last_pointer = new_ptr
492

Martin Bauer's avatar
Martin Bauer committed
493
            coord_dict = create_coordinate_dict(base_pointer_info[0])
Martin Bauer's avatar
Martin Bauer committed
494
            _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
495
496
            result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
                                             field_access.offsets, field_access.index)
497

Martin Bauer's avatar
Martin Bauer committed
498
499
            if isinstance(get_base_type(field_access.field.dtype), StructType):
                new_type = field_access.field.dtype.get_element_type(field_access.index[0])
500
                result = reinterpret_cast_func(result, new_type)
501

Martin Bauer's avatar
Martin Bauer committed
502
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
503
        else:
Martin Bauer's avatar
Martin Bauer committed
504
505
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr
506

Martin Bauer's avatar
Martin Bauer committed
507
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
508
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
509
510
511
512
513
514
515
516
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
517
        else:
Martin Bauer's avatar
Martin Bauer committed
518
519
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
520

Martin Bauer's avatar
Martin Bauer committed
521
    return visit_node(ast_node)
522
523


Martin Bauer's avatar
Martin Bauer committed
524
def move_constants_before_loop(ast_node):
525
526
    """Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.

Martin Bauer's avatar
Martin Bauer committed
527
    Call this after creating the loop structure with :func:`make_loop_over_domain`
Martin Bauer's avatar
Martin Bauer committed
528
    """
Martin Bauer's avatar
Martin Bauer committed
529
    def find_block_to_move_to(node):
Martin Bauer's avatar
Martin Bauer committed
530
531
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
532
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
533
534
535
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
536
537
        assert isinstance(node.parent, ast.Block)

Martin Bauer's avatar
Martin Bauer committed
538
539
        last_block = node.parent
        last_block_child = node
540
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
541
        prev_element = node
542
543
        while element:
            if isinstance(element, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
544
545
                last_block = element
                last_block_child = prev_element
546
547

            if isinstance(element, ast.Conditional):
Martin Bauer's avatar
Martin Bauer committed
548
                critical_symbols = element.condition_expr.atoms(sp.Symbol)
549
            else:
Martin Bauer's avatar
Martin Bauer committed
550
551
                critical_symbols = element.symbols_defined
            if node.undefined_symbols.intersection(critical_symbols):
552
                break
Martin Bauer's avatar
Martin Bauer committed
553
            prev_element = element
554
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
555
        return last_block, last_block_child
556

Martin Bauer's avatar
Martin Bauer committed
557
558
    def check_if_assignment_already_in_block(assignment, target_block):
        for arg in target_block.args:
559
560
561
562
563
564
            if type(arg) is not ast.SympyAssignment:
                continue
            if arg.lhs == assignment.lhs:
                return arg
        return None

Martin Bauer's avatar
Martin Bauer committed
565
    def get_blocks(node, result_list):
566
        if isinstance(node, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
567
            result_list.append(node)
568
569
        if isinstance(node, ast.Node):
            for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
570
                get_blocks(a, result_list)
571

Martin Bauer's avatar
Martin Bauer committed
572
573
574
    all_blocks = []
    get_blocks(ast_node, all_blocks)
    for block in all_blocks:
Martin Bauer's avatar
Martin Bauer committed
575
        children = block.take_child_nodes()
576
        for child in children:
577
578
579
            target, child_to_insert_before = find_block_to_move_to(child)
            if target == block:     # movement not possible
                target.append(child)
580
            else:
581
582
                if isinstance(child, ast.SympyAssignment):
                    exists_already = check_if_assignment_already_in_block(child, target)
583
                else:
584
                    exists_already = False
Martin Bauer's avatar
Martin Bauer committed
585

586
587
                if not exists_already:
                    target.insert_before(child, child_to_insert_before)
Martin Bauer's avatar
Martin Bauer committed
588
589
                elif exists_already and exists_already.rhs == child.rhs:
                    pass
590
                else:
Martin Bauer's avatar
Martin Bauer committed
591
                    block.append(child)  # don't move in this case - better would be to rename symbol
592
593


Martin Bauer's avatar
Martin Bauer committed
594
def split_inner_loop(ast_node: ast.Node, symbol_groups):
Martin Bauer's avatar
Martin Bauer committed
595
596
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
597

Martin Bauer's avatar
Martin Bauer committed
598
599
600
601
602
    Args:
        ast_node: AST root
        symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
                       updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups
                       and which no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
603
    """
Martin Bauer's avatar
Martin Bauer committed
604
605
606
607
608
609
610
611
612
613
614
615
616
    all_loops = ast_node.atoms(ast.LoopOverCoordinate)
    inner_loop = [l for l in all_loops if l.is_innermost_loop]
    assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
    inner_loop = inner_loop[0]
    assert type(inner_loop.body) is ast.Block
    outer_loop = [l for l in all_loops if l.is_outermost_loop]
    assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
    outer_loop = outer_loop[0]

    symbols_with_temporary_array = OrderedDict()
    assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args)

    assignment_groups = []
Martin Bauer's avatar
Martin Bauer committed
617
    for symbol_group in symbol_groups:
618
        # get all dependent symbols
Martin Bauer's avatar
Martin Bauer committed
619
        symbols_to_process = list(symbol_group)
Martin Bauer's avatar
Martin Bauer committed
620
621
622
623
        symbols_resolved = set()
        while symbols_to_process:
            s = symbols_to_process.pop()
            if s in symbols_resolved:
624
625
                continue

Martin Bauer's avatar
Martin Bauer committed
626
            if s in assignment_map:  # if there is no assignment inside the loop body it is independent already
Martin Bauer's avatar
Martin Bauer committed
627
628
629
                for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
                    if type(new_symbol) is not Field.Access and new_symbol not in symbols_with_temporary_array:
                        symbols_to_process.append(new_symbol)
Martin Bauer's avatar
Martin Bauer committed
630
            symbols_resolved.add(s)
631

Martin Bauer's avatar
Martin Bauer committed
632
        for symbol in symbol_group:
633
634
            if type(symbol) is not Field.Access:
                assert type(symbol) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
635
636
                new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
                symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
637

Martin Bauer's avatar
Martin Bauer committed
638
639
640
641
        assignment_group = []
        for assignment in inner_loop.body.args:
            if assignment.lhs in symbols_resolved:
                new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items())
Martin Bauer's avatar
Martin Bauer committed
642
                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbol_group:
643
                    assert type(assignment.lhs) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
644
645
                    new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
                    new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
646
                else:
Martin Bauer's avatar
Martin Bauer committed
647
648
649
                    new_lhs = assignment.lhs
                assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
        assignment_groups.append(assignment_group)
650

Martin Bauer's avatar
Martin Bauer committed
651
652
    new_loops = [inner_loop.new_loop_with_different_body(ast.Block(group)) for group in assignment_groups]
    inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
653

Martin Bauer's avatar
Martin Bauer committed
654
655
    for tmp_array in symbols_with_temporary_array:
        tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
Martin Bauer's avatar
Martin Bauer committed
656
657
658
659
        alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start)
        free_node = ast.TemporaryMemoryFree(alloc_node)
        outer_loop.parent.insert_front(alloc_node)
        outer_loop.parent.append(free_node)
660
661


Martin Bauer's avatar
Martin Bauer committed
662
def cut_loop(loop_node, cutting_points):
Martin Bauer's avatar
Martin Bauer committed
663
664
665
666
667
668
669
670
671
672
    """Cuts loop at given cutting points.

    One loop is transformed into len(cuttingPoints)+1 new loops that range from
    old_begin to cutting_points[1], ..., cutting_points[-1] to old_end

    Modifies the ast in place

    Returns:
        list of new loop nodes
    """
Martin Bauer's avatar
Martin Bauer committed
673
    if loop_node.step != 1:
674
        raise NotImplementedError("Can only split loops that have a step of 1")
Martin Bauer's avatar
Martin Bauer committed
675
676
677
    new_loops = []
    new_start = loop_node.start
    cutting_points = list(cutting_points) + [loop_node.stop]
Martin Bauer's avatar
Martin Bauer committed
678
679
    for new_end in cutting_points:
        if new_end - new_start == 1:
Martin Bauer's avatar
Martin Bauer committed
680
681
682
            new_body = deepcopy(loop_node.body)
            new_body.subs({loop_node.loop_counter_symbol: new_start})
            new_loops.append(new_body)
Martin Bauer's avatar
Martin Bauer committed
683
684
        elif new_end - new_start == 0:
            pass
685
        else:
Martin Bauer's avatar
Martin Bauer committed
686
687
            new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
                                              new_start, new_end, loop_node.step)
Martin Bauer's avatar
Martin Bauer committed
688
            new_loops.append(new_loop)
Martin Bauer's avatar
Martin Bauer committed
689
        new_start = new_end
Martin Bauer's avatar
Martin Bauer committed
690
    loop_node.parent.replace(loop_node, new_loops)
Martin Bauer's avatar
Martin Bauer committed
691
    return new_loops
692
693


Martin Bauer's avatar
Martin Bauer committed
694
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = False) -> None:
695
    """Removes conditionals that are always true/false.
696
697

    Args:
698
699
700
701
702
703
        node: ast node, all descendants of this node are simplified
        loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false
                                     depending on the surrounding loop. For example if the surrounding loop goes from
                                     x=0 to 10 and the condition is x < 0, it is removed.
                                     This analysis needs the integer set library (ISL) islpy, so it is not done by
                                     default.
704
    """
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
    for conditional in node.atoms(ast.Conditional):
        conditional.condition_expr = sp.simplify(conditional.condition_expr)
        if conditional.condition_expr == sp.true:
            conditional.parent.replace(conditional, [conditional.true_block])
        elif conditional.condition_expr == sp.false:
            conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
        elif loop_counter_simplification:
            try:
                # noinspection PyUnresolvedReferences
                from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
                simplify_loop_counter_dependent_conditional(conditional)
            except ImportError:
                warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")


def cleanup_blocks(node: ast.Node) -> None:
721
722
723
724
725
    """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
    if isinstance(node, ast.SympyAssignment):
        return
    elif isinstance(node, ast.Block):
        for a in list(node.args):
Martin Bauer's avatar
Martin Bauer committed
726
            cleanup_blocks(a)
727
728
729
730
731
        if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
            node.parent.replace(node, node.args)
            return
    else:
        for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
732
            cleanup_blocks(a)
733
734


735
736
737
738
739
740
741
742
743
744
745
746
class KernelConstraintsCheck:
    """Checks if the input to create_kernel is valid.

    Test the following conditions:

    - SSA Form for pure symbols:
        -  Every pure symbol may occur only once as left-hand-side of an assignment
        -  Every pure symbol that is read, may not be written to later
    - Independence / Parallelization condition:
        - a field that is written may only be read at exact the same spatial position

    (Pure symbols are symbols that are not Field.Accesses)
Martin Bauer's avatar
Martin Bauer committed
747
    """
748
749
750
751
752
    FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])

    def __init__(self, type_for_symbol, check_independence_condition):
        self._type_for_symbol = type_for_symbol

753
        self.scopes = NestedScopes()
754
755
756
757
758
759
760
761
762
763
        self._field_writes = defaultdict(set)
        self.fields_read = set()
        self.check_independence_condition = check_independence_condition

    def process_assignment(self, assignment):
        # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
        new_rhs = self.process_expression(assignment.rhs)
        new_lhs = self._process_lhs(assignment.lhs)
        return ast.SympyAssignment(new_lhs, new_rhs)

764
    def process_expression(self, rhs, type_constants=True):
765
766
        self._update_accesses_rhs(rhs)
        if isinstance(rhs, Field.Access):
Martin Bauer's avatar
Martin Bauer committed
767
            self.fields_read.add(rhs.field)
Martin Bauer's avatar
Martin Bauer committed
768
            self.fields_read.update(rhs.indirect_addressing_fields)
769
770
771
772
            return rhs
        elif isinstance(rhs, TypedSymbol):
            return rhs
        elif isinstance(rhs, sp.Symbol):
773
            return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
774
        elif type_constants and isinstance(rhs, sp.Number):
775
776
            return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
        elif isinstance(rhs, sp.Mul):
777
            new_args = [self.process_expression(arg, type_constants) if arg not in (-1, 1) else arg for arg in rhs.args]
778
            return rhs.func(*new_args) if new_args else rhs
779
780
        elif isinstance(rhs, sp.Indexed):
            return rhs
781
782
783
        else:
            if isinstance(rhs, sp.Pow):
                # don't process exponents -> they should remain integers
784
                return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1])
785
            else:
786
                new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
787
                return rhs.func(*new_args) if new_args else rhs
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805

    @property
    def fields_written(self):
        return set(k.field for k, v in self._field_writes.items() if len(v))

    def _process_lhs(self, lhs):
        assert isinstance(lhs, sp.Symbol)
        self._update_accesses_lhs(lhs)
        if not isinstance(lhs, Field.Access) and not isinstance(lhs, TypedSymbol):
            return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
        else:
            return lhs

    def _update_accesses_lhs(self, lhs):
        if isinstance(lhs, Field.Access):
            fai = self.FieldAndIndex(lhs.field, lhs.index)
            self._field_writes[fai].add(lhs.offsets)
            if len(self._field_writes[fai]) > 1:
806
                raise ValueError("Field {} is written at two different locations".format(lhs.field.name))
807
        elif isinstance(lhs, sp.Symbol):
808
            if self.scopes.is_defined_locally(lhs):
809
                raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
810
            if lhs in self.scopes.free_parameters:
811
                raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
812
            self.scopes.define_symbol(lhs)
813
814
815
816
817
818
819

    def _update_accesses_rhs(self, rhs):
        if isinstance(rhs, Field.Access) and self.check_independence_condition:
            writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
            for write_offset in writes:
                assert len(writes) == 1
                if write_offset != rhs.offsets:
820
821
                    raise ValueError("Violation of loop independence condition. Field "
                                     "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
822
823
            self.fields_read.add(rhs.field)
        elif isinstance(rhs, sp.Symbol):
824
            self.scopes.access_symbol(rhs)
825
826
827
828
829


def add_types(eqs, type_for_symbol, check_independence_condition):
    """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.

Martin Bauer's avatar
Martin Bauer committed
830
831
    Additionally returns sets of all fields which are read/written

832
833
834
835
836
837
838
839
840
    Args:
        eqs: list of equations
        type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
        check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
                                      kernels

    Returns:
        ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
         list of equations where symbols have been replaced by typed symbols
Martin Bauer's avatar
Martin Bauer committed
841
    """
Martin Bauer's avatar
Martin Bauer committed
842
843
    if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
        type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
844

845
    check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
846

Martin Bauer's avatar
Martin Bauer committed
847
848
849
850
    def visit(obj):
        if isinstance(obj, list) or isinstance(obj, tuple):
            return [visit(e) for e in obj]
        if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
851
            return check.process_assignment(obj)
Martin Bauer's avatar
Martin Bauer committed
852
        elif isinstance(obj, ast.Conditional):
853
            check.scopes.push()
Martin Bauer's avatar
Martin Bauer committed
854
            false_block = None if obj.false_block is None else visit(obj.false_block)
855
856
857
858
            result = ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False),
                                     true_block=visit(obj.true_block), false_block=false_block)
            check.scopes.pop()
            return result
Martin Bauer's avatar
Martin Bauer committed
859
        elif isinstance(obj, ast.Block):
860
861
862
863
            check.scopes.push()
            result = ast.Block([visit(e) for e in obj.args])
            check.scopes.pop()
            return result
864
        elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
Martin Bauer's avatar
Martin Bauer committed
865
            return obj
866
867
        else:
            raise ValueError("Invalid object in kernel " + str(type(obj)))
868

Martin Bauer's avatar
Martin Bauer committed
869
    typed_equations = visit(eqs)
870

871
    return check.fields_read, check.fields_written, typed_equations
872
873


Martin Bauer's avatar
Martin Bauer committed
874
def insert_casts(node):
875
    """Checks the types and inserts casts and pointer arithmetic where necessary.
Martin Bauer's avatar
Martin Bauer committed
876

877
878
879
880
881
    Args:
        node: the head node of the ast

    Returns:
        modified AST
Martin Bauer's avatar
Martin Bauer committed
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
    """
    def cast(zipped_args_types, target_dtype):
        """
        Adds casts to the arguments if their type differs from the target type
        :param zipped_args_types: a zipped list of args and types
        :param target_dtype: The target data type
        :return: args with possible casts
        """
        casted_args = []
        for argument, data_type in zipped_args_types:
            if data_type.numpy_dtype != target_dtype.numpy_dtype:  # ignoring const
                casted_args.append(cast_func(argument, target_dtype))
            else:
                casted_args.append(argument)
        return casted_args

    def pointer_arithmetic(expr_args):
        """
        Creates a valid pointer arithmetic function
        :param expr_args: Arguments of the add expression
        :return: pointer_arithmetic_func
        """
        pointer = None
        new_args = []
        for arg, data_type in expr_args:
            if data_type.func is PointerType:
                assert pointer is None
                pointer = arg
        for arg, data_type in expr_args:
            if arg != pointer:
                assert data_type.is_int() or data_type.is_uint()
                new_args.append(arg)
        new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
        return pointer_arithmetic_func(pointer, new_args)

917
    if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func):
Martin Bauer's avatar
Martin Bauer committed
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
        return node
    args = []
    for arg in node.args:
        args.append(insert_casts(arg))
    # TODO indexed, LoopOverCoordinate
    if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
        # TODO optimize pow, don't cast integer on double
        types = [get_type_of_expression(arg) for arg in args]
        assert len(types) > 0
        target = collate_types(types)
        zipped = list(zip(args, types))
        if target.func is PointerType:
            assert node.func is sp.Add
            return pointer_arithmetic(zipped)
        else:
            return node.func(*cast(zipped, target))
    elif node.func is ast.SympyAssignment:
        lhs = args[0]
        rhs = args[1]
        target = get_type_of_expression(lhs)
        if target.func is PointerType:
            return node.func(*args)  # TODO fix, not complete
        else:
            return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
    elif node.func is ast.ResolvedFieldAccess:
        return node
    elif node.func is ast.Block:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is ast.LoopOverCoordinate:
        for old_arg, new_arg in zip(node.args, args):
            node.replace(old_arg, new_arg)
        return node
    elif node.func is sp.Piecewise:
        expressions = [expr for (expr, _) in args]
        types = [get_type_of_expression(expr) for expr in expressions]
        target = collate_types(types)
        zipped = list(zip(expressions, types))
        casted_expressions = cast(zipped, target)
        args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_expressions)]

    return node.func(*args)


963
964
965
966
967
968
969
970
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
    """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""

    all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
    assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
    inner_loop = all_inner_loops.pop()

    for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
971
        cut_loop(loop, [loop.stop - 1])
972

973
    simplify_conditionals(function_node.body, loop_counter_simplification=True)
974
    cleanup_blocks(function_node.body)
Martin Bauer's avatar
Martin Bauer committed
975

976
977
978
979
    move_constants_before_loop(function_node.body)
    cleanup_blocks(function_node.body)


Martin Bauer's avatar
Martin Bauer committed
980
981
982
# --------------------------------------- Helper Functions -------------------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
983
def typing_from_sympy_inspection(eqs, default_type="double"):
Martin Bauer's avatar
Martin Bauer committed
984
985
986
    """
    Creates a default symbol name to type mapping.
    If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
987
988
989
990
991
992

    Args:
        eqs: list of equations
        default_type: the type for non-boolean symbols
    Returns:
        dictionary, mapping symbol name to type
Martin Bauer's avatar
Martin Bauer committed
993
    """
Martin Bauer's avatar
Martin Bauer committed
994
    result = defaultdict(lambda: default_type)
995
    for eq in eqs:
996
997
998
999
1000
        if isinstance(eq, ast.Conditional):
            result.update(typing_from_sympy_inspection(eq.true_block.args))
            if eq.false_block:
                result.update(typing_from_sympy_inspection(eq.false_block.args))
        elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment):
1001
            continue
1002
1003
1004
1005
1006
        else:
            # problematic case here is when rhs is a symbol: then it is impossible to decide here without
            # further information what type the left hand side is - default fallback is the dict value then
            if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
                result[eq.lhs.name] = "bool"
1007
1008
1009
    return result


Martin Bauer's avatar
Martin Bauer committed
1010
def get_next_parent_of_type(node, parent_type):
1011
1012
1013
1014
    """Returns the next parent node of given type or None, if root is reached.

    Traverses the AST nodes parents until a parent of given type was found.
    If no such parent is found, None is returned
Martin Bauer's avatar
Martin Bauer committed
1015
    """
1016
1017
    parent = node.parent
    while parent is not None:
Martin Bauer's avatar
Martin Bauer committed
1018
        if isinstance(parent, parent_type):
1019
1020
1021
1022
1023
            return parent
        parent = parent.parent
    return None


1024
def parents_of_type(node, parent_type, include_current=False):
1025
    """Generator for all parent nodes of given type"""
1026
1027
1028
1029
1030
1031
1032
    parent = node if include_current else node.parent
    while parent is not None:
        if isinstance(parent, parent_type):
            yield parent
        parent = parent.parent


Martin Bauer's avatar
Martin Bauer committed
1033
def get_optimal_loop_ordering(fields):
Martin Bauer's avatar
Martin Bauer committed
1034
1035
1036
    """
    Dete