transformations.py 57.5 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
2
import hashlib
import pickle
3
import warnings
Martin Bauer's avatar
Martin Bauer committed
4
from collections import OrderedDict, defaultdict, namedtuple
5
from copy import deepcopy
Martin Bauer's avatar
Martin Bauer committed
6
from types import MappingProxyType
Martin Bauer's avatar
Martin Bauer committed
7

8
import numpy as np
9
10
import sympy as sp
from sympy.logic.boolalg import Boolean
Martin Bauer's avatar
Martin Bauer committed
11
12

import pystencils.astnodes as ast
13
import pystencils.integer_functions
14
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
15
16
17
18
from pystencils.data_types import (
    PointerType, StructType, TypedSymbol, cast_func, collate_types, create_type, get_base_type,
    get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
from pystencils.field import AbstractField, Field, FieldType
19
from pystencils.kernelparameters import FieldPointerSymbol
Martin Bauer's avatar
Martin Bauer committed
20
from pystencils.simp.assignment_collection import AssignmentCollection
Martin Bauer's avatar
Martin Bauer committed
21
from pystencils.slicing import normalize_slice
22
23


24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class NestedScopes:
    """Symbol visibility model using nested scopes

    - every accessed symbol that was not defined before, is added as a "free parameter"
    - free parameters are global, i.e. they are not in scopes
    - push/pop adds or removes a scope

    >>> s = NestedScopes()
    >>> s.access_symbol("a")
    >>> s.is_defined("a")
    False
    >>> s.free_parameters
    {'a'}
    >>> s.define_symbol("b")
    >>> s.is_defined("b")
    True
    >>> s.push()
    >>> s.is_defined_locally("b")
    False
    >>> s.define_symbol("c")
    >>> s.pop()
    >>> s.is_defined("c")
    False
    """

    def __init__(self):
        self.free_parameters = set()
        self._defined = [set()]

    def access_symbol(self, symbol):
        if not self.is_defined(symbol):
            self.free_parameters.add(symbol)

    def define_symbol(self, symbol):
        self._defined[-1].add(symbol)

    def is_defined(self, symbol):
        return any(symbol in scopes for scopes in self._defined)

    def is_defined_locally(self, symbol):
        return symbol in self._defined[-1]

    def push(self):
        self._defined.append(set())

    def pop(self):
        self._defined.pop()
        assert self.depth >= 1

    @property
    def depth(self):
        return len(self._defined)


Martin Bauer's avatar
Martin Bauer committed
78
def filtered_tree_iteration(node, node_type, stop_type=None):
79
    for arg in node.args:
Martin Bauer's avatar
Martin Bauer committed
80
        if isinstance(arg, node_type):
81
            yield arg
Martin Bauer's avatar
Martin Bauer committed
82
83
84
        elif stop_type and isinstance(node, stop_type):
            continue

Martin Bauer's avatar
Martin Bauer committed
85
        yield from filtered_tree_iteration(arg, node_type)
86
87


88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
def generic_visit(term, visitor):
    if isinstance(term, AssignmentCollection):
        new_main_assignments = generic_visit(term.main_assignments, visitor)
        new_subexpressions = generic_visit(term.subexpressions, visitor)
        return term.copy(new_main_assignments, new_subexpressions)
    elif isinstance(term, list):
        return [generic_visit(e, visitor) for e in term]
    elif isinstance(term, Assignment):
        return Assignment(term.lhs, generic_visit(term.rhs, visitor))
    elif isinstance(term, sp.Matrix):
        return term.applyfunc(lambda e: generic_visit(e, visitor))
    else:
        return visitor(term)


103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def unify_shape_symbols(body, common_shape, fields):
    """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.

    When creating a kernel with variable array sizes, all passed arrays must have the same size.
    This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead
    of one for each field. For example shape_arr1[0]  and shape_arr2[0] must be equal, so they should also be
    represented by the same symbol.

    Args:
        body: ast node, for the kernel part where substitutions is made, is modified in-place
        common_shape: shape of the field that was chosen
        fields: all fields whose shapes should be replaced by common_shape
    """
    substitutions = {}
    for field in fields:
        assert len(field.spatial_shape) == len(common_shape)
        if not field.has_fixed_shape:
            for common_shape_component, shape_component in zip(common_shape, field.spatial_shape):
                if shape_component != common_shape_component:
                    substitutions[shape_component] = common_shape_component
    if substitutions:
        body.subs(substitutions)


Martin Bauer's avatar
Martin Bauer committed
127
def get_common_shape(field_set):
128
129
    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
    ValueError is raised"""
Martin Bauer's avatar
Martin Bauer committed
130
131
132
133
134
135
136
137
    nr_of_fixed_shaped_fields = 0
    for f in field_set:
        if f.has_fixed_shape:
            nr_of_fixed_shaped_fields += 1

    if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set):
        fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
        var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
138
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
Martin Bauer's avatar
Martin Bauer committed
139
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (var_field_names, fixed_field_names)
140
141
        raise ValueError(msg)

Martin Bauer's avatar
Martin Bauer committed
142
143
144
145
    shape_set = set([f.spatial_shape for f in field_set])
    if nr_of_fixed_shaped_fields == len(field_set):
        if len(shape_set) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
146

Martin Bauer's avatar
Martin Bauer committed
147
    shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
148
149
150
    return shape


151
def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
Martin Bauer's avatar
Martin Bauer committed
152
153
154
    """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.

    Args:
155
        body: Block object with inner loop contents
Martin Bauer's avatar
Martin Bauer committed
156
157
158
159
160
161
162
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
             if None, the number of ghost layers is determined automatically and assumed to be equal for a
             all dimensions
        loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)

    Returns:
