transformations.py 57.3 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1 2
import hashlib
import pickle
3
import warnings
Martin Bauer's avatar
Martin Bauer committed
4
from collections import OrderedDict, defaultdict, namedtuple
5
from copy import deepcopy
Martin Bauer's avatar
Martin Bauer committed
6
from types import MappingProxyType
Martin Bauer's avatar
Martin Bauer committed
7

8
import numpy as np
9 10
import sympy as sp
from sympy.logic.boolalg import Boolean
Martin Bauer's avatar
Martin Bauer committed
11 12

import pystencils.astnodes as ast
13
import pystencils.integer_functions
14
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
15 16 17 18
from pystencils.data_types import (
    PointerType, StructType, TypedSymbol, cast_func, collate_types, create_type, get_base_type,
    get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
from pystencils.field import AbstractField, Field, FieldType
19
from pystencils.kernelparameters import FieldPointerSymbol
Martin Bauer's avatar
Martin Bauer committed
20
from pystencils.simp.assignment_collection import AssignmentCollection
Martin Bauer's avatar
Martin Bauer committed
21
from pystencils.slicing import normalize_slice
22 23


24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
class NestedScopes:
    """Symbol visibility model using nested scopes

    - every accessed symbol that was not defined before, is added as a "free parameter"
    - free parameters are global, i.e. they are not in scopes
    - push/pop adds or removes a scope

    >>> s = NestedScopes()
    >>> s.access_symbol("a")
    >>> s.is_defined("a")
    False
    >>> s.free_parameters
    {'a'}
    >>> s.define_symbol("b")
    >>> s.is_defined("b")
    True
    >>> s.push()
    >>> s.is_defined_locally("b")
    False
    >>> s.define_symbol("c")
    >>> s.pop()
    >>> s.is_defined("c")
    False
    """

    def __init__(self):
        self.free_parameters = set()
        self._defined = [set()]

    def access_symbol(self, symbol):
        if not self.is_defined(symbol):
            self.free_parameters.add(symbol)

    def define_symbol(self, symbol):
        self._defined[-1].add(symbol)

    def is_defined(self, symbol):
        return any(symbol in scopes for scopes in self._defined)

    def is_defined_locally(self, symbol):
        return symbol in self._defined[-1]

    def push(self):
        self._defined.append(set())

    def pop(self):
        self._defined.pop()
        assert self.depth >= 1

    @property
    def depth(self):
        return len(self._defined)


Martin Bauer's avatar
Martin Bauer committed
78
def filtered_tree_iteration(node, node_type, stop_type=None):
79
    for arg in node.args:
Martin Bauer's avatar
Martin Bauer committed
80
        if isinstance(arg, node_type):
81
            yield arg
Martin Bauer's avatar
Martin Bauer committed
82 83 84
        elif stop_type and isinstance(node, stop_type):
            continue

Martin Bauer's avatar
Martin Bauer committed
85
        yield from filtered_tree_iteration(arg, node_type)
86 87


88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
def generic_visit(term, visitor):
    if isinstance(term, AssignmentCollection):
        new_main_assignments = generic_visit(term.main_assignments, visitor)
        new_subexpressions = generic_visit(term.subexpressions, visitor)
        return term.copy(new_main_assignments, new_subexpressions)
    elif isinstance(term, list):
        return [generic_visit(e, visitor) for e in term]
    elif isinstance(term, Assignment):
        return Assignment(term.lhs, generic_visit(term.rhs, visitor))
    elif isinstance(term, sp.Matrix):
        return term.applyfunc(lambda e: generic_visit(e, visitor))
    else:
        return visitor(term)


103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
def unify_shape_symbols(body, common_shape, fields):
    """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.

    When creating a kernel with variable array sizes, all passed arrays must have the same size.
    This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead
    of one for each field. For example shape_arr1[0]  and shape_arr2[0] must be equal, so they should also be
    represented by the same symbol.

    Args:
        body: ast node, for the kernel part where substitutions is made, is modified in-place
        common_shape: shape of the field that was chosen
        fields: all fields whose shapes should be replaced by common_shape
    """
    substitutions = {}
    for field in fields:
        assert len(field.spatial_shape) == len(common_shape)
        if not field.has_fixed_shape:
            for common_shape_component, shape_component in zip(common_shape, field.spatial_shape):
                if shape_component != common_shape_component:
                    substitutions[shape_component] = common_shape_component
    if substitutions:
        body.subs(substitutions)


Martin Bauer's avatar
Martin Bauer committed
127
def get_common_shape(field_set):
128 129
    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
    ValueError is raised"""
Martin Bauer's avatar
Martin Bauer committed
130 131 132 133 134 135 136 137
    nr_of_fixed_shaped_fields = 0
    for f in field_set:
        if f.has_fixed_shape:
            nr_of_fixed_shaped_fields += 1

    if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set):
        fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
        var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
138
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
Martin Bauer's avatar
Martin Bauer committed
139
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (var_field_names, fixed_field_names)
140 141
        raise ValueError(msg)

Martin Bauer's avatar
Martin Bauer committed
142 143 144 145
    shape_set = set([f.spatial_shape for f in field_set])
    if nr_of_fixed_shaped_fields == len(field_set):
        if len(shape_set) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
146

Martin Bauer's avatar
Martin Bauer committed
147
    shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
148 149 150
    return shape


151
def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
Martin Bauer's avatar
Martin Bauer committed
152 153 154
    """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.

    Args:
155
        body: Block object with inner loop contents
Martin Bauer's avatar
Martin Bauer committed
156 157 158 159 160 161 162
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
             if None, the number of ghost layers is determined automatically and assumed to be equal for a
             all dimensions
        loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)

    Returns:
