transformations.py 58.8 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
2
import hashlib
import pickle
3
import warnings
Martin Bauer's avatar
Martin Bauer committed
4
from collections import OrderedDict, defaultdict, namedtuple
5
from copy import deepcopy
Martin Bauer's avatar
Martin Bauer committed
6
from types import MappingProxyType
Martin Bauer's avatar
Martin Bauer committed
7

8
import numpy as np
9
import sympy as sp
10
from sympy.core.numbers import ImaginaryUnit
11
from sympy.logic.boolalg import Boolean
Martin Bauer's avatar
Martin Bauer committed
12
13

import pystencils.astnodes as ast
14
import pystencils.integer_functions
15
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
16
from pystencils.data_types import (
17
18
    PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type,
    get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
Martin Bauer's avatar
Martin Bauer committed
19
from pystencils.field import AbstractField, Field, FieldType
20
from pystencils.kernelparameters import FieldPointerSymbol
Martin Bauer's avatar
Martin Bauer committed
21
from pystencils.simp.assignment_collection import AssignmentCollection
Martin Bauer's avatar
Martin Bauer committed
22
from pystencils.slicing import normalize_slice
23
24


25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class NestedScopes:
    """Symbol visibility model using nested scopes

    - every accessed symbol that was not defined before, is added as a "free parameter"
    - free parameters are global, i.e. they are not in scopes
    - push/pop adds or removes a scope

    >>> s = NestedScopes()
    >>> s.access_symbol("a")
    >>> s.is_defined("a")
    False
    >>> s.free_parameters
    {'a'}
    >>> s.define_symbol("b")
    >>> s.is_defined("b")
    True
    >>> s.push()
    >>> s.is_defined_locally("b")
    False
    >>> s.define_symbol("c")
    >>> s.pop()
    >>> s.is_defined("c")
    False
    """

    def __init__(self):
        self.free_parameters = set()
        self._defined = [set()]

    def access_symbol(self, symbol):
        if not self.is_defined(symbol):
            self.free_parameters.add(symbol)

    def define_symbol(self, symbol):
        self._defined[-1].add(symbol)

    def is_defined(self, symbol):
        return any(symbol in scopes for scopes in self._defined)

    def is_defined_locally(self, symbol):
        return symbol in self._defined[-1]

    def push(self):
        self._defined.append(set())

    def pop(self):
        self._defined.pop()
        assert self.depth >= 1

    @property
    def depth(self):
        return len(self._defined)


Martin Bauer's avatar
Martin Bauer committed
79
def filtered_tree_iteration(node, node_type, stop_type=None):
80
    for arg in node.args:
Martin Bauer's avatar
Martin Bauer committed
81
        if isinstance(arg, node_type):
82
            yield arg
Martin Bauer's avatar
Martin Bauer committed
83
84
85
        elif stop_type and isinstance(node, stop_type):
            continue

Martin Bauer's avatar
Martin Bauer committed
86
        yield from filtered_tree_iteration(arg, node_type)
87
88


89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
def generic_visit(term, visitor):
    if isinstance(term, AssignmentCollection):
        new_main_assignments = generic_visit(term.main_assignments, visitor)
        new_subexpressions = generic_visit(term.subexpressions, visitor)
        return term.copy(new_main_assignments, new_subexpressions)
    elif isinstance(term, list):
        return [generic_visit(e, visitor) for e in term]
    elif isinstance(term, Assignment):
        return Assignment(term.lhs, generic_visit(term.rhs, visitor))
    elif isinstance(term, sp.Matrix):
        return term.applyfunc(lambda e: generic_visit(e, visitor))
    else:
        return visitor(term)


104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def unify_shape_symbols(body, common_shape, fields):
    """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.

    When creating a kernel with variable array sizes, all passed arrays must have the same size.
    This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead
    of one for each field. For example shape_arr1[0]  and shape_arr2[0] must be equal, so they should also be
    represented by the same symbol.

    Args:
        body: ast node, for the kernel part where substitutions is made, is modified in-place
        common_shape: shape of the field that was chosen
        fields: all fields whose shapes should be replaced by common_shape
    """
    substitutions = {}
    for field in fields:
        assert len(field.spatial_shape) == len(common_shape)
        if not field.has_fixed_shape:
            for common_shape_component, shape_component in zip(common_shape, field.spatial_shape):
                if shape_component != common_shape_component:
                    substitutions[shape_component] = common_shape_component
    if substitutions:
        body.subs(substitutions)


Martin Bauer's avatar
Martin Bauer committed
128
def get_common_shape(field_set):
129
130
    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
    ValueError is raised"""
Martin Bauer's avatar
Martin Bauer committed
131
132
133
134
135
136
137
138
    nr_of_fixed_shaped_fields = 0
    for f in field_set:
        if f.has_fixed_shape:
            nr_of_fixed_shaped_fields += 1

    if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set):
        fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
        var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
139
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
Martin Bauer's avatar
Martin Bauer committed
140
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (var_field_names, fixed_field_names)
141
142
        raise ValueError(msg)

Martin Bauer's avatar
Martin Bauer committed
143
144
145
146
    shape_set = set([f.spatial_shape for f in field_set])
    if nr_of_fixed_shaped_fields == len(field_set):
        if len(shape_set) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
147

Martin Bauer's avatar
Martin Bauer committed
148
    shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
149
150
151
    return shape


152
def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
Martin Bauer's avatar
Martin Bauer committed
153
154
155
    """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.

    Args:
156
        body: Block object with inner loop contents
Martin Bauer's avatar
Martin Bauer committed
157
158
159
160
161
162
163
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
             if None, the number of ghost layers is determined automatically and assumed to be equal for a
             all dimensions
        loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout)

    Returns:
164
        tuple of loop-node, ghost_layer_info
165
166
    """
    # find correct ordering by inspecting participating FieldAccesses
167
    field_accesses = body.atoms(AbstractField.AbstractAccess)
168
169
    field_accesses = {e for e in field_accesses if not e.is_absolute_access}

Martin Bauer's avatar
Martin Bauer committed
170
171
172
173
174
175
176
    # exclude accesses to buffers from field_list, because buffers are treated separately
    field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)]
    fields = set(field_list)

    if loop_order is None:
        loop_order = get_optimal_loop_ordering(fields)

177
178
    shape = get_common_shape(fields)
    unify_shape_symbols(body, common_shape=shape, fields=fields)
Martin Bauer's avatar
Martin Bauer committed
179
180
181
182
183
184
185
186
187
188
189

    if iteration_slice is not None:
        iteration_slice = normalize_slice(iteration_slice, shape)

    if ghost_layers is None:
        required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
        ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
    if isinstance(ghost_layers, int):
        ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)

    current_body = body
Martin Bauer's avatar
Martin Bauer committed
190
    for i, loop_coordinate in enumerate(reversed(loop_order)):
Martin Bauer's avatar
Martin Bauer committed
191
        if iteration_slice is None:
Martin Bauer's avatar
Martin Bauer committed
192
193
            begin = ghost_layers[loop_coordinate][0]
            end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
194
            new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
Martin Bauer's avatar
Martin Bauer committed
195
            current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
196
        else:
Martin Bauer's avatar
Martin Bauer committed
197
            slice_component = iteration_slice[loop_coordinate]
Martin Bauer's avatar
Martin Bauer committed
198
199
            if type(slice_component) is slice:
                sc = slice_component
200
                new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
Martin Bauer's avatar
Martin Bauer committed
201
                current_body = ast.Block([new_loop])
Martin Bauer's avatar
Martin Bauer committed
202
            else:
203
204
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
                                                 sp.sympify(slice_component))
Martin Bauer's avatar
Martin Bauer committed
205
                current_body.insert_front(assignment)
206

207
    return current_body, ghost_layers
208
209


Martin Bauer's avatar
Martin Bauer committed
210
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
Martin Bauer's avatar
Martin Bauer committed
211
    r"""
212
    Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
Martin Bauer's avatar
Martin Bauer committed
213
214
215
216
217
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.

218
219
220
221
222
223
224
225
226
    Args:
        field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
        coordinates: mapping of coordinate ids to its value, where stride*value is calculated
        previous_ptr: the pointer which is de-referenced

    Returns
        tuple with the new pointer symbol and the calculated offset

    Examples:
Martin Bauer's avatar
Martin Bauer committed
227
        >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1)
Martin Bauer's avatar
Martin Bauer committed
228
        >>> x, y = sp.symbols("x y")
Martin Bauer's avatar
Martin Bauer committed
229
230
        >>> prev_pointer = TypedSymbol("ptr", "double")
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer)
231
        (ptr_01, _stride_myfield_0*x + _stride_myfield_0)
Martin Bauer's avatar
Martin Bauer committed
232
        >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer)
233
        (ptr_01_1m2, _stride_myfield_0*x + _stride_myfield_0 + _stride_myfield_1*y - 2*_stride_myfield_1)
Martin Bauer's avatar
Martin Bauer committed
234
    """
Martin Bauer's avatar
Martin Bauer committed
235
    field = field_access.field
236
237
    offset = 0
    name = ""
Martin Bauer's avatar
Martin Bauer committed
238
    list_to_hash = []
Martin Bauer's avatar
Martin Bauer committed
239
240
    for coordinate_id, coordinate_value in coordinates.items():
        offset += field.strides[coordinate_id] * coordinate_value
241

Martin Bauer's avatar
Martin Bauer committed
242
        if coordinate_id < field.spatial_dimensions:
243
            offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
Martin Bauer's avatar
Martin Bauer committed
244
            if type(field_access.offsets[coordinate_id]) is int:
245
                name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
246
            else:
Martin Bauer's avatar
Martin Bauer committed
247
                list_to_hash.append(field_access.offsets[coordinate_id])
248
        else:
Martin Bauer's avatar
Martin Bauer committed
249
            if type(coordinate_value) is int:
Martin Bauer's avatar
Martin Bauer committed
250
                name += "_%d%d" % (coordinate_id, coordinate_value)
251
            else:
Martin Bauer's avatar
Martin Bauer committed
252
                list_to_hash.append(coordinate_value)
253

Martin Bauer's avatar
Martin Bauer committed
254
    if len(list_to_hash) > 0:
255
        name += hashlib.md5(pickle.dumps(list_to_hash)).hexdigest()[:16]
256

Martin Bauer's avatar
Martin Bauer committed
257
    name = name.replace("-", 'm')
Martin Bauer's avatar
Martin Bauer committed
258
259
    new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype)
    return new_ptr, offset
260
261


262
def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
263
    """