163
        tuple of loop-node, ghost_layer_info
164
165
    """
    # find correct ordering by inspecting participating FieldAccesses
166
    field_accesses = body.atoms(AbstractField.AbstractAccess)
167
168
    field_accesses = {e for e in field_accesses if not e.is_absolute_access}

Martin Bauer's avatar
Martin Bauer committed
169
170
171
172
173
174
175
    # exclude accesses to buffers from field_list, because buffers are treated separately
    field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)]
    fields = set(field_list)

    if loop_order is None:
        loop_order = get_optimal_loop_ordering(fields)

176
177
    shape = get_common_shape(fields)
    unify_shape_symbols(body, common_shape=shape, fields=fields)
Martin Bauer's avatar
Martin Bauer committed
178
179
180
181
182
183
184
185
186
187
188

    if iteration_slice is not None:
        iteration_slice = normalize_slice(iteration_slice, shape)

    if ghost_layers is None:
        required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
        ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
    if isinstance(ghost_layers, int):
        ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)

    current_body = body
Martin Bauer's avatar
Martin Bauer committed
189
    for i, loop_coordinate in enumerate(reversed(loop_order)):
Martin Bauer's avatar
Martin Bauer committed
190
        if iteration_slice is None:
Martin Bauer's avatar
Martin Bauer committed
191
192
            begin = ghost_layers[loop_coordinate][0]
            end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
193
            new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
Martin Bauer's avatar
Martin Bauer committed
194
            current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
195
        else:
Martin Bauer's avatar
Martin Bauer committed
196
            slice_component = iteration_slice[loop_coordinate]
Martin Bauer's avatar
Martin Bauer committed
197
198
            if type(slice_component) is slice:
                sc = slice_component
199
                new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
Martin Bauer's avatar
Martin Bauer committed
200
                current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
201
            else:
202
203
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
                                                 sp.sympify(slice_component))
Martin Bauer's avatar
Martin Bauer committed
204
                current_body.insert_front(assignment)
205

206
    return current_body, ghost_layers
207
208


Martin Bauer's avatar
Martin Bauer committed
209
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
Martin Bauer's avatar
Martin Bauer committed
210
    r"""
211
    Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
Martin Bauer's avatar
Martin Bauer committed
212
213
214
215
216
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.

217
218
219
220
221
222
223
224
225
    Args:
        field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
        coordinates: mapping of coordinate ids to its value, where stride*value is calculated
        previous_ptr: the pointer which is de-referenced

    Returns
        tuple with the new pointer symbol and the calculated offset

    Examples:
Martin Bauer's avatar
Martin Bauer committed
226
        >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
Martin Bauer's avatar
Martin Bauer committed
227
        >>> x, y = sp.symbols("x y")
Martin Bauer's avatar
Martin Bauer committed
228
229
        >>> prev_pointer = TypedSymbol("ptr", "double")
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
230
        (ptr_01, _stride_myfield_0*x + _stride_myfield_0)
Martin Bauer's avatar
Martin Bauer committed
231
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
232
        (ptr_01_1m2, _stride_myfield_0*x + _stride_myfield_0 + _stride_myfield_1*y - 2*_stride_myfield_1)
Martin Bauer's avatar
Martin Bauer committed
233
    """
Martin Bauer's avatar
Martin Bauer committed
234
    field = field_access.field
235
236
    offset = 0
    name = ""
Martin Bauer's avatar
Martin Bauer committed
237
    list_to_hash = []
Martin Bauer's avatar
Martin Bauer committed
238
239
    for coordinate_id, coordinate_value in coordinates.items():
        offset += field.strides[coordinate_id] * coordinate_value
240

Martin Bauer's avatar
Martin Bauer committed
241
        if coordinate_id < field.spatial_dimensions:
242
            offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
Martin Bauer's avatar
Martin Bauer committed
243
            if type(field_access.offsets[coordinate_id]) is int:
244
                name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
245
            else:
Martin Bauer's avatar
Martin Bauer committed
246
                list_to_hash.append(field_access.offsets[coordinate_id])
247
        else:
Martin Bauer's avatar
Martin Bauer committed
248
            if type(coordinate_value) is int:
Martin Bauer's avatar
Martin Bauer committed
249
                name += "_%d%d" % (coordinate_id, coordinate_value)
250
            else:
Martin Bauer's avatar
Martin Bauer committed
251
                list_to_hash.append(coordinate_value)
252

Martin Bauer's avatar
Martin Bauer committed
253
    if len(list_to_hash) > 0:
254
        name += hashlib.md5(pickle.dumps(list_to_hash)).hexdigest()[:16]
255

Martin Bauer's avatar
Martin Bauer committed
256
    name = name.replace("-", 'm')
Martin Bauer's avatar
Martin Bauer committed
257
258
    new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
    return new_ptr, offset
259
260


261
def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
262
    """