163
        tuple of loop-node, ghost_layer_info
164 165
    """
    # find correct ordering by inspecting participating FieldAccesses
166
    field_accesses = body.atoms(AbstractField.AbstractAccess)
167 168
    field_accesses = {e for e in field_accesses if not e.is_absolute_access}

Martin Bauer's avatar
Martin Bauer committed
169 170 171 172 173 174 175
    # exclude accesses to buffers from field_list, because buffers are treated separately
    field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)]
    fields = set(field_list)

    if loop_order is None:
        loop_order = get_optimal_loop_ordering(fields)

176 177
    shape = get_common_shape(fields)
    unify_shape_symbols(body, common_shape=shape, fields=fields)
Martin Bauer's avatar
Martin Bauer committed
178 179 180 181 182 183 184 185 186 187 188

    if iteration_slice is not None:
        iteration_slice = normalize_slice(iteration_slice, shape)

    if ghost_layers is None:
        required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
        ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
    if isinstance(ghost_layers, int):
        ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)

    current_body = body
Martin Bauer's avatar
Martin Bauer committed
189
    for i, loop_coordinate in enumerate(reversed(loop_order)):
Martin Bauer's avatar
Martin Bauer committed
190
        if iteration_slice is None:
Martin Bauer's avatar
Martin Bauer committed
191 192
            begin = ghost_layers[loop_coordinate][0]
            end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
193
            new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
Martin Bauer's avatar
Martin Bauer committed
194
            current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
195
        else:
Martin Bauer's avatar
Martin Bauer committed
196
            slice_component = iteration_slice[loop_coordinate]
Martin Bauer's avatar
Martin Bauer committed
197 198
            if type(slice_component) is slice:
                sc = slice_component
199
                new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
Martin Bauer's avatar
Martin Bauer committed
200
                current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
201
            else:
202 203
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
                                                 sp.sympify(slice_component))
Martin Bauer's avatar
Martin Bauer committed
204
                current_body.insert_front(assignment)
205

206
    return current_body, ghost_layers
207 208


Martin Bauer's avatar
Martin Bauer committed
209
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
Martin Bauer's avatar
Martin Bauer committed
210
    r"""
211
    Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
Martin Bauer's avatar
Martin Bauer committed
212 213 214 215 216
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.

217 218 219 220 221 222 223 224 225
    Args:
        field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
        coordinates: mapping of coordinate ids to its value, where stride*value is calculated
        previous_ptr: the pointer which is de-referenced

    Returns
        tuple with the new pointer symbol and the calculated offset

    Examples:
Martin Bauer's avatar
Martin Bauer committed
226
        >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
Martin Bauer's avatar
Martin Bauer committed
227
        >>> x, y = sp.symbols("x y")
Martin Bauer's avatar
Martin Bauer committed
228 229
        >>> prev_pointer = TypedSymbol("ptr", "double")
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
230
        (ptr_01, _stride_myfield_0*x + _stride_myfield_0)