Martin Bauer's avatar
Martin Bauer committed
264
    Creates base pointer specification for :func:`resolve_field_accesses` function.
Martin Bauer's avatar
Martin Bauer committed
265
266
267

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
268
    zero directly in the field access. These specifications are defined dependent on the loop ordering.
Martin Bauer's avatar
Martin Bauer committed
269
270
    This function translates more readable version into the specification above.

271
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
272
273
274
275
276
277
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

Martin Bauer's avatar
Martin Bauer committed
278
279
280
    Args:
        base_pointer_specification: nested list with above specifications
        loop_order: list with ordering of loops from outer to inner
Martin Bauer's avatar
Martin Bauer committed
281
282
        spatial_dimensions: number of spatial dimensions
        index_dimensions: number of index dimensions
Martin Bauer's avatar
Martin Bauer committed
283
284
285

    Returns:
        list of tuples that can be passed to :func:`resolve_field_accesses`
Martin Bauer's avatar
Martin Bauer committed
286
287
288
289
290

    Examples:
        >>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
        ...                         spatial_dimensions=3, index_dimensions=1)
        [[0], [3], [1, 2]]
291
292
    """
    result = []
Martin Bauer's avatar
Martin Bauer committed
293
294
    specified_coordinates = set()
    loop_order = list(reversed(loop_order))
Martin Bauer's avatar
Martin Bauer committed
295
    for spec_group in base_pointer_specification:
Martin Bauer's avatar
Martin Bauer committed
296
297
298
        new_group = []

        def add_new_element(elem):
Martin Bauer's avatar
Martin Bauer committed
299
            if elem >= spatial_dimensions + index_dimensions:
300
                raise ValueError("Coordinate %d does not exist" % (elem,))
Martin Bauer's avatar
Martin Bauer committed
301
302
            new_group.append(elem)
            if elem in specified_coordinates:
303
                raise ValueError("Coordinate %d specified two times" % (elem,))
Martin Bauer's avatar
Martin Bauer committed
304
            specified_coordinates.add(elem)
305

Martin Bauer's avatar
Martin Bauer committed
306
        for element in spec_group:
307
            if type(element) is int:
Martin Bauer's avatar
Martin Bauer committed
308
                add_new_element(element)
309
310
311
312
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
Martin Bauer's avatar
Martin Bauer committed
313
                    add_new_element(loop_order[index])
314
315
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
Martin Bauer's avatar
Martin Bauer committed
316
                    add_new_element(loop_order[-index])
317
                elif element == "all":
Martin Bauer's avatar
Martin Bauer committed
318
                    for i in range(spatial_dimensions):
Martin Bauer's avatar
Martin Bauer committed
319
                        add_new_element(i)
320
321
322
323
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
Martin Bauer's avatar
Martin Bauer committed
324
                add_new_element(spatial_dimensions + index)
325
            else:
326
                raise ValueError("Unknown specification %s" % (element,))
327

Martin Bauer's avatar
Martin Bauer committed
328
        result.append(new_group)
329

Martin Bauer's avatar
Martin Bauer committed
330
    all_coordinates = set(range(spatial_dimensions + index_dimensions))
Martin Bauer's avatar
Martin Bauer committed
331
    rest = all_coordinates - specified_coordinates
332
333
    if rest:
        result.append(list(rest))
334

335
336
337
    return result


Martin Bauer's avatar
Martin Bauer committed
338
339
340
341
342
343
344
345
346
347
348
349
350
def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
    """Used for buffer fields to determine the linearized index of the buffer dependent on loop counter symbols.

    Args:
        ast_node: ast before any field accesses are resolved
        loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
                       for GPU kernels: list of 'loop counters' from inner to outer loop
        loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default

    Returns:
        base buffer index - required by 'resolve_buffer_accesses' function
    """
    if loop_counters is None or loop_iterations is None:
351
        loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
Martin Bauer's avatar
Martin Bauer committed
352
        loops.reverse()
353
        parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
Martin Bauer's avatar
Martin Bauer committed
354
        assert len(loops) == len(parents_of_innermost_loop)
355
        assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
Martin Bauer's avatar
Martin Bauer committed
356
357
358
359

        loop_iterations = [(l.stop - l.start) / l.step for l in loops]
        loop_counters = [l.loop_counter_symbol for l in loops]

360
    field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
361
    buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
Martin Bauer's avatar
Martin Bauer committed
362
363
364
365
366
367
    loop_counters = [v * len(buffer_accesses) for v in loop_counters]

    base_buffer_index = loop_counters[0]
    stride = 1
    for idx, var in enumerate(loop_counters[1:]):
        cur_stride = loop_iterations[idx]
368
        stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
Martin Bauer's avatar
Martin Bauer committed
369
370
371
372
        base_buffer_index += var * stride
    return base_buffer_index


373
374
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):

Martin Bauer's avatar
Martin Bauer committed
375
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
376
        if isinstance(expr, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
377
            field_access = expr
378
379

            # Do not apply transformation if field is not a buffer
Martin Bauer's avatar
Martin Bauer committed
380
            if not FieldType.is_buffer(field_access.field):
381
382
                return expr

Martin Bauer's avatar
Martin Bauer committed
383
            buffer = field_access.field
384
            field_ptr = FieldPointerSymbol(buffer.name, buffer.dtype, const=buffer.name in read_only_field_names)
385

Martin Bauer's avatar
Martin Bauer committed
386
387
            buffer_index = base_buffer_index
            if len(field_access.index) > 1:
388
                raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')
389

Martin Bauer's avatar
Martin Bauer committed
390
391
392
            if len(field_access.index) > 0:
                cell_index = field_access.index[0]
                buffer_index += cell_index
393

394
            result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
Martin Bauer's avatar
Martin Bauer committed
395
                                             field_access.index)
396

Martin Bauer's avatar
Martin Bauer committed
397
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
398
399
400
401
        else:
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr

402
403
            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
404
405
406
407
408
409
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
410
411
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
412
        else:
Martin Bauer's avatar
Martin Bauer committed
413
414
            for i, a in enumerate(sub_ast.args):
                visit_node(a)
415

Martin Bauer's avatar
Martin Bauer committed
416
    return visit_node(ast_node)
417

418

419
def resolve_field_accesses(ast_node, read_only_field_names=set(),
Martin Bauer's avatar
Martin Bauer committed
420
421
                           field_to_base_pointer_info=MappingProxyType({}),
                           field_to_fixed_coordinates=MappingProxyType({})):
Martin Bauer's avatar
Martin Bauer committed
422
423
424
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

425
426
427
428
429
430
    Args:
        ast_node: the AST root
        read_only_field_names: set of field names which are considered read-only
        field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created
                                    for details see :func:`parse_base_pointer_info`
        field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
Martin Bauer's avatar
Martin Bauer committed
431
                                    counters to index the field these symbols are used as coordinates
432
433
434

    Returns
        transformed AST
Martin Bauer's avatar
Martin Bauer committed
435
    """