Martin Bauer's avatar
Martin Bauer committed
263
    Creates base pointer specification for :func:`resolve_field_accesses` function.
Martin Bauer's avatar
Martin Bauer committed
264
265
266

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
267
    zero directly in the field access. These specifications are defined dependent on the loop ordering.
Martin Bauer's avatar
Martin Bauer committed
268
269
    This function translates more readable version into the specification above.

270
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
271
272
273
274
275
276
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

Martin Bauer's avatar
Martin Bauer committed
277
278
279
    Args:
        base_pointer_specification: nested list with above specifications
        loop_order: list with ordering of loops from outer to inner
Martin Bauer's avatar
Martin Bauer committed
280
281
        spatial_dimensions: number of spatial dimensions
        index_dimensions: number of index dimensions
Martin Bauer's avatar
Martin Bauer committed
282
283
284

    Returns:
        list of tuples that can be passed to :func:`resolve_field_accesses`
Martin Bauer's avatar
Martin Bauer committed
285
286
287
288
289

    Examples:
        >>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
        ...                         spatial_dimensions=3, index_dimensions=1)
        [[0], [3], [1, 2]]
290
291
    """
    result = []
Martin Bauer's avatar
Martin Bauer committed
292
293
    specified_coordinates = set()
    loop_order = list(reversed(loop_order))
Martin Bauer's avatar
Martin Bauer committed
294
    for spec_group in base_pointer_specification:
Martin Bauer's avatar
Martin Bauer committed
295
296
297
        new_group = []

        def add_new_element(elem):
Martin Bauer's avatar
Martin Bauer committed
298
            if elem >= spatial_dimensions + index_dimensions:
299
                raise ValueError("Coordinate %d does not exist" % (elem,))
Martin Bauer's avatar
Martin Bauer committed
300
301
            new_group.append(elem)
            if elem in specified_coordinates:
302
                raise ValueError("Coordinate %d specified two times" % (elem,))
Martin Bauer's avatar
Martin Bauer committed
303
            specified_coordinates.add(elem)
304

Martin Bauer's avatar
Martin Bauer committed
305
        for element in spec_group:
306
            if type(element) is int:
Martin Bauer's avatar
Martin Bauer committed
307
                add_new_element(element)
308
309
310
311
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
Martin Bauer's avatar
Martin Bauer committed
312
                    add_new_element(loop_order[index])
313
314
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
Martin Bauer's avatar
Martin Bauer committed
315
                    add_new_element(loop_order[-index])
316
                elif element == "all":
Martin Bauer's avatar
Martin Bauer committed
317
                    for i in range(spatial_dimensions):
Martin Bauer's avatar
Martin Bauer committed
318
                        add_new_element(i)
319
320
321
322
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
Martin Bauer's avatar
Martin Bauer committed
323
                add_new_element(spatial_dimensions + index)
324
            else:
325
                raise ValueError("Unknown specification %s" % (element,))
326

Martin Bauer's avatar
Martin Bauer committed
327
        result.append(new_group)
328

Martin Bauer's avatar
Martin Bauer committed
329
    all_coordinates = set(range(spatial_dimensions + index_dimensions))
Martin Bauer's avatar
Martin Bauer committed
330
    rest = all_coordinates - specified_coordinates
331
332
    if rest:
        result.append(list(rest))
333

334
335
336
    return result


Martin Bauer's avatar
Martin Bauer committed
337
338
339
340
341
342
343
344
345
346
347
348
349
def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
    """Used for buffer fields to determine the linearized index of the buffer dependent on loop counter symbols.

    Args:
        ast_node: ast before any field accesses are resolved
        loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
                       for GPU kernels: list of 'loop counters' from inner to outer loop
        loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default

    Returns:
        base buffer index - required by 'resolve_buffer_accesses' function
    """
    if loop_counters is None or loop_iterations is None:
350
        loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
Martin Bauer's avatar
Martin Bauer committed
351
        loops.reverse()
352
        parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
Martin Bauer's avatar
Martin Bauer committed
353
        assert len(loops) == len(parents_of_innermost_loop)
354
        assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
Martin Bauer's avatar
Martin Bauer committed
355
356
357
358

        loop_iterations = [(l.stop - l.start) / l.step for l in loops]
        loop_counters = [l.loop_counter_symbol for l in loops]

359
    field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
360
    buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
Martin Bauer's avatar
Martin Bauer committed
361
362
363
364
365
366
    loop_counters = [v * len(buffer_accesses) for v in loop_counters]

    base_buffer_index = loop_counters[0]
    stride = 1
    for idx, var in enumerate(loop_counters[1:]):
        cur_stride = loop_iterations[idx]
367
        stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
Martin Bauer's avatar
Martin Bauer committed
368
369
370
371
        base_buffer_index += var * stride
    return base_buffer_index


372
373
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):

Martin Bauer's avatar
Martin Bauer committed
374
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
375
        if isinstance(expr, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
376
            field_access = expr
377
378

            # Do not apply transformation if field is not a buffer
Martin Bauer's avatar
Martin Bauer committed
379
            if not FieldType.is_buffer(field_access.field):
380
381
                return expr

Martin Bauer's avatar
Martin Bauer committed
382
            buffer = field_access.field
383
            field_ptr = FieldPointerSymbol(buffer.name, buffer.dtype, const=buffer.name in read_only_field_names)
384

Martin Bauer's avatar
Martin Bauer committed
385
386
            buffer_index = base_buffer_index
            if len(field_access.index) > 1:
387
                raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')
388

Martin Bauer's avatar
Martin Bauer committed
389
390
391
            if len(field_access.index) > 0:
                cell_index = field_access.index[0]
                buffer_index += cell_index
392

393
            result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
Martin Bauer's avatar
Martin Bauer committed
394
                                             field_access.index)
395

Martin Bauer's avatar
Martin Bauer committed
396
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
397
398
399
400
        else:
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr

401
402
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
403
404
405
406
407
408
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
409
410
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
411
        else:
Martin Bauer's avatar
Martin Bauer committed
412
413
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
414

Martin Bauer's avatar
Martin Bauer committed
415
    return visit_node(ast_node)
416

417

418
def resolve_field_accesses(ast_node, read_only_field_names=set(),
Martin Bauer's avatar
Martin Bauer committed
419
420
                           field_to_base_pointer_info=MappingProxyType({}),
                           field_to_fixed_coordinates=MappingProxyType({})):
Martin Bauer's avatar
Martin Bauer committed
421
422
423
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

424
425
426
427
428
429
    Args:
        ast_node: the AST root
        read_only_field_names: set of field names which are considered read-only
        field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
                                    for details see :func:`parse_base_pointer_info`
        field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
Martin Bauer's avatar
Martin Bauer committed
430
                                    counters to index the field these symbols are used as coordinates
431
432
433

    Returns
        transformed AST
Martin Bauer's avatar
Martin Bauer committed
434
    """