Martin Bauer's avatar
Martin Bauer committed
231
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
232
        (ptr_01_1m2, _stride_myfield_0*x + _stride_myfield_0 + _stride_myfield_1*y - 2*_stride_myfield_1)
Martin Bauer's avatar
Martin Bauer committed
233
    """
Martin Bauer's avatar
Martin Bauer committed
234
    field = field_access.field
235 236
    offset = 0
    name = ""
Martin Bauer's avatar
Martin Bauer committed
237
    list_to_hash = []
Martin Bauer's avatar
Martin Bauer committed
238 239
    for coordinate_id, coordinate_value in coordinates.items():
        offset += field.strides[coordinate_id] * coordinate_value
240

Martin Bauer's avatar
Martin Bauer committed
241
        if coordinate_id < field.spatial_dimensions:
242
            offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
Martin Bauer's avatar
Martin Bauer committed
243
            if type(field_access.offsets[coordinate_id]) is int:
244
                name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
245
            else:
Martin Bauer's avatar
Martin Bauer committed
246
                list_to_hash.append(field_access.offsets[coordinate_id])
247
        else:
Martin Bauer's avatar
Martin Bauer committed
248
            if type(coordinate_value) is int:
Martin Bauer's avatar
Martin Bauer committed
249
                name += "_%d%d" % (coordinate_id, coordinate_value)
250
            else:
Martin Bauer's avatar
Martin Bauer committed
251
                list_to_hash.append(coordinate_value)
252

Martin Bauer's avatar
Martin Bauer committed
253
    if len(list_to_hash) > 0:
254
        name += hashlib.md5(pickle.dumps(list_to_hash)).hexdigest()[:16]
255

Martin Bauer's avatar
Martin Bauer committed
256
    name = name.replace("-", 'm')
Martin Bauer's avatar
Martin Bauer committed
257 258
    new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
    return new_ptr, offset
259 260


261
def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
262
    """
Martin Bauer's avatar
Martin Bauer committed
263
    Creates base pointer specification for :func:`resolve_field_accesses` function.
Martin Bauer's avatar
Martin Bauer committed
264 265 266

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
267
    zero directly in the field access. These specifications are defined dependent on the loop ordering.
Martin Bauer's avatar
Martin Bauer committed
268 269
    This function translates more readable version into the specification above.

270
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
271 272 273 274 275 276
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

Martin Bauer's avatar
Martin Bauer committed
277 278 279
    Args:
        base_pointer_specification: nested list with above specifications
        loop_order: list with ordering of loops from outer to inner
Martin Bauer's avatar
Martin Bauer committed
280 281
        spatial_dimensions: number of spatial dimensions
        index_dimensions: number of index dimensions
Martin Bauer's avatar
Martin Bauer committed
282 283 284

    Returns:
        list of tuples that can be passed to :func:`resolve_field_accesses`
Martin Bauer's avatar
Martin Bauer committed
285 286 287 288 289

    Examples:
        >>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
        ...                         spatial_dimensions=3, index_dimensions=1)
        [[0], [3], [1, 2]]