436
437
    field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
    field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
438

Martin Bauer's avatar
Martin Bauer committed
439
    def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
440
        if isinstance(expr, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
441
            field_access = expr
Martin Bauer's avatar
Martin Bauer committed
442
            field = field_access.field
443

Martin Bauer's avatar
Martin Bauer committed
444
            if field_access.indirect_addressing_fields:
445
446
447
448
449
                new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
                                    for off in field_access.offsets)
                new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
                                    if isinstance(ind, sp.Basic) else ind
                                    for ind in field_access.index)
Martin Bauer's avatar
Martin Bauer committed
450
                field_access = Field.Access(field_access.field, new_offsets,
451
                                            new_indices, field_access.is_absolute_access)
452

Martin Bauer's avatar
Martin Bauer committed
453
454
            if field.name in field_to_base_pointer_info:
                base_pointer_info = field_to_base_pointer_info[field.name]
455
            else:
456
457
458
459
                base_pointer_info = [
                    list(
                        range(field.index_dimensions + field.spatial_dimensions))
                ]
460

461
462
463
464
            field_ptr = FieldPointerSymbol(
                field.name,
                field.dtype,
                const=field.name in read_only_field_names)
465

Martin Bauer's avatar
Martin Bauer committed
466
467
468
469
            def create_coordinate_dict(group_param):
                coordinates = {}
                for e in group_param:
                    if e < field.spatial_dimensions:
Martin Bauer's avatar
Martin Bauer committed
470
                        if field.name in field_to_fixed_coordinates:
471
                            if not field_access.is_absolute_access:
472
                                coordinates[e] = field_to_fixed_coordinates[field.name][e]
473
474
                            else:
                                coordinates[e] = 0
475
                        else:
Martin Bauer's avatar
Martin Bauer committed
476
                            if not field_access.is_absolute_access:
477
                                coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
Martin Bauer's avatar
Martin Bauer committed
478
479
                            else:
                                coordinates[e] = 0
Martin Bauer's avatar
Martin Bauer committed
480
                        coordinates[e] *= field.dtype.item_size
481
                    else:
482
                        if isinstance(field.dtype, StructType):