435
436
    field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
    field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
437

Martin Bauer's avatar
Martin Bauer committed
438
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
439
        if isinstance(expr, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
440
            field_access = expr
Martin Bauer's avatar
Martin Bauer committed
441
            field = field_access.field
442

Martin Bauer's avatar
Martin Bauer committed
443
            if field_access.indirect_addressing_fields:
444
445
446
447
448
                new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
                                    for off in field_access.offsets)
                new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
                                    if isinstance(ind, sp.Basic) else ind
                                    for ind in field_access.index)
Martin Bauer's avatar
Martin Bauer committed
449
                field_access = Field.Access(field_access.field, new_offsets,
450
                                            new_indices, field_access.is_absolute_access)
451

Martin Bauer's avatar
Martin Bauer committed
452
453
            if field.name in field_to_base_pointer_info:
                base_pointer_info = field_to_base_pointer_info[field.name]
454
            else:
455
456
457
458
                base_pointer_info = [
                    list(
                        range(field.index_dimensions + field.spatial_dimensions))
                ]
459

460
461
462
463
            field_ptr = FieldPointerSymbol(
                field.name,
                field.dtype,
                const=field.name in read_only_field_names)
464

Martin Bauer's avatar
Martin Bauer committed
465
466
467
468
            def create_coordinate_dict(group_param):
                coordinates = {}
                for e in group_param:
                    if e < field.spatial_dimensions:
Martin Bauer's avatar
Martin Bauer committed
469
                        if field.name in field_to_fixed_coordinates:
470
                            if not field_access.is_absolute_access:
471
                                coordinates[e] = field_to_fixed_coordinates[field.name][e]
472
473
                            else:
                                coordinates[e] = 0
474
                        else:
Martin Bauer's avatar
Martin Bauer committed
475
                            if not field_access.is_absolute_access:
476
                                coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
Martin Bauer's avatar
Martin Bauer committed
477
478
                            else:
                                coordinates[e] = 0
Martin Bauer's avatar
Martin Bauer committed
479
                        coordinates[e] *= field.dtype.item_size
480
                    else:
481
                        if isinstance(field.dtype, StructType):
Martin Bauer's avatar
Martin Bauer committed
482
                            assert field.index_dimensions == 1
Martin Bauer's avatar
Martin Bauer committed
483
                            accessed_field_name = field_access.index[0]