290 291
    """
    result = []
Martin Bauer's avatar
Martin Bauer committed
292 293
    specified_coordinates = set()
    loop_order = list(reversed(loop_order))
Martin Bauer's avatar
Martin Bauer committed
294
    for spec_group in base_pointer_specification:
Martin Bauer's avatar
Martin Bauer committed
295 296 297
        new_group = []

        def add_new_element(elem):
Martin Bauer's avatar
Martin Bauer committed
298
            if elem >= spatial_dimensions + index_dimensions:
299
                raise ValueError("Coordinate %d does not exist" % (elem,))
Martin Bauer's avatar
Martin Bauer committed
300 301
            new_group.append(elem)
            if elem in specified_coordinates:
302
                raise ValueError("Coordinate %d specified two times" % (elem,))
Martin Bauer's avatar
Martin Bauer committed
303
            specified_coordinates.add(elem)
304

Martin Bauer's avatar
Martin Bauer committed
305
        for element in spec_group:
306
            if type(element) is int:
Martin Bauer's avatar
Martin Bauer committed
307
                add_new_element(element)
308 309 310 311
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
Martin Bauer's avatar
Martin Bauer committed
312
                    add_new_element(loop_order[index])
313 314
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
Martin Bauer's avatar
Martin Bauer committed
315
                    add_new_element(loop_order[-index])
316
                elif element == "all":
Martin Bauer's avatar
Martin Bauer committed
317
                    for i in range(spatial_dimensions):
Martin Bauer's avatar
Martin Bauer committed
318
                        add_new_element(i)
319 320 321 322
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
Martin Bauer's avatar
Martin Bauer committed
323
                add_new_element(spatial_dimensions + index)
324
            else:
325
                raise ValueError("Unknown specification %s" % (element,))
326

Martin Bauer's avatar
Martin Bauer committed
327
        result.append(new_group)
328

Martin Bauer's avatar
Martin Bauer committed
329
    all_coordinates = set(range(spatial_dimensions + index_dimensions))
Martin Bauer's avatar
Martin Bauer committed
330
    rest = all_coordinates - specified_coordinates
331 332
    if rest:
        result.append(list(rest))
333

334 335 336
    return result


Martin Bauer's avatar
Martin Bauer committed
337 338 339 340 341 342 343 344 345 346 347 348 349
def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
    """Used for buffer fields to determine the linearized index of the buffer dependent on loop counter symbols.

    Args:
        ast_node: ast before any field accesses are resolved
        loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
                       for GPU kernels: list of 'loop counters' from inner to outer loop
        loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default

    Returns:
        base buffer index - required by 'resolve_buffer_accesses' function
    """
    if loop_counters is None or loop_iterations is None:
350
        loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
Martin Bauer's avatar
Martin Bauer committed
351
        loops.reverse()
352
        parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
Martin Bauer's avatar
Martin Bauer committed
353
        assert len(loops) == len(parents_of_innermost_loop)
354
        assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
Martin Bauer's avatar
Martin Bauer committed
355 356 357 358

        loop_iterations = [(l.stop - l.start) / l.step for l in loops]
        loop_counters = [l.loop_counter_symbol for l in loops]

359
    field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
360
    buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
Martin Bauer's avatar
Martin Bauer committed
361 362 363 364 365 366
    loop_counters = [v * len(buffer_accesses) for v in loop_counters]

    base_buffer_index = loop_counters[0]
    stride = 1
    for idx, var in enumerate(loop_counters[1:]):
        cur_stride = loop_iterations[idx]
367
        stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
Martin Bauer's avatar
Martin Bauer committed
368 369 370 371
        base_buffer_index += var * stride
    return base_buffer_index


372 373
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):

Martin Bauer's avatar
Martin Bauer committed
374
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
375
        if isinstance(expr, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
376
            field_access = expr
377 378

            # Do not apply transformation if field is not a buffer
Martin Bauer's avatar
Martin Bauer committed
379
            if not FieldType.is_buffer(field_access.field):
380 381
                return expr

Martin Bauer's avatar
Martin Bauer committed
382
            buffer = field_access.field
383
            field_ptr = FieldPointerSymbol(buffer.name, buffer.dtype, const=buffer.name in read_only_field_names)
384

Martin Bauer's avatar
Martin Bauer committed
385 386
            buffer_index = base_buffer_index
            if len(field_access.index) > 1:
387
                raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')
388

Martin Bauer's avatar
Martin Bauer committed
389 390 391
            if len(field_access.index) > 0:
                cell_index = field_access.index[0]
                buffer_index += cell_index
392

393
            result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
Martin Bauer's avatar
Martin Bauer committed
394
                                             field_access.index)
395

Martin Bauer's avatar
Martin Bauer committed
396
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
397 398 399 400
        else:
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr

401 402
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
403 404 405 406 407 408
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
409 410
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
411
        else:
Martin Bauer's avatar
Martin Bauer committed
412 413
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
414

Martin Bauer's avatar
Martin Bauer committed
415
    return visit_node(ast_node)
416

417

418
def resolve_field_accesses(ast_node, read_only_field_names=set(),
Martin Bauer's avatar
Martin Bauer committed
419 420
                           field_to_base_pointer_info=MappingProxyType({}),
                           field_to_fixed_coordinates=MappingProxyType({})):
Martin Bauer's avatar
Martin Bauer committed
421 422 423
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

424 425 426 427 428 429
    Args:
        ast_node: the AST root
        read_only_field_names: set of field names which are considered read-only
        field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
                                    for details see :func:`parse_base_pointer_info`
        field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
Martin Bauer's avatar
Martin Bauer committed
430
                                    counters to index the field these symbols are used as coordinates
431 432 433

    Returns
        transformed AST
Martin Bauer's avatar
Martin Bauer committed
434
    """