Martin Bauer's avatar
Martin Bauer committed
483
                            assert field.index_dimensions == 1
Martin Bauer's avatar
Martin Bauer committed
484
                            accessed_field_name = field_access.index[0]
485
486
                            if isinstance(accessed_field_name, sp.Symbol):
                                accessed_field_name = accessed_field_name.name
Martin Bauer's avatar
Martin Bauer committed
487
                            assert isinstance(accessed_field_name, str)
488
                            coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
489
                        else:
490
                            coordinates[e] = field_access.index[e - field.spatial_dimensions]
491

Martin Bauer's avatar
Martin Bauer committed
492
                return coordinates
493

Martin Bauer's avatar
Martin Bauer committed
494
            last_pointer = field_ptr
495

Martin Bauer's avatar
Martin Bauer committed
496
497
            for group in reversed(base_pointer_info[1:]):
                coord_dict = create_coordinate_dict(group)
498
                new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
Martin Bauer's avatar
Martin Bauer committed
499
                if new_ptr not in enclosing_block.symbols_defined:
500
501
                    new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
                    enclosing_block.insert_before(new_assignment, sympy_assignment)
Martin Bauer's avatar
Martin Bauer committed
502
                last_pointer = new_ptr
503

Martin Bauer's avatar
Martin Bauer committed
504
            coord_dict = create_coordinate_dict(base_pointer_info[0])
505
506
507
            _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
            result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
                                             field_access.offsets, field_access.index)
508

Martin Bauer's avatar
Martin Bauer committed
509
            if isinstance(get_base_type(field_access.field.dtype), StructType):
510
511
512
513
                accessed_field_name = field_access.index[0]
                if isinstance(accessed_field_name, sp.Symbol):
                    accessed_field_name = accessed_field_name.name
                new_type = field_access.field.dtype.get_element_type(accessed_field_name)
514
                result = reinterpret_cast_func(result, new_type)
515

Martin Bauer's avatar
Martin Bauer committed
516
            return visit_sympy_expr(result, enclosing_block, sympy_assignment)
517
        else:
Martin Bauer's avatar
Martin Bauer committed
518
519
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr
520

521
522
523
524
525
            if hasattr(expr, 'args'):
                new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
            else:
                new_args = []
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
Martin Bauer's avatar
Martin Bauer committed
526
527
528
529
530
531
            return expr.func(*new_args, **kwargs) if new_args else expr

    def visit_node(sub_ast):
        if isinstance(sub_ast, ast.SympyAssignment):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
532
533
            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
534
535
536
        elif isinstance(sub_ast, ast.Conditional):
            enclosing_block = sub_ast.parent
            assert type(enclosing_block) is ast.Block
537
            sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, enclosing_block, sub_ast)
538
539
540
            visit_node(sub_ast.true_block)
            if sub_ast.false_block:
                visit_node(sub_ast.false_block)
541
        else:
542
543
544
            if isinstance(sub_ast, (bool, int, float)):
                return
            for a in sub_ast.args:
Martin Bauer's avatar
Martin Bauer committed
545
                visit_node(a)
546

Martin Bauer's avatar
Martin Bauer committed
547
    return visit_node(ast_node)
548
549


Martin Bauer's avatar
Martin Bauer committed
550
def move_constants_before_loop(ast_node):
551
552
    """Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.

Martin Bauer's avatar
Martin Bauer committed
553
    Call this after creating the loop structure with :func:`make_loop_over_domain`
Martin Bauer's avatar
Martin Bauer committed
554
    """
Martin Bauer's avatar
Martin Bauer committed
555
    def find_block_to_move_to(node):
Martin Bauer's avatar
Martin Bauer committed
556
557
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
558
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
559
560
561
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
562
563
        assert isinstance(node.parent, ast.Block)

Martin Bauer's avatar
Martin Bauer committed
564
565
        last_block = node.parent
        last_block_child = node
566
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
567
        prev_element = node
