transformations.py 56.3 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
2
import hashlib
import pickle
3
import warnings
Martin Bauer's avatar
Martin Bauer committed
4
from collections import OrderedDict, defaultdict, namedtuple
5
from copy import deepcopy
Martin Bauer's avatar
Martin Bauer committed
6
from types import MappingProxyType
Martin Bauer's avatar
Martin Bauer committed
7

8
import numpy as np
9
10
import sympy as sp
from sympy.logic.boolalg import Boolean
Martin Bauer's avatar
Martin Bauer committed
11
12

import pystencils.astnodes as ast
13
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
14
15
16
17
from pystencils.data_types import (
    PointerType, StructType, TypedSymbol, cast_func, collate_types, create_type, get_base_type,
    get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
from pystencils.field import AbstractField, Field, FieldType
18
from pystencils.kernelparameters import FieldPointerSymbol
Martin Bauer's avatar
Martin Bauer committed
19
from pystencils.simp.assignment_collection import AssignmentCollection
Martin Bauer's avatar
Martin Bauer committed
20
from pystencils.slicing import normalize_slice
21
22


23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class NestedScopes:
    """Symbol visibility model using nested scopes

    - every accessed symbol that was not defined before, is added as a "free parameter"
    - free parameters are global, i.e. they are not in scopes
    - push/pop adds or removes a scope

    >>> s = NestedScopes()
    >>> s.access_symbol("a")
    >>> s.is_defined("a")
    False
    >>> s.free_parameters
    {'a'}
    >>> s.define_symbol("b")
    >>> s.is_defined("b")
    True
    >>> s.push()
    >>> s.is_defined_locally("b")
    False
    >>> s.define_symbol("c")
    >>> s.pop()
    >>> s.is_defined("c")
    False
    """

    def __init__(self):
        self.free_parameters = set()
        self._defined = [set()]

    def access_symbol(self, symbol):
        if not self.is_defined(symbol):
            self.free_parameters.add(symbol)

    def define_symbol(self, symbol):
        self._defined[-1].add(symbol)

    def is_defined(self, symbol):
        return any(symbol in scopes for scopes in self._defined)

    def is_defined_locally(self, symbol):
        return symbol in self._defined[-1]

    def push(self):
        self._defined.append(set())

    def pop(self):
        self._defined.pop()
        assert self.depth >= 1

    @property
    def depth(self):
        return len(self._defined)


Martin Bauer's avatar
Martin Bauer committed
77
def filtered_tree_iteration(node, node_type, stop_type=None):
78
    for arg in node.args:
Martin Bauer's avatar
Martin Bauer committed
79
        if isinstance(arg, node_type):
80
            yield arg
Martin Bauer's avatar
Martin Bauer committed
81
82
83
        elif stop_type and isinstance(node, stop_type):
            continue

Martin Bauer's avatar
Martin Bauer committed
84
        yield from filtered_tree_iteration(arg, node_type)
85
86


87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
def generic_visit(term, visitor):
    if isinstance(term, AssignmentCollection):
        new_main_assignments = generic_visit(term.main_assignments, visitor)
        new_subexpressions = generic_visit(term.subexpressions, visitor)
        return term.copy(new_main_assignments, new_subexpressions)
    elif isinstance(term, list):
        return [generic_visit(e, visitor) for e in term]
    elif isinstance(term, Assignment):
        return Assignment(term.lhs, generic_visit(term.rhs, visitor))
    elif isinstance(term, sp.Matrix):
        return term.applyfunc(lambda e: generic_visit(e, visitor))
    else:
        return visitor(term)


102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def unify_shape_symbols(body, common_shape, fields):
    """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.

    When creating a kernel with variable array sizes, all passed arrays must have the same size.
    This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead
    of one for each field. For example shape_arr1[0]  and shape_arr2[0] must be equal, so they should also be
    represented by the same symbol.

    Args:
        body: ast node, for the kernel part where substitutions is made, is modified in-place
        common_shape: shape of the field that was chosen
        fields: all fields whose shapes should be replaced by common_shape
    """
    substitutions = {}
    for field in fields:
        assert len(field.spatial_shape) == len(common_shape)
        if not field.has_fixed_shape:
            for common_shape_component, shape_component in zip(common_shape, field.spatial_shape):
                if shape_component != common_shape_component:
                    substitutions[shape_component] = common_shape_component
    if substitutions:
        body.subs(substitutions)


Martin Bauer's avatar
Martin Bauer committed
126
def get_common_shape(field_set):
127
128
    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
    ValueError is raised"""
Martin Bauer's avatar
Martin Bauer committed
129
130
131
132
133
134
135
136
    nr_of_fixed_shaped_fields = 0
    for f in field_set:
        if f.has_fixed_shape:
            nr_of_fixed_shaped_fields += 1

    if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set):
        fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
        var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
137
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
Martin Bauer's avatar
Martin Bauer committed
138
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (var_field_names, fixed_field_names)
139
140
        raise ValueError(msg)

Martin Bauer's avatar
Martin Bauer committed
141
142
143
144
    shape_set = set([f.spatial_shape for f in field_set])
    if nr_of_fixed_shaped_fields == len(field_set):
        if len(shape_set) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
145

Martin Bauer's avatar
Martin Bauer committed
146
    shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
147
148
149
    return shape


150
151
152
153
def make_loop_over_domain(body,
                          iteration_slice=None,
                          ghost_layers=None,
                          loop_order=None):
Martin Bauer's avatar
Martin Bauer committed
154
155
156
    """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.

    Args:
157
        body: Block object with inner loop contents
Martin Bauer's avatar
Martin Bauer committed
158
159
160
161
162
163
164
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
             if None, the number of ghost layers is determined automatically and assumed to be equal for a
             all dimensions
        loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)

    Returns:
165
        tuple of loop-node, ghost_layer_info
166
167
    """
    # find correct ordering by inspecting participating FieldAccesses
168
    field_accesses = body.atoms(AbstractField.AbstractAccess)
169
170
    field_accesses = {e for e in field_accesses if not e.is_absolute_access}

Martin Bauer's avatar
Martin Bauer committed
171
172
173
174
175
176
177
    # exclude accesses to buffers from field_list, because buffers are treated separately
    field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)]
    fields = set(field_list)

    if loop_order is None:
        loop_order = get_optimal_loop_ordering(fields)

178
179
    shape = get_common_shape(fields)
    unify_shape_symbols(body, common_shape=shape, fields=fields)
Martin Bauer's avatar
Martin Bauer committed
180
181
182
183
184
185
186
187
188
189
190

    if iteration_slice is not None:
        iteration_slice = normalize_slice(iteration_slice, shape)

    if ghost_layers is None:
        required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
        ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
    if isinstance(ghost_layers, int):
        ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)

    current_body = body
Martin Bauer's avatar
Martin Bauer committed
191
    for i, loop_coordinate in enumerate(reversed(loop_order)):
Martin Bauer's avatar
Martin Bauer committed
192
        if iteration_slice is None:
Martin Bauer's avatar
Martin Bauer committed
193
194
            begin = ghost_layers[loop_coordinate][0]
            end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
195
196
            new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate,
                                              begin, end, 1)
Martin Bauer's avatar
Martin Bauer committed
197
            current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
198
        else:
Martin Bauer's avatar
Martin Bauer committed
199
            slice_component = iteration_slice[loop_coordinate]
Martin Bauer's avatar
Martin Bauer committed
200
201
            if type(slice_component) is slice:
                sc = slice_component
202
203
204
                new_loop = ast.LoopOverCoordinate(current_body,
                                                  loop_coordinate, sc.start,
                                                  sc.stop, sc.step)
Martin Bauer's avatar
Martin Bauer committed
205
                current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
206
            else:
207
208
209
                assignment = ast.SympyAssignment(
                    ast.LoopOverCoordinate.get_loop_counter_symbol(
                        loop_coordinate), sp.sympify(slice_component))
Martin Bauer's avatar
Martin Bauer committed
210
                current_body.insert_front(assignment)
211

212
    return current_body, ghost_layers
213
214


Martin Bauer's avatar
Martin Bauer committed
215
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
Martin Bauer's avatar
Martin Bauer committed
216
    r"""
217
    Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
Martin Bauer's avatar
Martin Bauer committed
218
219
220
221
222
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.

223
224
225
226
227
228
229
230
231
    Args:
        field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
        coordinates: mapping of coordinate ids to its value, where stride*value is calculated
        previous_ptr: the pointer which is de-referenced

    Returns
        tuple with the new pointer symbol and the calculated offset

    Examples:
Martin Bauer's avatar
Martin Bauer committed
232
        >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
Martin Bauer's avatar
Martin Bauer committed
233
        >>> x, y = sp.symbols("x y")
Martin Bauer's avatar
Martin Bauer committed
234
235
        >>> prev_pointer = TypedSymbol("ptr", "double")
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
236
        (ptr_01, _stride_myfield_0*x + _stride_myfield_0)
Martin Bauer's avatar
Martin Bauer committed
237
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
238
        (ptr_01_1m2, _stride_myfield_0*x + _stride_myfield_0 + _stride_myfield_1*y - 2*_stride_myfield_1)
Martin Bauer's avatar
Martin Bauer committed
239
    """
Martin Bauer's avatar
Martin Bauer committed
240
    field = field_access.field
241
242
    offset = 0
    name = ""
Martin Bauer's avatar
Martin Bauer committed
243
    list_to_hash = []
Martin Bauer's avatar
Martin Bauer committed
244
245
    for coordinate_id, coordinate_value in coordinates.items():
        offset += field.strides[coordinate_id] * coordinate_value
246

Martin Bauer's avatar
Martin Bauer committed
247
        if coordinate_id < field.spatial_dimensions:
248
249
            offset += field.strides[coordinate_id] * field_access.offsets[
                coordinate_id]
Martin Bauer's avatar
Martin Bauer committed
250
            if type(field_access.offsets[coordinate_id]) is int:
251
252
                name += "_%d%d" % (coordinate_id,
                                   field_access.offsets[coordinate_id])
253
            else:
Martin Bauer's avatar
Martin Bauer committed
254
                list_to_hash.append(field_access.offsets[coordinate_id])
255
        else:
Martin Bauer's avatar
Martin Bauer committed
256
            if type(coordinate_value) is int:
Martin Bauer's avatar
Martin Bauer committed
257
                name += "_%d%d" % (coordinate_id, coordinate_value)
258
            else:
Martin Bauer's avatar
Martin Bauer committed
259
                list_to_hash.append(coordinate_value)
260

Martin Bauer's avatar
Martin Bauer committed
261
    if len(list_to_hash) > 0:
262
        name += hashlib.md5(pickle.dumps(list_to_hash)).hexdigest()[:16]
263

Martin Bauer's avatar
Martin Bauer committed
264
    name = name.replace("-", 'm')
Martin Bauer's avatar
Martin Bauer committed
265
266
    new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
    return new_ptr, offset
267
268


269
270
def parse_base_pointer_info(base_pointer_specification, loop_order,
                            spatial_dimensions, index_dimensions):
271
    """