435 436
    field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
    field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
437

Martin Bauer's avatar
Martin Bauer committed
438
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
439
        if isinstance(expr, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
440
            field_access = expr
Martin Bauer's avatar
Martin Bauer committed
441
            field = field_access.field
442

Martin Bauer's avatar
Martin Bauer committed
443
            if field_access.indirect_addressing_fields:
444 445 446 447 448
                new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
                                    for off in field_access.offsets)
                new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
                                    if isinstance(ind, sp.Basic) else ind
                                    for ind in field_access.index)
Martin Bauer's avatar
Martin Bauer committed
449
                field_access = Field.Access(field_access.field, new_offsets,
450
                                            new_indices, field_access.is_absolute_access)
451

Martin Bauer's avatar
Martin Bauer committed
452 453
            if field.name in field_to_base_pointer_info:
                base_pointer_info = field_to_base_pointer_info[field.name]
454
            else:
455 456 457 458
                base_pointer_info = [
                    list(
                        range(field.index_dimensions + field.spatial_dimensions))
                ]
459

460 461 462 463
            field_ptr = FieldPointerSymbol(
                field.name,
                field.dtype,
                const=field.name in read_only_field_names)
464

Martin Bauer's avatar
Martin Bauer committed
465 466 467 468
            def create_coordinate_dict(group_param):
                coordinates = {}
                for e in group_param:
                    if e < field.spatial_dimensions:
Martin Bauer's avatar
Martin Bauer committed
469
                        if field.name in field_to_fixed_coordinates:
470
                            if not field_access.is_absolute_access:
471
                                coordinates[e] = field_to_fixed_coordinates[field.name][e]
472 473
                            else:
                                coordinates[e] = 0
474
                        else:
Martin Bauer's avatar
Martin Bauer committed
475
                            if not field_access.is_absolute_access:
476
                                coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
Martin Bauer's avatar
Martin Bauer committed
477 478
                            else:
                                coordinates[e] = 0
Martin Bauer's avatar
Martin Bauer committed
479
                        coordinates[e] *= field.dtype.item_size
480
                    else:
481
                        if isinstance(field.dtype, StructType):
Martin Bauer's avatar
Martin Bauer committed
482
                            assert field.index_dimensions == 1
Martin Bauer's avatar
Martin Bauer committed
483 484
                            accessed_field_name = field_access.index[0]
                            assert isinstance(accessed_field_name, str)
485
                            coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
486
                        else:
487
                            coordinates[e] = field_access.index[e - field.spatial_dimensions]
488

Martin Bauer's avatar
Martin Bauer committed
489
                return coordinates
490

Martin Bauer's avatar
Martin Bauer committed
491
            last_pointer = field_ptr
492

Martin Bauer's avatar
Martin Bauer committed
493 494
            for group in reversed(base_pointer_info[1:]):
                coord_dict = create_coordinate_dict(group)
495
                new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
496
                if new_ptr not in enclosing_block.symbols_defined:
497 498
                    new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
                    enclosing_block.insert_before(new_assignment, sympy_assignment)
Martin Bauer's avatar
Martin Bauer committed
499
                last_pointer = new_ptr
500

Martin Bauer's avatar
Martin Bauer committed
501
            coord_dict = create_coordinate_dict(base_pointer_info[0])
502 503 504
            _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
            result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
                                             field_access.offsets, field_access.index)
505

Martin Bauer's avatar
Martin Bauer committed
506
            if isinstance(get_base_type(field_access.field.dtype), StructType):
507
                new_type = field_access.field.dtype.get_element_type(field_access.index[0])
508
                result = reinterpret_cast_func(result, new_type)
509

Martin Bauer's avatar
Martin Bauer committed
510
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
511
        else:
Martin Bauer's avatar
Martin Bauer committed
512 513
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr
514

515 516 517 518 519
            if hasattr(expr, 'args'):
                new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
            else:
                new_args = []
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
520 521 522 523 524 525
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
526 527
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
528 529 530
        elif isinstance(sub_ast, ast.Conditional):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