568
569
        while element:
            if isinstance(element, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
570
571
                last_block = element
                last_block_child = prev_element
572
573

            if isinstance(element, ast.Conditional):
574
                break
575
            else:
Martin Bauer's avatar
Martin Bauer committed
576
577
                critical_symbols = element.symbols_defined
            if node.undefined_symbols.intersection(critical_symbols):
578
                break
Martin Bauer's avatar
Martin Bauer committed
579
            prev_element = element
580
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
581
        return last_block, last_block_child
582

583
    def check_if_assignment_already_in_block(assignment, target_block, rhs_or_lhs=True):
Martin Bauer's avatar
Martin Bauer committed
584
        for arg in target_block.args:
585
586
            if type(arg) is not ast.SympyAssignment:
                continue
587
            if (rhs_or_lhs and arg.rhs == assignment.rhs) or (not rhs_or_lhs and arg.lhs == assignment.lhs):
588
589
590
                return arg
        return None

Martin Bauer's avatar
Martin Bauer committed
591
    def get_blocks(node, result_list):
592
        if isinstance(node, ast.Block):
Martin Bauer's avatar
Martin Bauer committed
593
            result_list.append(node)
594
595
        if isinstance(node, ast.Node):
            for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
596
                get_blocks(a, result_list)
597

Martin Bauer's avatar
Martin Bauer committed
598
599
600
    all_blocks = []
    get_blocks(ast_node, all_blocks)
    for block in all_blocks:
Martin Bauer's avatar
Martin Bauer committed
601
        children = block.take_child_nodes()
602
603
604
605
        # Every time a symbol can be replaced in the current block because the assignment
        # was found in a parent block, but with a different lhs symbol (same rhs)
        # the outer symbol is inserted here as key.
        substitute_variables = {}
606
        for child in children:
607
608
609
            # Before traversing the next child, all symbols are substituted first.
            child.subs(substitute_variables)

610
            if not isinstance(child, ast.SympyAssignment):  # only move SympyAssignments
611
612
613
                block.append(child)
                continue

614
            target, child_to_insert_before = find_block_to_move_to(child)
615
            if target == block:     # movement not possible
616
                target.append(child)
617
            else:
618
                if isinstance(child, ast.SympyAssignment):
619
                    exists_already = check_if_assignment_already_in_block(child, target, False)
620
                else:
621
                    exists_already = False
Martin Bauer's avatar
Martin Bauer committed
622

623
                if not exists_already:
624
                    rhs_identical = check_if_assignment_already_in_block(child, target, True)
625
626
627
628
629
630
631
                    if rhs_identical:
                        # there is already an assignment out there with the same rhs
                        # -> replace all lhs symbols in this block with the lhs of the outer assignment
                        # -> remove the local assignment (do not re-append child to the former block)
                        substitute_variables[child.lhs] = rhs_identical.lhs
                    else:
                        target.insert_before(child, child_to_insert_before)
Martin Bauer's avatar
Martin Bauer committed
632
                elif exists_already and exists_already.rhs == child.rhs:
633
634
635
636
637
                    if target.args.index(exists_already) > target.args.index(child_to_insert_before):
                        assert target.args.count(exists_already) == 1
                        assert target.args.count(child_to_insert_before) == 1
                        target.args.remove(exists_already)
                        target.insert_before(exists_already, child_to_insert_before)
638
                else:
639
640
641
642
                    # this variable already exists in outer block, but with different rhs
                    # -> symbol has to be renamed
                    assert isinstance(child.lhs, TypedSymbol)
                    new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
643
644
                    target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const),
                                         child_to_insert_before)
645
                    substitute_variables[child.lhs] = new_symbol
646
647


Martin Bauer's avatar
Martin Bauer committed
648
def split_inner_loop(ast_node: ast.Node, symbol_groups):
Martin Bauer's avatar
Martin Bauer committed
649
650
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
651

Martin Bauer's avatar
Martin Bauer committed
652
653
654
655
656
    Args:
        ast_node: AST root
        symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
                       updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups
                       and which no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