Martin Bauer's avatar
Martin Bauer committed
272
    Creates base pointer specification for :func:`resolve_field_accesses` function.
Martin Bauer's avatar
Martin Bauer committed
273
274
275

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
276
    zero directly in the field access. These specifications are defined dependent on the loop ordering.
Martin Bauer's avatar
Martin Bauer committed
277
278
    This function translates more readable version into the specification above.

279
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
280
281
282
283
284
285
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

Martin Bauer's avatar
Martin Bauer committed
286
287
288
    Args:
        base_pointer_specification: nested list with above specifications
        loop_order: list with ordering of loops from outer to inner
Martin Bauer's avatar
Martin Bauer committed
289
290
        spatial_dimensions: number of spatial dimensions
        index_dimensions: number of index dimensions
Martin Bauer's avatar
Martin Bauer committed
291
292
293

    Returns:
        list of tuples that can be passed to :func:`resolve_field_accesses`
Martin Bauer's avatar
Martin Bauer committed
294
295
296
297
298

    Examples:
        >>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
        ...                         spatial_dimensions=3, index_dimensions=1)
        [[0], [3], [1, 2]]
299
300
    """
    result = []
Martin Bauer's avatar
Martin Bauer committed
301
302
    specified_coordinates = set()
    loop_order = list(reversed(loop_order))
Martin Bauer's avatar
Martin Bauer committed
303
    for spec_group in base_pointer_specification:
Martin Bauer's avatar
Martin Bauer committed
304
305
306
        new_group = []

        def add_new_element(elem):
Martin Bauer's avatar
Martin Bauer committed
307
            if elem >= spatial_dimensions + index_dimensions:
308
                raise ValueError("Coordinate %d does not exist" % (elem, ))
Martin Bauer's avatar
Martin Bauer committed
309
310
            new_group.append(elem)
            if elem in specified_coordinates:
311
312
                raise ValueError("Coordinate %d specified two times" %
                                 (elem, ))
Martin Bauer's avatar
Martin Bauer committed
313
            specified_coordinates.add(elem)
314

Martin Bauer's avatar
Martin Bauer committed
315
        for element in spec_group:
316
            if type(element) is int:
Martin Bauer's avatar
Martin Bauer committed
317
                add_new_element(element)
318
319
320
321
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
Martin Bauer's avatar
Martin Bauer committed
322
                    add_new_element(loop_order[index])
323
324
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
Martin Bauer's avatar
Martin Bauer committed
325
                    add_new_element(loop_order[-index])
326
                elif element == "all":
Martin Bauer's avatar
Martin Bauer committed
327
                    for i in range(spatial_dimensions):
Martin Bauer's avatar
Martin Bauer committed
328
                        add_new_element(i)
329
330
331
332
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
Martin Bauer's avatar
Martin Bauer committed
333
                add_new_element(spatial_dimensions + index)
334
            else:
335
                raise ValueError("Unknown specification %s" % (element, ))
336

Martin Bauer's avatar
Martin Bauer committed
337
        result.append(new_group)
338

Martin Bauer's avatar
Martin Bauer committed
339
    all_coordinates = set(range(spatial_dimensions + index_dimensions))
Martin Bauer's avatar
Martin Bauer committed
340
    rest = all_coordinates - specified_coordinates
341
342
    if rest:
        result.append(list(rest))
343

344
345
346
    return result


Martin Bauer's avatar
Martin Bauer committed
347
348
349
350
351
352
353
354
355
356
357
358
359
def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
    """Used for buffer fields to determine the linearized index of the buffer dependent on loop counter symbols.

    Args:
        ast_node: ast before any field accesses are resolved
        loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
                       for GPU kernels: list of 'loop counters' from inner to outer loop
        loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default

    Returns:
        base buffer index - required by 'resolve_buffer_accesses' function
    """
    if loop_counters is None or loop_iterations is None:
360
361
362
363
        loops = [
            l for l in filtered_tree_iteration(
                ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)
        ]
Martin Bauer's avatar
Martin Bauer committed
364
        loops.reverse()
365
366
367
368
        parents_of_innermost_loop = list(
            parents_of_type(loops[0],
                            ast.LoopOverCoordinate,
                            include_current=True))
Martin Bauer's avatar
Martin Bauer committed
369
        assert len(loops) == len(parents_of_innermost_loop)
370
371
        assert all(l1 is l2
                   for l1, l2 in zip(loops, parents_of_innermost_loop))
Martin Bauer's avatar
Martin Bauer committed
372
373
374
375

        loop_iterations = [(l.stop - l.start) / l.step for l in loops]
        loop_counters = [l.loop_counter_symbol for l in loops]

376
    field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
377
378
379
380
    buffer_accesses = {
        fa
        for fa in field_accesses if FieldType.is_buffer(fa.field)
    }
Martin Bauer's avatar
Martin Bauer committed
381
382
383
384
385
386
    loop_counters = [v * len(buffer_accesses) for v in loop_counters]

    base_buffer_index = loop_counters[0]
    stride = 1
    for idx, var in enumerate(loop_counters[1:]):
        cur_stride = loop_iterations[idx]
387
388
        stride *= int(cur_stride) if isinstance(cur_stride,
                                                float) else cur_stride
Martin Bauer's avatar
Martin Bauer committed
389
390
391
392
        base_buffer_index += var * stride
    return base_buffer_index


393
394
395
def resolve_buffer_accesses(ast_node,
                            base_buffer_index,
                            read_only_field_names=set()):
Martin Bauer's avatar
Martin Bauer committed
396
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
397
        if isinstance(expr, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
398
            field_access = expr
399
400

            # Do not apply transformation if field is not a buffer
Martin Bauer's avatar
Martin Bauer committed
401
            if not FieldType.is_buffer(field_access.field):
402
403
                return expr

Martin Bauer's avatar
Martin Bauer committed
404
            buffer = field_access.field
405
406
407
408
            field_ptr = FieldPointerSymbol(
                buffer.name,
                buffer.dtype,
                const=buffer.name in read_only_field_names)
409

Martin Bauer's avatar
Martin Bauer committed
410
411
            buffer_index = base_buffer_index
            if len(field_access.index) > 1:
412
413
414
                raise RuntimeError(
                    'Only indexing dimensions up to 1 are currently supported in buffers!'
                )
415

Martin Bauer's avatar
Martin Bauer committed
416
417
418
            if len(field_access.index) > 0:
                cell_index = field_access.index[0]
                buffer_index += cell_index
419

420
421
422
            result = ast.ResolvedFieldAccess(field_ptr, buffer_index,
                                             field_access.field,
                                             field_access.offsets,
Martin Bauer's avatar
Martin Bauer committed
423
                                             field_access.index)
424

Martin Bauer's avatar
Martin Bauer committed
425
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
426
427
428
429
        else:
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr

430
431
432
433
434
435
436
            new_args = [
                visit_sympy_expr(e, enclosing_block, sympy_assignment)
                for e in expr.args
            ]
            kwargs = {
                'evaluate': False
            } if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
437
438
439
440
441
442
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
443
444
445
446
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block,
                                           sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block,
                                           sub_ast)
447
        else:
Martin Bauer's avatar
Martin Bauer committed
448
449
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
450

Martin Bauer's avatar
Martin Bauer committed
451
    return visit_node(ast_node)
452

453

454
455
def resolve_field_accesses(ast_node,
                           read_only_field_names=set(),
Martin Bauer's avatar
Martin Bauer committed
456
457
                           field_to_base_pointer_info=MappingProxyType({}),
                           field_to_fixed_coordinates=MappingProxyType({})):
Martin Bauer's avatar
Martin Bauer committed
458
459
460
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

461
462
463
464
465
466
    Args:
        ast_node: the AST root
        read_only_field_names: set of field names which are considered read-only
        field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
                                    for details see :func:`parse_base_pointer_info`
        field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
Martin Bauer's avatar
Martin Bauer committed
467
                                    counters to index the field these symbols are used as coordinates
468
469
470

    Returns
        transformed AST
Martin Bauer's avatar
Martin Bauer committed
471
    """