484
485
                            if isinstance(accessed_field_name, sp.Symbol):
                                accessed_field_name = accessed_field_name.name
Martin Bauer's avatar
Martin Bauer committed
486
                            assert isinstance(accessed_field_name, str)
487
                            coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
488
                        else:
489
                            coordinates[e] = field_access.index[e - field.spatial_dimensions]
490

Martin Bauer's avatar
Martin Bauer committed
491
                return coordinates
492

Martin Bauer's avatar
Martin Bauer committed
493
            last_pointer = field_ptr
494

Martin Bauer's avatar
Martin Bauer committed
495
496
            for group in reversed(base_pointer_info[1:]):
                coord_dict = create_coordinate_dict(group)
497
                new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
498
                if new_ptr not in enclosing_block.symbols_defined:
499
500
                    new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
                    enclosing_block.insert_before(new_assignment, sympy_assignment)
Martin Bauer's avatar
Martin Bauer committed
501
                last_pointer = new_ptr
502

Martin Bauer's avatar
Martin Bauer committed
503
            coord_dict = create_coordinate_dict(base_pointer_info[0])
504
505
506
            _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
            result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
                                             field_access.offsets, field_access.index)
507

Martin Bauer's avatar
Martin Bauer committed
508
            if isinstance(get_base_type(field_access.field.dtype), StructType):
509
510
511
512
                accessed_field_name = field_access.index[0]
                if isinstance(accessed_field_name, sp.Symbol):
                    accessed_field_name = accessed_field_name.name
                new_type = field_access.field.dtype.get_element_type(accessed_field_name)
513
                result = reinterpret_cast_func(result, new_type)
514

Martin Bauer's avatar
Martin Bauer committed
515
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
516
        else:
Martin Bauer's avatar
Martin Bauer committed
517
518
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr
519

520
521
522
523
524
            if hasattr(expr, 'args'):
                new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
            else:
                new_args = []
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
525
526
527
528
529
530
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
531
532
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
533
534
535
        elif isinstance(sub_ast, ast.Conditional):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
536
            sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, enclosing_block, sub_ast)
537
538
539
            visit_node(sub_ast.true_block)
            if sub_ast.false_block:
                visit_node(sub_ast.false_block)
540
        else:
541
542
543
            if isinstance(sub_ast, (bool, int, float)):
                return
            for a in sub_ast.args:
Martin Bauer's avatar
Martin Bauer committed
544
                visit_node(a)
545

Martin Bauer's avatar
Martin Bauer committed
546
    return visit_node(ast_node)
547
548


Martin Bauer's avatar
Martin Bauer committed
549
def move_constants_before_loop(ast_node):
550
551
    """Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.

Martin Bauer's avatar
Martin Bauer committed
552
    Call this after creating the loop structure with :func:`make_loop_over_domain`
Martin Bauer's avatar
Martin Bauer committed
553
    """
Martin Bauer's avatar
Martin Bauer committed
554
    def find_block_to_move_to(node):
Martin Bauer's avatar
Martin Bauer committed
555
556
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
557
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
558
559
560
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
561
562
        assert isinstance(node.parent, ast.Block)

Martin Bauer's avatar
Martin Bauer committed
563
564
        last_block = node.parent
        last_block_child = node
565
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
566
        prev_element = node