657
    """
Martin Bauer's avatar
Martin Bauer committed
658
659
    all_loops = ast_node.atoms(ast.LoopOverCoordinate)
    inner_loop = [l for l in all_loops if l.is_innermost_loop]
660
    assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
Martin Bauer's avatar
Martin Bauer committed
661
662
663
664
665
666
667
    inner_loop = inner_loop[0]
    assert type(inner_loop.body) is ast.Block
    outer_loop = [l for l in all_loops if l.is_outermost_loop]
    assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
    outer_loop = outer_loop[0]

    symbols_with_temporary_array = OrderedDict()
668
    assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args if hasattr(a, 'lhs'))
Martin Bauer's avatar
Martin Bauer committed
669
670

    assignment_groups = []
Martin Bauer's avatar
Martin Bauer committed
671
    for symbol_group in symbol_groups:
672
        # get all dependent symbols
Martin Bauer's avatar
Martin Bauer committed
673
        symbols_to_process = list(symbol_group)
Martin Bauer's avatar
Martin Bauer committed
674
675
676
677
        symbols_resolved = set()
        while symbols_to_process:
            s = symbols_to_process.pop()
            if s in symbols_resolved:
678
679
                continue

Martin Bauer's avatar
Martin Bauer committed
680
            if s in assignment_map:  # if there is no assignment inside the loop body it is independent already
Martin Bauer's avatar
Martin Bauer committed
681
                for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
682
683
                    if not isinstance(new_symbol, AbstractField.AbstractAccess) and \
                            new_symbol not in symbols_with_temporary_array:
Martin Bauer's avatar
Martin Bauer committed
684
                        symbols_to_process.append(new_symbol)
Martin Bauer's avatar
Martin Bauer committed
685
            symbols_resolved.add(s)
686

Martin Bauer's avatar
Martin Bauer committed
687
        for symbol in symbol_group:
688
            if not isinstance(symbol, AbstractField.AbstractAccess):
689
                assert type(symbol) is TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
690
                new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
691
692
                symbols_with_temporary_array[symbol] = sp.IndexedBase(
                    new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
693

Martin Bauer's avatar
Martin Bauer committed
694
695
696
        assignment_group = []
        for assignment in inner_loop.body.args:
            if assignment.lhs in symbols_resolved:
697
698
                new_rhs = assignment.rhs.subs(
                    symbols_with_temporary_array.items())
699
                if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group:
700
                    assert type(assignment.lhs) is TypedSymbol
701
702
                    new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
                    new_lhs = sp.IndexedBase(new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
703
                else:
Martin Bauer's avatar
Martin Bauer committed
704
705
706
                    new_lhs = assignment.lhs
                assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
        assignment_groups.append(assignment_group)
707

708
709
710
711
    new_loops = [
        inner_loop.new_loop_with_different_body(ast.Block(group))
        for group in assignment_groups
    ]
Martin Bauer's avatar
Martin Bauer committed
712
    inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
713

Martin Bauer's avatar
Martin Bauer committed
714
    for tmp_array in symbols_with_temporary_array:
715
716
        tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
        alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start)
Martin Bauer's avatar
Martin Bauer committed
717
718
719
        free_node = ast.TemporaryMemoryFree(alloc_node)
        outer_loop.parent.insert_front(alloc_node)
        outer_loop.parent.append(free_node)
720
721


Martin Bauer's avatar
Martin Bauer committed
722
def cut_loop(loop_node, cutting_points):
Martin Bauer's avatar
Martin Bauer committed
723
724
725
726
727
728
729
730
731
732
    """Cuts loop at given cutting points.

    One loop is transformed into len(cuttingPoints)+1 new loops that range from
    old_begin to cutting_points[1], ..., cutting_points[-1] to old_end

    Modifies the ast in place

    Returns:
        list of new loop nodes
    """
Martin Bauer's avatar
Martin Bauer committed
733
    if loop_node.step != 1:
734
        raise NotImplementedError("Can only split loops that have a step of 1")
735
    new_loops = ast.Block([])
Martin Bauer's avatar
Martin Bauer committed
736
737
    new_start = loop_node.start
    cutting_points = list(cutting_points) + [loop_node.stop]
Martin Bauer's avatar
Martin Bauer committed
738
739
    for new_end in cutting_points:
        if new_end - new_start == 1:
Martin Bauer's avatar
Martin Bauer committed
740
741
742
            new_body = deepcopy(loop_node.body)
            new_body.subs({loop_node.loop_counter_symbol: new_start})
            new_loops.append(new_body)
Martin Bauer's avatar
Martin Bauer committed
743
744
        elif new_end - new_start == 0:
            pass
745
        else:
746
747
748
            new_loop = ast.LoopOverCoordinate(
                deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
                new_start, new_end, loop_node.step)
Martin Bauer's avatar
Martin Bauer committed
749
            new_loops.append(new_loop)
Martin Bauer's avatar
Martin Bauer committed
750
        new_start = new_end
Martin Bauer's avatar
Martin Bauer committed
751
    loop_node.parent.replace(loop_node, new_loops)
Martin Bauer's avatar
Martin Bauer committed
752
    return new_loops
753
754


755
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = False) -> None:
756
    """Removes conditionals that are always true/false.
757
758

    Args:
759
760
761
762
763
764
        node: ast node, all descendants of this node are simplified
        loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false
                                     depending on the surrounding loop. For example if the surrounding loop goes from
                                     x=0 to 10 and the condition is x < 0, it is removed.
                                     This analysis needs the integer set library (ISL) islpy, so it is not done by
                                     default.
765
    """
766
767
768
769
770
    for conditional in node.atoms(ast.Conditional):
        conditional.condition_expr = sp.simplify(conditional.condition_expr)
        if conditional.condition_expr == sp.true:
            conditional.parent.replace(conditional, [conditional.true_block])
        elif conditional.condition_expr == sp.false:
771
            conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
772
773
774
775
776
777
        elif loop_counter_simplification:
            try:
                # noinspection PyUnresolvedReferences
                from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
                simplify_loop_counter_dependent_conditional(conditional)
            except ImportError:
778
                warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")
779
780
781


def cleanup_blocks(node: ast.Node) -> None:
782
783
784
785
786
    """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
    if isinstance(node, ast.SympyAssignment):
        return
    elif isinstance(node, ast.Block):
        for a in list(node.args):
Martin Bauer's avatar
Martin Bauer committed
787
            cleanup_blocks(a)
788
789
790
791
792
        if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
            node.parent.replace(node, node.args)
            return
    else:
        for a in node.args:
Martin Bauer's avatar
Martin Bauer committed
793
            cleanup_blocks(a)
794
795


796
797
798
799
800
801
802
803
804
805
806
807
class KernelConstraintsCheck:
    """Checks if the input to create_kernel is valid.

    Test the following conditions:

    - SSA Form for pure symbols:
        -  Every pure symbol may occur only once as left-hand-side of an assignment
        -  Every pure symbol that is read, may not be written to later
    - Independence / Parallelization condition:
        - a field that is written may only be read at exact the same spatial position

    (Pure symbols are symbols that are not Field.Accesses)
Martin Bauer's avatar
Martin Bauer committed
808
    """