472
473
474
475
    field_to_base_pointer_info = OrderedDict(
        sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
    field_to_fixed_coordinates = OrderedDict(
        sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
476

Martin Bauer's avatar
Martin Bauer committed
477
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
478
        if isinstance(expr, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
479
            field_access = expr
Martin Bauer's avatar
Martin Bauer committed
480
            field = field_access.field
481

Martin Bauer's avatar
Martin Bauer committed
482
            if field_access.indirect_addressing_fields:
483
484
485
486
487
488
489
                new_offsets = tuple(
                    visit_sympy_expr(off, enclosing_block, sympy_assignment)
                    for off in field_access.offsets)
                new_indices = tuple(
                    visit_sympy_expr(ind, enclosing_block, sympy_assignment
                                     ) if isinstance(ind, sp.Basic) else ind
                    for ind in field_access.index)
Martin Bauer's avatar
Martin Bauer committed
490
                field_access = Field.Access(field_access.field, new_offsets,
491
492
                                            new_indices,
                                            field_access.is_absolute_access)
493

Martin Bauer's avatar
Martin Bauer committed
494
495
            if field.name in field_to_base_pointer_info:
                base_pointer_info = field_to_base_pointer_info[field.name]
496
            else:
497
498
499
500
                base_pointer_info = [
                    list(
                        range(field.index_dimensions + field.spatial_dimensions))
                ]
501

502
503
504
505
            field_ptr = FieldPointerSymbol(
                field.name,
                field.dtype,
                const=field.name in read_only_field_names)
506

Martin Bauer's avatar
Martin Bauer committed
507
508
509
510
            def create_coordinate_dict(group_param):
                coordinates = {}
                for e in group_param:
                    if e < field.spatial_dimensions:
Martin Bauer's avatar
Martin Bauer committed
511
                        if field.name in field_to_fixed_coordinates:
512
                            if not field_access.is_absolute_access:
513
514
                                coordinates[e] = field_to_fixed_coordinates[
                                    field.name][e]
515
516
                            else:
                                coordinates[e] = 0
517
                        else:
Martin Bauer's avatar
Martin Bauer committed
518
                            if not field_access.is_absolute_access:
519
520
521
                                coordinates[
                                    e] = ast.LoopOverCoordinate.get_loop_counter_symbol(
                                        e)
Martin Bauer's avatar
Martin Bauer committed
522
523
                            else:
                                coordinates[e] = 0
Martin Bauer's avatar
Martin Bauer committed
524
                        coordinates[e] *= field.dtype.item_size
525
                    else:
526
                        if isinstance(field.dtype, StructType):
Martin Bauer's avatar
Martin Bauer committed
527
                            assert field.index_dimensions == 1
Martin Bauer's avatar
Martin Bauer committed
528
529
                            accessed_field_name = field_access.index[0]
                            assert isinstance(accessed_field_name, str)
530
531
                            coordinates[e] = field.dtype.get_element_offset(
                                accessed_field_name)
532
                        else:
533
534
                            coordinates[e] = field_access.index[
                                e - field.spatial_dimensions]
535

Martin Bauer's avatar
Martin Bauer committed
536
                return coordinates
537

Martin Bauer's avatar
Martin Bauer committed
538
            last_pointer = field_ptr
539

Martin Bauer's avatar
Martin Bauer committed
540
541
            for group in reversed(base_pointer_info[1:]):
                coord_dict = create_coordinate_dict(group)
542
543
                new_ptr, offset = create_intermediate_base_pointer(
                    field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
544
                if new_ptr not in enclosing_block.symbols_defined:
545
546
547
548
549
                    new_assignment = ast.SympyAssignment(new_ptr,
                                                         last_pointer + offset,
                                                         is_const=False)
                    enclosing_block.insert_before(new_assignment,
                                                  sympy_assignment)
Martin Bauer's avatar
Martin Bauer committed
550
                last_pointer = new_ptr
551

Martin Bauer's avatar
Martin Bauer committed
552
            coord_dict = create_coordinate_dict(base_pointer_info[0])
553
554
555
556
557
558
            _, offset = create_intermediate_base_pointer(
                field_access, coord_dict, last_pointer)
            result = ast.ResolvedFieldAccess(last_pointer, offset,
                                             field_access.field,
                                             field_access.offsets,
                                             field_access.index)
559

Martin Bauer's avatar
Martin Bauer committed
560
            if isinstance(get_base_type(field_access.field.dtype), StructType):
561
562
                new_type = field_access.field.dtype.get_element_type(
                    field_access.index[0])
563
                result = reinterpret_cast_func(result, new_type)
564

Martin Bauer's avatar
Martin Bauer committed
565
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
566
        else:
Martin Bauer's avatar
Martin Bauer committed
567
568
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr
569

570
571
572
573
574
575
576
            new_args = [
                visit_sympy_expr(e, enclosing_block, sympy_assignment)
                for e in expr.args
            ]
            kwargs = {
                'evaluate': False
            } if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
577
578
579
580
581
582
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
583
584
585
586
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block,
                                           sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block,
                                           sub_ast)
587
588
589
        elif isinstance(sub_ast, ast.Conditional):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
590
591
            sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr,
                                                      enclosing_block, sub_ast)
592
593
594
            visit_node(sub_ast.true_block)
            if sub_ast.false_block:
                visit_node(sub_ast.false_block)
595
        else:
Martin Bauer's avatar
Martin Bauer committed
596
597
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
598

Martin Bauer's avatar
Martin Bauer committed
599
    return visit_node(ast_node)
600
601


Martin Bauer's avatar
Martin Bauer committed
602
def move_constants_before_loop(ast_node):
603
604
    """Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.

Martin Bauer's avatar
Martin Bauer committed
605
    Call this after creating the loop structure with :func:`make_loop_over_domain`
Martin Bauer's avatar
Martin Bauer committed
606
    """
Martin Bauer's avatar
Martin Bauer committed
607
    def find_block_to_move_to(node):
Martin Bauer's avatar
Martin Bauer committed
608
609
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
610
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
611
612
613
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
614
615
        assert isinstance(node.parent, ast.Block)

Martin Bauer's avatar
Martin Bauer committed
616
617
        last_block = node.parent
        last_block_child = node
618
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
619
        prev_element = node
620
621
        while element:
            if isinstance(element, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
622
623
                last_block = element
                last_block_child = prev_element
624
625

            if isinstance(element, ast.Conditional):
626
                break
627
            else:
Martin Bauer's avatar
Martin Bauer committed
628
629
                critical_symbols = element.symbols_defined
            if node.undefined_symbols.intersection(critical_symbols):
630
                break
Martin Bauer's avatar
Martin Bauer committed
631
            prev_element = element
632
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
633
        return last_block, last_block_child
634

635
636
637
    def check_if_assignment_already_in_block(assignment,
                                             target_block,
                                             rhs_or_lhs=True):
Martin Bauer's avatar
Martin Bauer committed
638
        for arg in target_block.args:
639
640
            if type(arg) is not ast.SympyAssignment:
                continue
641
642
            if (rhs_or_lhs and arg.rhs == assignment.rhs) or (
                    not rhs_or_lhs and arg.lhs == assignment.lhs):
643
644
645
                return arg
        return None

Martin Bauer's avatar
Martin Bauer committed
646
    def get_blocks(node, result_list):
647
        if isinstance(node, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
648
            result_list.append(node)
649
650
        if isinstance(node, ast.Node):
            for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
651
                get_blocks(a, result_list)
652

Martin Bauer's avatar
Martin Bauer committed
653
654
655
    all_blocks = []
    get_blocks(ast_node, all_blocks)
    for block in all_blocks:
Martin Bauer's avatar
Martin Bauer committed
656
        children = block.take_child_nodes()
657
658
659
660
        # Every time a symbol can be replaced in the current block because the assignment
        # was found in a parent block, but with a different lhs symbol (same rhs)
        # the outer symbol is inserted here as key.
        substitute_variables = {}
661
        for child in children:
662
663
664
            # Before traversing the next child, all symbols are substituted first.
            child.subs(substitute_variables)

665
666
            if not isinstance(
                    child, ast.SympyAssignment):  # only move SympyAssignments
667
668
669
                block.append(child)
                continue

670
            target, child_to_insert_before = find_block_to_move_to(child)
671
            if target == block:  # movement not possible
672
                target.append(child)
673
            else:
674
                if isinstance(child, ast.SympyAssignment):
675
676
                    exists_already = check_if_assignment_already_in_block(
                        child, target, False)
677
                else:
678
                    exists_already = False
Martin Bauer's avatar
Martin Bauer committed
679

680
                if not exists_already:
681
682
                    rhs_identical = check_if_assignment_already_in_block(
                        child, target, True)
683
684
685
686
687
688
689
                    if rhs_identical:
                        # there is already an assignment out there with the same rhs
                        # -> replace all lhs symbols in this block with the lhs of the outer assignment
                        # -> remove the local assignment (do not re-append child to the former block)
                        substitute_variables[child.lhs] = rhs_identical.lhs
                    else:
                        target.insert_before(child, child_to_insert_before)
Martin Bauer's avatar
Martin Bauer committed
690
691
                elif exists_already and exists_already.rhs == child.rhs:
                    pass
692
                else:
693
694
695
696
                    # this variable already exists in outer block, but with different rhs
                    # -> symbol has to be renamed
                    assert isinstance(child.lhs, TypedSymbol)
                    new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
697
698
699
                    target.insert_before(
                        ast.SympyAssignment(new_symbol, child.rhs),
                        child_to_insert_before)
700
                    substitute_variables[child.lhs] = new_symbol
701
702


Martin Bauer's avatar
Martin Bauer committed
703
def split_inner_loop(ast_node: ast.Node, symbol_groups):
Martin Bauer's avatar
Martin Bauer committed
704
705
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
706

Martin Bauer's avatar
Martin Bauer committed
707
708
709
710
711
    Args:
        ast_node: AST root
        symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
                       updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups
                       and which no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
712
    """
Martin Bauer's avatar
Martin Bauer committed
713
714
    all_loops = ast_node.atoms(ast.LoopOverCoordinate)
    inner_loop = [l for l in all_loops if l.is_innermost_loop]
715
716
717
    assert len(
        inner_loop
    ) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
Martin Bauer's avatar
Martin Bauer committed
718
719
720
721
722
723
724
725
726
727
    inner_loop = inner_loop[0]
    assert type(inner_loop.body) is ast.Block
    outer_loop = [l for l in all_loops if l.is_outermost_loop]
    assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
    outer_loop = outer_loop[0]

    symbols_with_temporary_array = OrderedDict()
    assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args)

    assignment_groups = []
Martin Bauer's avatar
Martin Bauer committed
728
    for symbol_group in symbol_groups:
729
        # get all dependent symbols
Martin Bauer's avatar
Martin Bauer committed
730
        symbols_to_process = list(symbol_group)
Martin Bauer's avatar
Martin Bauer committed
731
732
733
734
        symbols_resolved = set()
        while symbols_to_process:
            s = symbols_to_process.pop()
            if s in symbols_resolved:
735
736
                continue

Martin Bauer's avatar
Martin Bauer committed
737
            if s in assignment_map:  # if there is no assignment inside the loop body it is independent already
Martin Bauer's avatar
Martin Bauer committed
738
                for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
739
740
                    if not isinstance(new_symbol, AbstractField.AbstractAccess) and \
                            new_symbol not in symbols_with_temporary_array:
Martin Bauer's avatar
Martin Bauer committed
741
                        symbols_to_process.append(new_symbol)
Martin Bauer's avatar
Martin Bauer committed
742
            symbols_resolved.add(s)
743

Martin Bauer's avatar
Martin Bauer committed
744
        for symbol in symbol_group:
745
            if not isinstance(symbol, AbstractField.AbstractAccess):
746
                assert type(symbol) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
747
                new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
748
749
                symbols_with_temporary_array[symbol] = sp.IndexedBase(
                    new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
750

Martin Bauer's avatar
Martin Bauer committed
751
752
753
        assignment_group = []
        for assignment in inner_loop.body.args:
            if assignment.lhs in symbols_resolved:
754
755
756
757
                new_rhs = assignment.rhs.subs(
                    symbols_with_temporary_array.items())
                if not isinstance(assignment.lhs, AbstractField.AbstractAccess
                                  ) and assignment.lhs in symbol_group:
758
                    assert type(assignment.lhs) is TypedSymbol
759
760
761
762
                    new_ts = TypedSymbol(assignment.lhs.name,
                                         PointerType(assignment.lhs.dtype))
                    new_lhs = sp.IndexedBase(
                        new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
763
                else:
Martin Bauer's avatar
Martin Bauer committed
764
765
766
                    new_lhs = assignment.lhs
                assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
        assignment_groups.append(assignment_group)
767

768
769
770
771
    new_loops = [
        inner_loop.new_loop_with_different_body(ast.Block(group))
        for group in assignment_groups
    ]
Martin Bauer's avatar
Martin Bauer committed
772
    inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
773

Martin Bauer's avatar
Martin Bauer committed
774
    for tmp_array in symbols_with_temporary_array:
775
776
777
778
779
        tmp_array_pointer = TypedSymbol(tmp_array.name,
                                        PointerType(tmp_array.dtype))
        alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer,
                                                   inner_loop.stop,
                                                   inner_loop.start)
Martin Bauer's avatar
Martin Bauer committed
780
781
782
        free_node = ast.TemporaryMemoryFree(alloc_node)
        outer_loop.parent.insert_front(alloc_node)
        outer_loop.parent.append(free_node)
783
784


Martin Bauer's avatar
Martin Bauer committed
785
def cut_loop(loop_node, cutting_points):
Martin Bauer's avatar
Martin Bauer committed
786
787
788
789
790
791
792
793
794
795
    """Cuts loop at given cutting points.

    One loop is transformed into len(cuttingPoints)+1 new loops that range from
    old_begin to cutting_points[1], ..., cutting_points[-1] to old_end

    Modifies the ast in place

    Returns:
        list of new loop nodes
    """
Martin Bauer's avatar
Martin Bauer committed
796
    if loop_node.step != 1:
797
        raise NotImplementedError("Can only split loops that have a step of 1")
798
    new_loops = ast.Block([])
Martin Bauer's avatar
Martin Bauer committed
799
800
    new_start = loop_node.start
    cutting_points = list(cutting_points) + [loop_node.stop]
Martin Bauer's avatar
Martin Bauer committed
801
802
    for new_end in cutting_points:
        if new_end - new_start == 1:
Martin Bauer's avatar
Martin Bauer committed
803
804
805
            new_body = deepcopy(loop_node.body)
            new_body.subs({loop_node.loop_counter_symbol: new_start})
            new_loops.append(new_body)
Martin Bauer's avatar
Martin Bauer committed
806
807
        elif new_end - new_start == 0:
            pass
808
        else:
809
810
811
            new_loop = ast.LoopOverCoordinate(
                deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
                new_start, new_end, loop_node.step)
Martin Bauer's avatar
Martin Bauer committed
812
            new_loops.append(new_loop)
Martin Bauer's avatar
Martin Bauer committed
813
        new_start = new_end
Martin Bauer's avatar
Martin Bauer committed
814
    loop_node.parent.replace(loop_node, new_loops)
Martin Bauer's avatar
Martin Bauer committed
815
    return new_loops
816
817


818
819
def simplify_conditionals(node: ast.Node,
                          loop_counter_simplification: bool = False) -> None:
820
    """Removes conditionals that are always true/false.
821
822

    Args:
823
824
825
826
827
828
        node: ast node, all descendants of this node are simplified
        loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false
                                     depending on the surrounding loop. For example if the surrounding loop goes from
                                     x=0 to 10 and the condition is x < 0, it is removed.
                                     This analysis needs the integer set library (ISL) islpy, so it is not done by
                                     default.
829
    """
830
831
832
833
834
    for conditional in node.atoms(ast.Conditional):
        conditional.condition_expr = sp.simplify(conditional.condition_expr)
        if conditional.condition_expr == sp.true:
            conditional.parent.replace(conditional, [conditional.true_block])
        elif conditional.condition_expr == sp.false:
835
836
837
            conditional.parent.replace(
                conditional,
                [conditional.false_block] if conditional.false_block else [])
838
839
840
841
842
843
        elif loop_counter_simplification:
            try:
                # noinspection PyUnresolvedReferences
                from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
                simplify_loop_counter_dependent_conditional(conditional)
            except ImportError:
844
845
846
                warnings.warn(
                    "Integer simplifications in conditionals skipped, because ISLpy package not installed"
                )
847
848
849


def cleanup_blocks(node: ast.Node) -> None:
850
851
852
853
854
    """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
    if isinstance(node, ast.SympyAssignment):
        return
    elif isinstance(node, ast.Block):
        for a in list(node.args):
Martin Bauer's avatar
Martin Bauer committed
855
            cleanup_blocks(a)
856
857
858
859
860
        if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
            node.parent.replace(node, node.args)
            return
    else:
        for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
861
            cleanup_blocks(a)
862
863


864
865
866
867
868
869
870
871
872
873
874
875
class KernelConstraintsCheck:
    """Checks if the input to create_kernel is valid.

    Test the following conditions:

    - SSA Form for pure symbols:
        -  Every pure symbol may occur only once as left-hand-side of an assignment
        -  Every pure symbol that is read, may not be written to later
    - Independence / Parallelization condition:
        - a field that is written may only be read at exact the same spatial position

    (Pure symbols are symbols that are not Field.Accesses)
Martin Bauer's avatar
Martin Bauer committed
876
    """
877
878
879
880
881
    FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])

    def __init__(self, type_for_symbol, check_independence_condition):
        self._type_for_symbol = type_for_symbol

882
        self.scopes = NestedScopes()
883
884
885
886
887
888
889
890
891
892
        self._field_writes = defaultdict(set)
        self.fields_read = set()
        self.check_independence_condition = check_independence_condition

    def process_assignment(self, assignment):
        # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
        new_rhs = self.process_expression(assignment.rhs)
        new_lhs = self._process_lhs(assignment.lhs)
        return ast.SympyAssignment(new_lhs, new_rhs)

893
    def process_expression(self, rhs, type_constants=True):
894
        self._update_accesses_rhs(rhs)
895
        if isinstance(rhs, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
896
            self.fields_read.add(rhs.field)
Martin Bauer's avatar
Martin Bauer committed
897
            self.fields_read.update(rhs.indirect_addressing_fields)
898
899
900
901
            return rhs
        elif isinstance(rhs, TypedSymbol):
            return rhs
        elif isinstance(rhs, sp.Symbol):
902
            return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
903
904
        elif type_constants and isinstance(rhs, np.generic):
            return cast_func(rhs, create_type(rhs.dtype))
905
        elif type_constants and isinstance(rhs, sp.Number):
906
907
            return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
        elif isinstance(rhs, sp.Mul):
908
909
910
911
            new_args = [
                self.process_expression(arg, type_constants)
                if arg not in (-1, 1) else arg for arg in rhs.args
            ]
912
            return rhs.func(*new_args) if new_args else rhs
913
914
        elif isinstance(rhs, sp.Indexed):
            return rhs
915
        elif isinstance(rhs, cast_func):
916
917
918
            return cast_func(
                self.process_expression(rhs.args[0], type_constants=False),
                rhs.dtype)
919
920
921
        else:
            if isinstance(rhs, sp.Pow):
                # don't process exponents -> they should remain integers
922
923
924
                return sp.Pow(
                    self.process_expression(rhs.args[0], type_constants),
                    rhs.args[1])
925
            else:
926
927
928
929
                new_args = [
                    self.process_expression(arg, type_constants)
                    for arg in rhs.args
                ]
930
                return rhs.func(*new_args) if new_args else rhs
931
932
933
934
935
936
937
938

    @property
    def fields_written(self):
        return set(k.field for k, v in self._field_writes.items() if len(v))

    def _process_lhs(self, lhs):
        assert isinstance(lhs, sp.Symbol)
        self._update_accesses_lhs(lhs)
939
        if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
940
941
942
943
944
            return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
        else:
            return lhs

    def _update_accesses_lhs(self, lhs):
945
        if isinstance(lhs, AbstractField.AbstractAccess):
946
947
948
            fai = self.FieldAndIndex(lhs.field, lhs.index)
            self._field_writes[fai].add(lhs.offsets)
            if len(self._field_writes[fai]) > 1:
949
950
951
                raise ValueError(
                    "Field {} is written at two different locations".format(
                        lhs.field.name))
952
        elif isinstance(lhs, sp.Symbol):
953
            if self.scopes.is_defined_locally(lhs):
954
955
956
                raise ValueError(
                    "Assignments not in SSA form, multiple assignments to {}".
                    format(lhs.name))
957
            if lhs in self.scopes.free_parameters:
958
959
960
                raise ValueError(
                    "Symbol {} is written, after it has been read".format(
                        lhs.name))
961
            self.scopes.define_symbol(lhs)
962
963

    def _update_accesses_rhs(self, rhs):
964
965
966
967
        if isinstance(rhs, AbstractField.AbstractAccess
                      ) and self.check_independence_condition:
            writes = self._field_writes[self.FieldAndIndex(
                rhs.field, rhs.index)]
968
969
970
            for write_offset in writes:
                assert len(writes) == 1
                if write_offset != rhs.offsets:
971
972
973
974
                    raise ValueError(
                        "Violation of loop independence condition. Field "
                        "{} is read at {} and written at {}".format(
                            rhs.field, rhs.offsets, write_offset))
975
976
            self.fields_read.add(rhs.field)
        elif isinstance(rhs, sp.Symbol):
977
            self.scopes.access_symbol(rhs)
978
979
980
981
982


def add_types(eqs, type_for_symbol, check_independence_condition):
    """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.

Martin Bauer's avatar
Martin Bauer committed
983
984
    Additionally returns sets of all fields which are read/written

985
986
987
988
989
990
991
992
993
    Args:
        eqs: list of equations
        type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
        check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
                                      kernels

    Returns:
        ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
         list of equations where symbols have been replaced by typed symbols
Martin Bauer's avatar
Martin Bauer committed
994
    """
995
996