567
568
        while element:
            if isinstance(element, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
569
570
                last_block = element
                last_block_child = prev_element
571
572

            if isinstance(element, ast.Conditional):
573
                break
574
            else:
Martin Bauer's avatar
Martin Bauer committed
575
576
                critical_symbols = element.symbols_defined
            if node.undefined_symbols.intersection(critical_symbols):
577
                break
Martin Bauer's avatar
Martin Bauer committed
578
            prev_element = element
579
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
580
        return last_block, last_block_child
581

582
    def check_if_assignment_already_in_block(assignment, target_block, rhs_or_lhs=True):
Martin Bauer's avatar
Martin Bauer committed
583
        for arg in target_block.args:
584
585
            if type(arg) is not ast.SympyAssignment:
                continue
586
            if (rhs_or_lhs and arg.rhs == assignment.rhs) or (not rhs_or_lhs and arg.lhs == assignment.lhs):
587
588
589
                return arg
        return None

Martin Bauer's avatar
Martin Bauer committed
590
    def get_blocks(node, result_list):
591
        if isinstance(node, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
592
            result_list.append(node)
593
594
        if isinstance(node, ast.Node):
            for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
595
                get_blocks(a, result_list)
596

Martin Bauer's avatar
Martin Bauer committed
597
598
599
    all_blocks = []
    get_blocks(ast_node, all_blocks)
    for block in all_blocks:
Martin Bauer's avatar
Martin Bauer committed
600
        children = block.take_child_nodes()
601
602
603
604
        # Every time a symbol can be replaced in the current block because the assignment
        # was found in a parent block, but with a different lhs symbol (same rhs)
        # the outer symbol is inserted here as key.
        substitute_variables = {}
605
        for child in children:
606
607
608
            # Before traversing the next child, all symbols are substituted first.
            child.subs(substitute_variables)

609
            if not isinstance(child, ast.SympyAssignment):  # only move SympyAssignments
610
611
612
                block.append(child)
                continue

613
            target, child_to_insert_before = find_block_to_move_to(child)
614
            if target == block:     # movement not possible
615
                target.append(child)
616
            else:
617
                if isinstance(child, ast.SympyAssignment):
618
                    exists_already = check_if_assignment_already_in_block(child, target, False)
619
                else:
620
                    exists_already = False
Martin Bauer's avatar
Martin Bauer committed
621

622
                if not exists_already:
623
                    rhs_identical = check_if_assignment_already_in_block(child, target, True)
624
625
626
627
628
629
630
                    if rhs_identical:
                        # there is already an assignment out there with the same rhs
                        # -> replace all lhs symbols in this block with the lhs of the outer assignment
                        # -> remove the local assignment (do not re-append child to the former block)
                        substitute_variables[child.lhs] = rhs_identical.lhs
                    else:
                        target.insert_before(child, child_to_insert_before)
Martin Bauer's avatar
Martin Bauer committed
631
632
                elif exists_already and exists_already.rhs == child.rhs:
                    pass
633
                else:
634
635
636
637
                    # this variable already exists in outer block, but with different rhs
                    # -> symbol has to be renamed
                    assert isinstance(child.lhs, TypedSymbol)
                    new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
638
                    target.insert_before(ast.SympyAssignment(new_symbol, child.rhs), child_to_insert_before)
639
                    substitute_variables[child.lhs] = new_symbol
640
641


Martin Bauer's avatar
Martin Bauer committed
642
def split_inner_loop(ast_node: ast.Node, symbol_groups):
Martin Bauer's avatar
Martin Bauer committed
643
644
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
645

Martin Bauer's avatar
Martin Bauer committed
646
647
648
649
650
    Args:
        ast_node: AST root
        symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
                       updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups
                       and which no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
651
    """
Martin Bauer's avatar
Martin Bauer committed
652
653
    all_loops = ast_node.atoms(ast.LoopOverCoordinate)
    inner_loop = [l for l in all_loops if l.is_innermost_loop]
654
    assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
Martin Bauer's avatar
Martin Bauer committed
655
656
657
658
659
660
661
662
663
664
    inner_loop = inner_loop[0]
    assert type(inner_loop.body) is ast.Block
    outer_loop = [l for l in all_loops if l.is_outermost_loop]
    assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
    outer_loop = outer_loop[0]

    symbols_with_temporary_array = OrderedDict()
    assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args)

    assignment_groups = []
Martin Bauer's avatar
Martin Bauer committed
665
    for symbol_group in symbol_groups:
666
        # get all dependent symbols
Martin Bauer's avatar
Martin Bauer committed
667
        symbols_to_process = list(symbol_group)
Martin Bauer's avatar
Martin Bauer committed
668
669
670
671
        symbols_resolved = set()
        while symbols_to_process:
            s = symbols_to_process.pop()
            if s in symbols_resolved:
672
673
                continue

Martin Bauer's avatar
Martin Bauer committed
674
            if s in assignment_map:  # if there is no assignment inside the loop body it is independent already
Martin Bauer's avatar
Martin Bauer committed
675
                for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
676
677
                    if not isinstance(new_symbol, AbstractField.AbstractAccess) and \
                            new_symbol not in symbols_with_temporary_array:
Martin Bauer's avatar
Martin Bauer committed
678
                        symbols_to_process.append(new_symbol)
Martin Bauer's avatar
Martin Bauer committed
679
            symbols_resolved.add(s)
680

Martin Bauer's avatar
Martin Bauer committed
681
        for symbol in symbol_group:
682
            if not isinstance(symbol, AbstractField.AbstractAccess):
683
                assert type(symbol) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
684
                new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
685
686
                symbols_with_temporary_array[symbol] = sp.IndexedBase(
                    new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
687

Martin Bauer's avatar
Martin Bauer committed
688
689
690
        assignment_group = []
        for assignment in inner_loop.body.args:
            if assignment.lhs in symbols_resolved:
691
692
693
694
                new_rhs = assignment.rhs.subs(
                    symbols_with_temporary_array.items())
                if not isinstance(assignment.lhs, AbstractField.AbstractAccess
                                  ) and assignment.lhs in symbol_group:
695
                    assert type(assignment.lhs) is TypedSymbol
696
697
698
699
                    new_ts = TypedSymbol(assignment.lhs.name,
                                         PointerType(assignment.lhs.dtype))
                    new_lhs = sp.IndexedBase(
                        new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
700
                else:
Martin Bauer's avatar
Martin Bauer committed
701
702
703
                    new_lhs = assignment.lhs
                assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
        assignment_groups.append(assignment_group)
704

705
706
707
708
    new_loops = [
        inner_loop.new_loop_with_different_body(ast.Block(group))
        for group in assignment_groups
    ]
Martin Bauer's avatar
Martin Bauer committed
709
    inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
710

Martin Bauer's avatar
Martin Bauer committed
711
    for tmp_array in symbols_with_temporary_array:
712
713
        tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
        alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start)
Martin Bauer's avatar
Martin Bauer committed
714
715
716
        free_node = ast.TemporaryMemoryFree(alloc_node)
        outer_loop.parent.insert_front(alloc_node)
        outer_loop.parent.append(free_node)
717
718


Martin Bauer's avatar
Martin Bauer committed
719
def cut_loop(loop_node, cutting_points):
Martin Bauer's avatar
Martin Bauer committed
720
721
722
723
724
725
726
727
728
729
    """Cuts loop at given cutting points.

    One loop is transformed into len(cuttingPoints)+1 new loops that range from
    old_begin to cutting_points[1], ..., cutting_points[-1] to old_end

    Modifies the ast in place

    Returns:
        list of new loop nodes
    """
Martin Bauer's avatar
Martin Bauer committed
730
    if loop_node.step != 1:
731
        raise NotImplementedError("Can only split loops that have a step of 1")
732
    new_loops = ast.Block([])
Martin Bauer's avatar
Martin Bauer committed
733
734
    new_start = loop_node.start
    cutting_points = list(cutting_points) + [loop_node.stop]
Martin Bauer's avatar
Martin Bauer committed
735
736
    for new_end in cutting_points:
        if new_end - new_start == 1:
Martin Bauer's avatar
Martin Bauer committed
737
738
739
            new_body = deepcopy(loop_node.body)
            new_body.subs({loop_node.loop_counter_symbol: new_start})
            new_loops.append(new_body)
Martin Bauer's avatar
Martin Bauer committed
740
741
        elif new_end - new_start == 0:
            pass
742
        else:
743
744
745
            new_loop = ast.LoopOverCoordinate(
                deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
                new_start, new_end, loop_node.step)
Martin Bauer's avatar
Martin Bauer committed
746
            new_loops.append(new_loop)
Martin Bauer's avatar
Martin Bauer committed
747
        new_start = new_end
Martin Bauer's avatar
Martin Bauer committed
748
    loop_node.parent.replace(loop_node, new_loops)
Martin Bauer's avatar
Martin Bauer committed
749
    return new_loops
750
751


752
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = False) -> None:
753
    """Removes conditionals that are always true/false.
754
755

    Args:
756
757
758
759
760
761
        node: ast node, all descendants of this node are simplified
        loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false
                                     depending on the surrounding loop. For example if the surrounding loop goes from
                                     x=0 to 10 and the condition is x < 0, it is removed.
                                     This analysis needs the integer set library (ISL) islpy, so it is not done by
                                     default.
762
    """
763
764
765
766
767
    for conditional in node.atoms(ast.Conditional):
        conditional.condition_expr = sp.simplify(conditional.condition_expr)
        if conditional.condition_expr == sp.true:
            conditional.parent.replace(conditional, [conditional.true_block])
        elif conditional.condition_expr == sp.false:
768
            conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
769
770
771
772
773
774
        elif loop_counter_simplification:
            try:
                # noinspection PyUnresolvedReferences
                from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
                simplify_loop_counter_dependent_conditional(conditional)
            except ImportError:
775
                warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")
776
777
778


def cleanup_blocks(node: ast.Node) -> None:
779
780
781
782
783
    """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
    if isinstance(node, ast.SympyAssignment):
        return
    elif isinstance(node, ast.Block):
        for a in list(node.args):
Martin Bauer's avatar
Martin Bauer committed
784
            cleanup_blocks(a)
785
786
787
788
789
        if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
            node.parent.replace(node, node.args)
            return
    else:
        for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
790
            cleanup_blocks(a)
791
792


793
794
795
796
797
798
799
800
801
802
803
804
class KernelConstraintsCheck:
    """Checks if the input to create_kernel is valid.

    Test the following conditions:

    - SSA Form for pure symbols:
        -  Every pure symbol may occur only once as left-hand-side of an assignment
        -  Every pure symbol that is read, may not be written to later
    - Independence / Parallelization condition:
        - a field that is written may only be read at exact the same spatial position

    (Pure symbols are symbols that are not Field.Accesses)
Martin Bauer's avatar
Martin Bauer committed
805
    """
806
807
808
809
810
    FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])

    def __init__(self, type_for_symbol, check_independence_condition):
        self._type_for_symbol = type_for_symbol