531
            sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, enclosing_block, sub_ast)
532 533 534
            visit_node(sub_ast.true_block)
            if sub_ast.false_block:
                visit_node(sub_ast.false_block)
535
        else:
536 537 538
            if isinstance(sub_ast, (bool, int, float)):
                return
            for a in sub_ast.args:
Martin Bauer's avatar
Martin Bauer committed
539
                visit_node(a)
540

Martin Bauer's avatar
Martin Bauer committed
541
    return visit_node(ast_node)
542 543


Martin Bauer's avatar
Martin Bauer committed
544
def move_constants_before_loop(ast_node):
545 546
    """Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.

Martin Bauer's avatar
Martin Bauer committed
547
    Call this after creating the loop structure with :func:`make_loop_over_domain`
Martin Bauer's avatar
Martin Bauer committed
548
    """
Martin Bauer's avatar
Martin Bauer committed
549
    def find_block_to_move_to(node):
Martin Bauer's avatar
Martin Bauer committed
550 551
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
552
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
553 554 555
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
556 557
        assert isinstance(node.parent, ast.Block)

Martin Bauer's avatar
Martin Bauer committed
558 559
        last_block = node.parent
        last_block_child = node
560
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
561
        prev_element = node
562 563
        while element:
            if isinstance(element, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
564 565
                last_block = element
                last_block_child = prev_element
566 567

            if isinstance(element, ast.Conditional):
568
                break
569
            else:
Martin Bauer's avatar
Martin Bauer committed
570 571
                critical_symbols = element.symbols_defined
            if node.undefined_symbols.intersection(critical_symbols):
572
                break
Martin Bauer's avatar
Martin Bauer committed
573
            prev_element = element
574
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
575
        return last_block, last_block_child
576

577
    def check_if_assignment_already_in_block(assignment, target_block, rhs_or_lhs=True):
Martin Bauer's avatar
Martin Bauer committed
578
        for arg in target_block.args:
579 580
            if type(arg) is not ast.SympyAssignment:
                continue
581
            if (rhs_or_lhs and arg.rhs == assignment.rhs) or (not rhs_or_lhs and arg.lhs == assignment.lhs):
582 583 584
                return arg
        return None

Martin Bauer's avatar
Martin Bauer committed
585
    def get_blocks(node, result_list):
586
        if isinstance(node, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
587
            result_list.append(node)
588 589
        if isinstance(node, ast.Node):
            for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
590
                get_blocks(a, result_list)
591

Martin Bauer's avatar
Martin Bauer committed
592 593 594
    all_blocks = []
    get_blocks(ast_node, all_blocks)
    for block in all_blocks:
Martin Bauer's avatar
Martin Bauer committed
595
        children = block.take_child_nodes()
596 597 598 599
        # Every time a symbol can be replaced in the current block because the assignment
        # was found in a parent block, but with a different lhs symbol (same rhs)
        # the outer symbol is inserted here as key.
        substitute_variables = {}
600
        for child in children:
601 602 603
            # Before traversing the next child, all symbols are substituted first.
            child.subs(substitute_variables)

604
            if not isinstance(child, ast.SympyAssignment):  # only move SympyAssignments
605 606 607
                block.append(child)
                continue

608
            target, child_to_insert_before = find_block_to_move_to(child)
609
            if target == block:     # movement not possible
610
                target.append(child)
611
            else:
612
                if isinstance(child, ast.SympyAssignment):
613
                    exists_already = check_if_assignment_already_in_block(child, target, False)
614
                else:
615
                    exists_already = False
Martin Bauer's avatar
Martin Bauer committed
616

617
                if not exists_already:
618
                    rhs_identical = check_if_assignment_already_in_block(child, target, True)
619 620 621 622 623 624 625
                    if rhs_identical:
                        # there is already an assignment out there with the same rhs
                        # -> replace all lhs symbols in this block with the lhs of the outer assignment
                        # -> remove the local assignment (do not re-append child to the former block)
                        substitute_variables[child.lhs] = rhs_identical.lhs
                    else:
                        target.insert_before(child, child_to_insert_before)
Martin Bauer's avatar
Martin Bauer committed
626 627
                elif exists_already and exists_already.rhs == child.rhs:
                    pass
628
                else:
629 630 631 632
                    # this variable already exists in outer block, but with different rhs
                    # -> symbol has to be renamed
                    assert isinstance(child.lhs, TypedSymbol)
                    new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
633
                    target.insert_before(ast.SympyAssignment(new_symbol, child.rhs), child_to_insert_before)
634
                    substitute_variables[child.lhs] = new_symbol
635 636


Martin Bauer's avatar
Martin Bauer committed
637
def split_inner_loop(ast_node: ast.Node, symbol_groups):
Martin Bauer's avatar
Martin Bauer committed
638 639
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
640

Martin Bauer's avatar
Martin Bauer committed
641 642 643 644 645
    Args:
        ast_node: AST root
        symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
                       updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups
                       and which no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
646
    """
Martin Bauer's avatar
Martin Bauer committed
647 648
    all_loops = ast_node.atoms(ast.LoopOverCoordinate)
    inner_loop = [l for l in all_loops if l.is_innermost_loop]
649
    assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
Martin Bauer's avatar
Martin Bauer committed
650 651 652 653 654 655 656 657 658 659
    inner_loop = inner_loop[0]
    assert type(inner_loop.body) is ast.Block
    outer_loop = [l for l in all_loops if l.is_outermost_loop]
    assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
    outer_loop = outer_loop[0]

    symbols_with_temporary_array = OrderedDict()
    assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args)

    assignment_groups = []
Martin Bauer's avatar
Martin Bauer committed
660
    for symbol_group in symbol_groups:
661
        # get all dependent symbols
Martin Bauer's avatar
Martin Bauer committed
662
        symbols_to_process = list(symbol_group)
Martin Bauer's avatar
Martin Bauer committed
663 664 665 666
        symbols_resolved = set()
        while symbols_to_process:
            s = symbols_to_process.pop()
            if s in symbols_resolved:
667 668
                continue

Martin Bauer's avatar
Martin Bauer committed
669
            if s in assignment_map:  # if there is no assignment inside the loop body it is independent already
Martin Bauer's avatar
Martin Bauer committed
670
                for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
671 672
                    if not isinstance(new_symbol, AbstractField.AbstractAccess) and \
                            new_symbol not in symbols_with_temporary_array:
Martin Bauer's avatar
Martin Bauer committed
673
                        symbols_to_process.append(new_symbol)
Martin Bauer's avatar
Martin Bauer committed
674
            symbols_resolved.add(s)
675

Martin Bauer's avatar
Martin Bauer committed
676
        for symbol in symbol_group:
677
            if not isinstance(symbol, AbstractField.AbstractAccess):
678
                assert type(symbol) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
679
                new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
680 681
                symbols_with_temporary_array[symbol] = sp.IndexedBase(
                    new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
682

Martin Bauer's avatar
Martin Bauer committed
683 684 685
        assignment_group = []
        for assignment in inner_loop.body.args:
            if assignment.lhs in symbols_resolved:
686 687 688 689
                new_rhs = assignment.rhs.subs(
                    symbols_with_temporary_array.items())
                if not isinstance(assignment.lhs, AbstractField.AbstractAccess
                                  ) and assignment.lhs in symbol_group:
690
                    assert type(assignment.lhs) is TypedSymbol
691 692 693 694
                    new_ts = TypedSymbol(assignment.lhs.name,
                                         PointerType(assignment.lhs.dtype))
                    new_lhs = sp.IndexedBase(
                        new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
695
                else:
Martin Bauer's avatar
Martin Bauer committed
696 697 698
                    new_lhs = assignment.lhs
                assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
        assignment_groups.append(assignment_group)
699

700 701 702 703
    new_loops = [
        inner_loop.new_loop_with_different_body(ast.Block(group))
        for group in assignment_groups
    ]
Martin Bauer's avatar
Martin Bauer committed
704
    inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
705

Martin Bauer's avatar
Martin Bauer committed
706
    for tmp_array in symbols_with_temporary_array:
707 708
        tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
        alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start)
Martin Bauer's avatar
Martin Bauer committed
709 710 711
        free_node = ast.TemporaryMemoryFree(alloc_node)
        outer_loop.parent.insert_front(alloc_node)
        outer_loop.parent.append(free_node)
712 713


Martin Bauer's avatar
Martin Bauer committed
714
def cut_loop(loop_node, cutting_points):
Martin Bauer's avatar
Martin Bauer committed
715 716 717 718 719 720 721 722 723 724
    """Cuts loop at given cutting points.

    One loop is transformed into len(cuttingPoints)+1 new loops that range from
    old_begin to cutting_points[1], ..., cutting_points[-1] to old_end

    Modifies the ast in place

    Returns:
        list of new loop nodes
    """
Martin Bauer's avatar
Martin Bauer committed
725
    if loop_node.step != 1:
726
        raise NotImplementedError("Can only split loops that have a step of 1")
727
    new_loops = ast.Block([])
Martin Bauer's avatar
Martin Bauer committed
728 729
    new_start = loop_node.start
    cutting_points = list(cutting_points) + [loop_node.stop]
Martin Bauer's avatar
Martin Bauer committed
730 731
    for new_end in cutting_points:
        if new_end - new_start == 1:
Martin Bauer's avatar
Martin Bauer committed
732 733 734
            new_body = deepcopy(loop_node.body)
            new_body.subs({loop_node.loop_counter_symbol: new_start})
            new_loops.append(new_body)
Martin Bauer's avatar
Martin Bauer committed
735 736
        elif new_end - new_start == 0:
            pass
737
        else:
738 739 740
            new_loop = ast.LoopOverCoordinate(
                deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
                new_start, new_end, loop_node.step)
Martin Bauer's avatar
Martin Bauer committed
741
            new_loops.append(new_loop)
Martin Bauer's avatar
Martin Bauer committed
742
        new_start = new_end
Martin Bauer's avatar
Martin Bauer committed
743
    loop_node.parent.replace(loop_node, new_loops)
Martin Bauer's avatar
Martin Bauer committed
744
    return new_loops
745 746


747
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = False) -> None:
748
    """Removes conditionals that are always true/false.
749 750

    Args:
751 752 753 754 755 756
        node: ast node, all descendants of this node are simplified
        loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false
                                     depending on the surrounding loop. For example if the surrounding loop goes from
                                     x=0 to 10 and the condition is x < 0, it is removed.
                                     This analysis needs the integer set library (ISL) islpy, so it is not done by
                                     default.
757
    """
758 759 760 761 762
    for conditional in node.atoms(ast.Conditional):
        conditional.condition_expr = sp.simplify(conditional.condition_expr)
        if conditional.condition_expr == sp.true:
            conditional.parent.replace(conditional, [conditional.true_block])
        elif conditional.condition_expr == sp.false:
763
            conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
764 765 766 767 768 769
        elif loop_counter_simplification:
            try:
                # noinspection PyUnresolvedReferences
                from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
                simplify_loop_counter_dependent_conditional(conditional)
            except ImportError:
770
                warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")
771 772 773


def cleanup_blocks(node: ast.Node) -> None:
774 775 776 777 778
    """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
    if isinstance(node, ast.SympyAssignment):
        return
    elif isinstance(node, ast.Block):
        for a in list(node.args):
Martin Bauer's avatar
Martin Bauer committed
779
            cleanup_blocks(a)
780 781 782 783 784
        if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
            node.parent.replace(node, node.args)
            return
    else:
        for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
785
            cleanup_blocks(a)
786 787


788 789 790 791 792 793 794 795 796 797 798 799
class KernelConstraintsCheck:
    """Checks if the input to create_kernel is valid.

    Test the following conditions:

    - SSA Form for pure symbols:
        -  Every pure symbol may occur only once as left-hand-side of an assignment
        -  Every pure symbol that is read, may not be written to later
    - Independence / Parallelization condition:
        - a field that is written may only be read at exact the same spatial position

    (Pure symbols are symbols that are not Field.Accesses)
Martin Bauer's avatar
Martin Bauer committed
800
    """
801 802 803 804 805
    FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])

    def __init__(self, type_for_symbol, check_independence_condition):
        self._type_for_symbol = type_for_symbol

806
        self.scopes = NestedScopes()
807 808 809 810 811 812 813 814 815 816
        self._field_writes = defaultdict(set)
        self.fields_read = set()
        self.check_independence_condition = check_independence_condition

    def process_assignment(self, assignment):
        # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
        new_rhs = self.process_expression(assignment.rhs)
        new_lhs = self._process_lhs(assignment.lhs)
        return ast.SympyAssignment(new_lhs, new_rhs)