809
810
    FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])

811
    def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True):
812
813
        self._type_for_symbol = type_for_symbol

814
        self.scopes = NestedScopes()
815
816
817
        self._field_writes = defaultdict(set)
        self.fields_read = set()
        self.check_independence_condition = check_independence_condition
818
        self.check_double_write_condition = check_double_write_condition
819
820
821
822
823
824
825

    def process_assignment(self, assignment):
        # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
        new_rhs = self.process_expression(assignment.rhs)
        new_lhs = self._process_lhs(assignment.lhs)
        return ast.SympyAssignment(new_lhs, new_rhs)

826
    def process_expression(self, rhs, type_constants=True):
827
828
        from pystencils.interpolation_astnodes import InterpolatorAccess

829
        self._update_accesses_rhs(rhs)
830
        if isinstance(rhs, AbstractField.AbstractAccess):
Martin Bauer's avatar
Martin Bauer committed
831
            self.fields_read.add(rhs.field)
Martin Bauer's avatar
Martin Bauer committed
832
            self.fields_read.update(rhs.indirect_addressing_fields)
833
            return rhs
834
835
836
837
838
        elif isinstance(rhs, InterpolatorAccess):
            new_args = [self.process_expression(arg, type_constants) for arg in rhs.offsets]
            if new_args:
                rhs.offsets = new_args
            return rhs
839
840
        elif isinstance(rhs, ImaginaryUnit):
            return TypedImaginaryUnit(create_type(self._type_for_symbol['_complex_type']))
841
842
843
        elif isinstance(rhs, TypedSymbol):
            return rhs
        elif isinstance(rhs, sp.Symbol):
844
            return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
845
846
        elif type_constants and isinstance(rhs, np.generic):
            return cast_func(rhs, create_type(rhs.dtype))
847
        elif type_constants and isinstance(rhs, sp.Number):
848
            return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
849
850
851
852
853
854
855
856
857
858
859
860
861
862
        # Very important that this clause comes before BooleanFunction
        elif isinstance(rhs, cast_func):
            return cast_func(
                self.process_expression(rhs.args[0], type_constants=False),
                rhs.dtype)
        elif isinstance(rhs, sp.boolalg.BooleanFunction) or \
                type(rhs) in pystencils.integer_functions.__dict__.values():
            new_args = [self.process_expression(a, type_constants) for a in rhs.args]
            types_of_expressions = [get_type_of_expression(a) for a in new_args]
            arg_type = collate_types(types_of_expressions, forbid_collation_to_float=True)
            new_args = [a if not hasattr(a, 'dtype') or a.dtype == arg_type
                        else cast_func(a, arg_type)
                        for a in new_args]
            return rhs.func(*new_args)
863
        elif isinstance(rhs, sp.Mul):
864
865
866
867
            new_args = [
                self.process_expression(arg, type_constants)
                if arg not in (-1, 1) else arg for arg in rhs.args
            ]
868
            return rhs.func(*new_args) if new_args else rhs
869
870
        elif isinstance(rhs, sp.Indexed):
            return rhs
871
872
873
        else:
            if isinstance(rhs, sp.Pow):
                # don't process exponents -> they should remain integers
874
875
876
                return sp.Pow(
                    self.process_expression(rhs.args[0], type_constants),
                    rhs.args[1])
877
            else:
878
879
880
881
                new_args = [
                    self.process_expression(arg, type_constants)
                    for arg in rhs.args
                ]
882
                return rhs.func(*new_args) if new_args else rhs
883
884
885
886
887
888
889
890

    @property
    def fields_written(self):
        return set(k.field for k, v in self._field_writes.items() if len(v))

    def _process_lhs(self, lhs):
        assert isinstance(lhs, sp.Symbol)
        self._update_accesses_lhs(lhs)
891
        if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
892
893
894
895
896
            return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
        else:
            return lhs

    def _update_accesses_lhs(self, lhs):
897
        if isinstance(lhs, AbstractField.AbstractAccess):
898
899
            fai = self.FieldAndIndex(lhs.field, lhs.index)
            self._field_writes[fai].add(lhs.offsets)
900
901
902
903
            if self.check_double_write_condition and len(self._field_writes[fai]) > 1:
                raise ValueError(
                    "Field {} is written at two different locations".format(
                        lhs.field.name))
904
        elif isinstance(lhs, sp.Symbol):
905
            if self.scopes.is_defined_locally(lhs):
906
                raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
907
            if lhs in self.scopes.free_parameters:
908
                raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
909
            self.scopes.define_symbol(lhs)
910
911

    def _update_accesses_rhs(self, rhs):
912
        if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition:
913
914
            writes = self._field_writes[self.FieldAndIndex(
                rhs.field, rhs.index)]