811
        self.scopes = NestedScopes()
812
813
814
815
816
817
818
819
820
821
        self._field_writes = defaultdict(set)
        self.fields_read = set()
        self.check_independence_condition = check_independence_condition

    def process_assignment(self, assignment):
        # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
        new_rhs = self.process_expression(assignment.rhs)
        new_lhs = self._process_lhs(assignment.lhs)
        return ast.SympyAssignment(new_lhs, new_rhs)

822
    def process_expression(self, rhs, type_constants=True):
823
824
        from pystencils.interpolation_astnodes import InterpolatorAccess

825
        self._update_accesses_rhs(rhs)
826
        if isinstance(rhs, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
827
            self.fields_read.add(rhs.field)
Martin Bauer's avatar
Martin Bauer committed
828
            self.fields_read.update(rhs.indirect_addressing_fields)
829
            return rhs
830
831
832
833
834
        elif isinstance(rhs, InterpolatorAccess):
            new_args = [self.process_expression(arg, type_constants) for arg in rhs.offsets]
            if new_args:
                rhs.offsets = new_args
            return rhs
835
836
837
        elif isinstance(rhs, TypedSymbol):
            return rhs
        elif isinstance(rhs, sp.Symbol):
838
            return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
839
840
        elif type_constants and isinstance(rhs, np.generic):
            return cast_func(rhs, create_type(rhs.dtype))
841
        elif type_constants and isinstance(rhs, sp.Number):
842
            return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
843
844
845
846
847
848
849
850
851
852
853
854
855
856
        # Very important that this clause comes before BooleanFunction
        elif isinstance(rhs, cast_func):
            return cast_func(
                self.process_expression(rhs.args[0], type_constants=False),
                rhs.dtype)
        elif isinstance(rhs, sp.boolalg.BooleanFunction) or \
                type(rhs) in pystencils.integer_functions.__dict__.values():
            new_args = [self.process_expression(a, type_constants) for a in rhs.args]
            types_of_expressions = [get_type_of_expression(a) for a in new_args]
            arg_type = collate_types(types_of_expressions, forbid_collation_to_float=True)
            new_args = [a if not hasattr(a, 'dtype') or a.dtype == arg_type
                        else cast_func(a, arg_type)
                        for a in new_args]
            return rhs.func(*new_args)
857
        elif isinstance(rhs, sp.Mul):
858
859
860
861
            new_args = [
                self.process_expression(arg, type_constants)
                if arg not in (-1, 1) else arg for arg in rhs.args
            ]
862
            return rhs.func(*new_args) if new_args else rhs
863
864
        elif isinstance(rhs, sp.Indexed):
            return rhs
865
866
867
        else:
            if isinstance(rhs, sp.Pow):
                # don't process exponents -> they should remain integers
868
869
870
                return sp.Pow(
                    self.process_expression(rhs.args[0], type_constants),
                    rhs.args[1])
871
            else:
872
873
874
875
                new_args = [
                    self.process_expression(arg, type_constants)
                    for arg in rhs.args
                ]
876
                return rhs.func(*new_args) if new_args else rhs
877
878
879
880
881
882
883
884

    @property
    def fields_written(self):
        return set(k.field for k, v in self._field_writes.items() if len(v))

    def _process_lhs(self, lhs):
        assert isinstance(lhs, sp.Symbol)
        self._update_accesses_lhs(lhs)
885
        if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
886
887
888
889
890
            return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
        else:
            return lhs

    def _update_accesses_lhs(self, lhs):
891
        if isinstance(lhs, AbstractField.AbstractAccess):
892
893
            fai = self.FieldAndIndex(lhs.field, lhs.index)
            self._field_writes[fai].add(lhs.offsets)
Martin Bauer's avatar
Martin Bauer committed
894
895
896
897
            #if len(self._field_writes[fai]) > 1:
            #    raise ValueError(
            #        "Field {} is written at two different locations".format(
            #            lhs.field.name))
898
        elif isinstance(lhs, sp.Symbol):
899
            if self.scopes.is_defined_locally(lhs):
900
901
902
                raise ValueError(
                    "Assignments not in SSA form, multiple assignments to {}".
                    format(lhs.name))
903
            if lhs in self.scopes.free_parameters:
904
905
906
                raise ValueError(
                    "Symbol {} is written, after it has been read".format(
                        lhs.name))
907
            self.scopes.define_symbol(lhs)
908
909

    def _update_accesses_rhs(self, rhs):
910
        if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition:
911
912
            writes = self._field_writes[self.FieldAndIndex(
                rhs.field, rhs.index)]
913
914
915
            for write_offset in writes:
                assert len(writes) == 1
                if write_offset != rhs.offsets:
916
917
                    raise ValueError("Violation of loop independence condition. Field "
                                     "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
918
919
            self.fields_read.add(rhs.field)
        elif isinstance(rhs, sp.Symbol):
920
            self.scopes.access_symbol(rhs)
921
922
923
924
925


def add_types(eqs, type_for_symbol, check_independence_condition):
    """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.

Martin Bauer's avatar
Martin Bauer committed
926
927
    Additionally returns sets of all fields which are read/written

928
929
930
931
932
933
934
935
936
    Args:
        eqs: list of equations
        type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
        check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
                                      kernels

    Returns:
        ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
         list of equations where symbols have been replaced by typed symbols
Martin Bauer's avatar
Martin Bauer committed
937
    """
938
    if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):