From bad8ff83ca3af3824f82b6e525271cfa7a5fd47c Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Mon, 19 Feb 2024 10:27:28 +0100
Subject: [PATCH] More refactoring

---
 conftest.py                                   |    3 +-
 src/pystencils/__init__.py                    |    7 +-
 .../backend/kernelcreation/__init__.py        |    2 -
 .../backend/kernelcreation/analysis.py        |    3 +-
 .../backend/kernelcreation/context.py         |    4 +-
 .../backend/kernelcreation/defaults.py        |    2 +-
 .../backend/kernelcreation/freeze.py          |    5 +-
 .../backend/kernelcreation/iteration_space.py |    2 +-
 src/pystencils/config.py                      |    3 +-
 .../datahandling/serial_datahandling.py       |    3 +-
 src/pystencils/display_utils.py               |    5 +-
 src/pystencils/fd/derivation.py               |    2 +-
 src/pystencils/fd/finitedifferences.py        |    4 +-
 src/pystencils/fd/spatial.py                  |    6 +-
 src/pystencils/kernel_decorator.py            |    4 +-
 src/pystencils/kernel_wrapper.py              |    2 +-
 src/pystencils/kernelcreation.py              |    2 +-
 src/pystencils/spatial_coordinates.py         |   13 +-
 src/pystencils/sympyextensions/__init__.py    |    5 +-
 src/pystencils/sympyextensions/assignment.py  |  104 --
 .../sympyextensions/assignment_collection.py  |  474 --------
 src/pystencils/sympyextensions/astnodes.py    | 1056 ++++++++---------
 src/pystencils/sympyextensions/bit_masks.py   |    1 -
 src/pystencils/sympyextensions/math.py        |    2 +-
 .../sympyextensions/simplifications.py        |   15 +-
 .../sympyextensions/simplificationstrategy.py |    2 +-
 src/pystencils/sympyextensions/typed_sympy.py |   91 +-
 27 files changed, 619 insertions(+), 1203 deletions(-)
 delete mode 100644 src/pystencils/sympyextensions/assignment.py
 delete mode 100644 src/pystencils/sympyextensions/assignment_collection.py

diff --git a/conftest.py b/conftest.py
index 2fd7f8e81..ca7c153b4 100644
--- a/conftest.py
+++ b/conftest.py
@@ -10,7 +10,8 @@ from nbconvert import PythonExporter
 
 # Trigger config file reading / creation once - to avoid race conditions when multiple instances are creating it
 # at the same time
-from pystencils.cpu import cpujit
+# TODO: replace with new backend
+# from pystencils.cpu import cpujit
 
 # trigger cython imports - there seems to be a problem when multiple processes try to compile the same cython file
 # at the same time
diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py
index 161530f88..51556f799 100644
--- a/src/pystencils/__init__.py
+++ b/src/pystencils/__init__.py
@@ -2,16 +2,17 @@
 from .enums import Backend, Target
 from . import fd
 from . import stencil as stencil
-from pystencils.sympyextensions.assignmentcollection.assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil
 from .display_utils import get_code_obj, get_code_str, show_code, to_dot
 from .field import Field, FieldType, fields
 from .config import CreateKernelConfig
 from .cache import clear_cache
 from .kernel_decorator import kernel, kernel_config
 from .kernelcreation import create_kernel
-from pystencils.sympyextensions.assignmentcollection import AssignmentCollection
 from .slicing import make_slice
 from .spatial_coordinates import x_, x_staggered, x_staggered_vector, x_vector, y_, y_staggered, z_, z_staggered
+from .sympyextensions import Assignment, AssignmentCollection, AddAugmentedAssignment
+from .sympyextensions.astnodes import assignment_from_stencil
+from .sympyextensions.typed_sympy import TypedSymbol
 from .sympyextensions.math import SymbolCreator
 from .datahandling import create_data_handling
 
@@ -19,7 +20,7 @@ __all__ = ['Field', 'FieldType', 'fields',
            'TypedSymbol',
            'make_slice',
            'CreateKernelConfig',
-           'create_kernel', 'create_staggered_kernel',
+           'create_kernel',
            'Target', 'Backend',
            'show_code', 'to_dot', 'get_code_obj', 'get_code_str',
            'AssignmentCollection',
diff --git a/src/pystencils/backend/kernelcreation/__init__.py b/src/pystencils/backend/kernelcreation/__init__.py
index 749493059..6ac30e9ee 100644
--- a/src/pystencils/backend/kernelcreation/__init__.py
+++ b/src/pystencils/backend/kernelcreation/__init__.py
@@ -99,7 +99,6 @@ It is furthermore annotated with constraints collected during the translation, a
 """
 
 from .config import CreateKernelConfig
-from .kernelcreation import create_kernel
 
 from .context import KernelCreationContext
 from .analysis import KernelAnalysis
@@ -115,7 +114,6 @@ from .iteration_space import (
 
 __all__ = [
     "CreateKernelConfig",
-    "create_kernel",
     "KernelCreationContext",
     "KernelAnalysis",
     "FreezeExpressions",
diff --git a/src/pystencils/backend/kernelcreation/analysis.py b/src/pystencils/backend/kernelcreation/analysis.py
index 1b498427d..b0267a3d6 100644
--- a/src/pystencils/backend/kernelcreation/analysis.py
+++ b/src/pystencils/backend/kernelcreation/analysis.py
@@ -9,8 +9,7 @@ import sympy as sp
 from .context import KernelCreationContext
 
 from ...field import Field
-from pystencils.sympyextensions.assignmentcollection.assignment import Assignment
-from ...simp import AssignmentCollection
+from ...sympyextensions import Assignment, AssignmentCollection
 
 from ..exceptions import PsInternalCompilerError, KernelConstraintsError
 
diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index d2cddb142..2080efed3 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -1,14 +1,14 @@
 from __future__ import annotations
 
 from ...field import Field, FieldType
-from ...typing import TypedSymbol, BasicType, StructType
-
+from ...sympyextensions.typed_sympy import TypedSymbol, BasicType, StructType
 from ..arrays import PsLinearizedArray
 from ..types import PsIntegerType
 from ..types.quick import make_type
 from ..constraints import PsKernelConstraint
 from ..exceptions import PsInternalCompilerError, KernelConstraintsError
 
+
 from .config import CreateKernelConfig
 from .iteration_space import IterationSpace, FullIterationSpace, SparseIterationSpace
 
diff --git a/src/pystencils/backend/kernelcreation/defaults.py b/src/pystencils/backend/kernelcreation/defaults.py
index fc0a602a1..b1822dc77 100644
--- a/src/pystencils/backend/kernelcreation/defaults.py
+++ b/src/pystencils/backend/kernelcreation/defaults.py
@@ -20,7 +20,7 @@ from typing import TypeVar, Generic, Callable
 from ..types import PsAbstractType, PsSignedIntegerType, PsStructType
 from ..typed_expressions import PsTypedVariable
 
-from ...typing import TypedSymbol
+from pystencils.sympyextensions.typed_sympy import TypedSymbol
 
 SymbolT = TypeVar("SymbolT")
 
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 06d6f1adc..94b660fd9 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -4,10 +4,9 @@ import sympy as sp
 import pymbolic.primitives as pb
 from pymbolic.interop.sympy import SympyToPymbolicMapper
 
-from pystencils.sympyextensions.assignmentcollection.assignment import Assignment
-from ...simp import AssignmentCollection
+from ...sympyextensions import Assignment, AssignmentCollection
+from ...sympyextensions.typed_sympy import BasicType
 from ...field import Field, FieldType
-from ...typing import BasicType
 
 from .context import KernelCreationContext
 
diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py
index a739893c0..0e5828ef1 100644
--- a/src/pystencils/backend/kernelcreation/iteration_space.py
+++ b/src/pystencils/backend/kernelcreation/iteration_space.py
@@ -5,7 +5,7 @@ from dataclasses import dataclass
 from functools import reduce
 from operator import mul
 
-from ...simp import AssignmentCollection
+from ...sympyextensions import AssignmentCollection
 from ...field import Field, FieldType
 
 from ..typed_expressions import (
diff --git a/src/pystencils/config.py b/src/pystencils/config.py
index 5fff3b799..fdaed6665 100644
--- a/src/pystencils/config.py
+++ b/src/pystencils/config.py
@@ -5,8 +5,7 @@ from types import MappingProxyType
 from typing import Union, Tuple, List, Dict, Callable, Any, DefaultDict, Iterable
 
 from pystencils import Target, Backend, Field
-from pystencils.typing.typed_sympy import BasicType
-from pystencils.typing.utilities import collate_types
+from .sympyextensions.typed_sympy import BasicType
 
 import numpy as np
 
diff --git a/src/pystencils/datahandling/serial_datahandling.py b/src/pystencils/datahandling/serial_datahandling.py
index 0f5ddb431..e01db8dbb 100644
--- a/src/pystencils/datahandling/serial_datahandling.py
+++ b/src/pystencils/datahandling/serial_datahandling.py
@@ -9,7 +9,8 @@ from pystencils.datahandling.datahandling_interface import DataHandling
 from pystencils.enums import Target
 from pystencils.field import (Field, FieldType, create_numpy_array_with_layout,
                               layout_string_to_tuple, spatial_layout_string_to_tuple)
-from pystencils.gpu.gpu_array_handler import GPUArrayHandler, GPUNotAvailableHandler
+# TODO replace with platform
+# from pystencils.gpu.gpu_array_handler import GPUArrayHandler, GPUNotAvailableHandler
 from pystencils.slicing import normalize_slice, remove_ghost_layers
 from pystencils.utils import DotDict
 
diff --git a/src/pystencils/display_utils.py b/src/pystencils/display_utils.py
index 2fba12da6..bce5d493c 100644
--- a/src/pystencils/display_utils.py
+++ b/src/pystencils/display_utils.py
@@ -2,7 +2,6 @@ from typing import Any, Dict, Optional, Union
 
 import sympy as sp
 
-from pystencils.sympyextensions.astnodes import KernelFunction
 from pystencils.enums import Backend
 from pystencils.kernel_wrapper import KernelWrapper
 
@@ -41,7 +40,7 @@ def highlight_cpp(code: str):
     return HTML(highlight(code, CppLexer(), HtmlFormatter()))
 
 
-def get_code_obj(ast: Union[KernelFunction, KernelWrapper], custom_backend=None):
+def get_code_obj(ast: Union[KernelWrapper], custom_backend=None):
     """Returns an object to display generated code (C/C++ or CUDA)
 
     Can either be displayed as HTML in Jupyter notebooks or printed as normal string.
@@ -87,7 +86,7 @@ def _isnotebook():
         return False
 
 
-def show_code(ast: Union[KernelFunction, KernelWrapper], custom_backend=None):
+def show_code(ast: Union[KernelWrapper], custom_backend=None):
     code = get_code_obj(ast, custom_backend)
 
     if _isnotebook():
diff --git a/src/pystencils/fd/derivation.py b/src/pystencils/fd/derivation.py
index bea0a9674..f739c8f02 100644
--- a/src/pystencils/fd/derivation.py
+++ b/src/pystencils/fd/derivation.py
@@ -6,7 +6,7 @@ import sympy as sp
 
 from pystencils.field import Field
 from pystencils.stencil import direction_string_to_offset
-from pystencils.sympyextensions import multidimensional_sum, prod
+from pystencils.sympyextensions.math import multidimensional_sum, prod
 from pystencils.utils import LinearEquationSystem, fully_contains
 
 
diff --git a/src/pystencils/fd/finitedifferences.py b/src/pystencils/fd/finitedifferences.py
index e8ff2c80c..9c4116ee5 100644
--- a/src/pystencils/fd/finitedifferences.py
+++ b/src/pystencils/fd/finitedifferences.py
@@ -7,8 +7,8 @@ from pystencils.fd import Diff
 from pystencils.fd.derivative import diff_args
 from pystencils.fd.spatial import fd_stencils_standard
 from pystencils.field import Field
-from pystencils.simp.assignment_collection import AssignmentCollection
-from pystencils.sympyextensions import fast_subs
+from pystencils.sympyextensions import AssignmentCollection
+from pystencils.sympyextensions.math import fast_subs
 
 FieldOrFieldAccess = Union[Field, Field.Access]
 
diff --git a/src/pystencils/fd/spatial.py b/src/pystencils/fd/spatial.py
index 17f708ed8..1bec75726 100644
--- a/src/pystencils/fd/spatial.py
+++ b/src/pystencils/fd/spatial.py
@@ -3,10 +3,10 @@ from typing import Tuple
 
 import sympy as sp
 
-from pystencils.sympyextensions.astnodes import LoopOverCoordinate
 from pystencils.fd import Diff
 from pystencils.field import Field
-from pystencils.sympyextensions import generic_visit
+from pystencils.sympyextensions.astnodes import generic_visit
+from pystencils.sympyextensions.typed_sympy import is_loop_counter_symbol
 
 from .derivation import FiniteDifferenceStencilDerivation
 from .derivative import diff_args
@@ -112,7 +112,7 @@ def discretize_spatial_staggered(expr, dx, stencil=fd_stencils_standard):
         elif isinstance(e, Field.Access):
             return (e.neighbor(coordinate, sign) + e) / 2
         elif isinstance(e, sp.Symbol):
-            loop_idx = LoopOverCoordinate.is_loop_counter_symbol(e)
+            loop_idx = is_loop_counter_symbol(e)
             return e + sign / 2 if loop_idx == coordinate else e
         else:
             new_args = [staggered_visitor(a, coordinate, sign) for a in e.args]
diff --git a/src/pystencils/kernel_decorator.py b/src/pystencils/kernel_decorator.py
index a8246de07..deb94eec0 100644
--- a/src/pystencils/kernel_decorator.py
+++ b/src/pystencils/kernel_decorator.py
@@ -5,8 +5,8 @@ from typing import Callable, Union, List, Dict, Tuple
 
 import sympy as sp
 
-from pystencils.sympyextensions.assignmentcollection.assignment import Assignment
-from pystencils.sympyextensions import SymbolCreator
+from .sympyextensions import Assignment
+from .sympyextensions.math import SymbolCreator
 from pystencils.config import CreateKernelConfig
 
 __all__ = ['kernel', 'kernel_config']
diff --git a/src/pystencils/kernel_wrapper.py b/src/pystencils/kernel_wrapper.py
index b76ec4e3d..d5dfbecca 100644
--- a/src/pystencils/kernel_wrapper.py
+++ b/src/pystencils/kernel_wrapper.py
@@ -8,7 +8,7 @@ class KernelWrapper:
     Can be called while still providing access to underlying AST.
     """
 
-    def __init__(self, kernel, parameters, ast_node: pystencils.astnodes.KernelFunction):
+    def __init__(self, kernel, parameters, ast_node):
         self.kernel = kernel
         self.parameters = parameters
         self.ast = ast_node
diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index 52c70fbfb..fb4d6abfb 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -8,7 +8,7 @@ from .backend.kernelcreation.transformations import EraseAnonymousStructTypes
 
 from .enums import Target
 from .config import CreateKernelConfig
-from pystencils.sympyextensions.assignmentcollection import AssignmentCollection
+from .sympyextensions import AssignmentCollection
 
 
 def create_kernel(
diff --git a/src/pystencils/spatial_coordinates.py b/src/pystencils/spatial_coordinates.py
index a8be92c94..cc244b11c 100644
--- a/src/pystencils/spatial_coordinates.py
+++ b/src/pystencils/spatial_coordinates.py
@@ -1,19 +1,14 @@
-
 import sympy
+from pystencils.sympyextensions.typed_sympy import get_loop_counter_symbol
 
-import pystencils
-import pystencils.sympyextensions.astnodes
 
-x_, y_, z_ = tuple(pystencils.sympyextensions.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(3))
+x_, y_, z_ = tuple(get_loop_counter_symbol(i) for i in range(3))
 x_staggered, y_staggered, z_staggered = x_ + 0.5, y_ + 0.5, z_ + 0.5
 
 
 def x_vector(ndim):
-    return sympy.Matrix(tuple(
-        pystencils.sympyextensions.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(ndim)))
+    return sympy.Matrix(tuple(get_loop_counter_symbol(i) for i in range(ndim)))
 
 
 def x_staggered_vector(ndim):
-    return sympy.Matrix(tuple(
-        pystencils.sympyextensions.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) + 0.5 for i in range(ndim)
-    ))
+    return sympy.Matrix(tuple(get_loop_counter_symbol(i) + 0.5 for i in range(ndim)))
diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py
index d61e705d1..41d43d83c 100644
--- a/src/pystencils/sympyextensions/__init__.py
+++ b/src/pystencils/sympyextensions/__init__.py
@@ -1,5 +1,4 @@
-from .assignment import Assignment, AugmentedAssignment, AddAugmentedAssignment, assignment_from_stencil
-from pystencils.sympyextensions.assignment_collection import AssignmentCollection
+from .astnodes import Assignment, AugmentedAssignment, AddAugmentedAssignment, AssignmentCollection
 from .simplificationstrategy import SimplificationStrategy
 from .simplifications import (sympy_cse, sympy_cse_on_assignment_list, apply_to_all_assignments,
                               apply_on_all_subexpressions, subexpression_substitution_in_existing_subexpressions,
@@ -12,7 +11,7 @@ from .subexpression_insertion import (
     insert_squares, insert_symbol_times_minus_one)
 
 
-__all__ = ['Assignment', 'AugmentedAssignment', 'AddAugmentedAssignment', 'assignment_from_stencil',
+__all__ = ['Assignment', 'AugmentedAssignment', 'AddAugmentedAssignment',
            'AssignmentCollection', 'SimplificationStrategy',
            'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
            'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
diff --git a/src/pystencils/sympyextensions/assignment.py b/src/pystencils/sympyextensions/assignment.py
deleted file mode 100644
index 591bde9a5..000000000
--- a/src/pystencils/sympyextensions/assignment.py
+++ /dev/null
@@ -1,104 +0,0 @@
-import numpy as np
-import sympy as sp
-from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment
-from sympy.printing.latex import LatexPrinter
-
-
-def print_assignment_latex(printer, expr):
-    binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else ''
-    """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
-    printed_lhs = printer.doprint(expr.lhs)
-    printed_rhs = printer.doprint(expr.rhs)
-    return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}"
-
-
-def assignment_str(assignment):
-    op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else '←'
-    return fr"{assignment.lhs} {op} {assignment.rhs}"
-
-
-_old_new = sp.codegen.ast.Assignment.__new__
-
-
-# TODO Typing Part2 add default type, defult_float_type, default_int_type and use sane defaults
-def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
-    if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
-        assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
-        return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
-    return _old_new(cls, lhs, rhs, *args, **kwargs)
-
-
-Assignment.__str__ = assignment_str
-Assignment.__new__ = _Assignment__new__
-LatexPrinter._print_Assignment = print_assignment_latex
-
-AugmentedAssignment.__str__ = assignment_str
-LatexPrinter._print_AugmentedAssignment = print_assignment_latex
-
-sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
-
-
-def assignment_from_stencil(stencil_array, input_field, output_field,
-                            normalization_factor=None, order='visual') -> Assignment:
-    """Creates an assignment
-
-    Args:
-        stencil_array: nested list of numpy array defining the stencil weights
-        input_field: field or field access, defining where the stencil should be applied to
-        output_field: field or field access where the result is written to
-        normalization_factor: optional normalization factor for the stencil
-        order: defines how the stencil_array is interpreted. Possible values are 'visual' and 'numpy'.
-               For details see examples
-
-    Returns:
-        Assignment that can be used to create a kernel
-
-    Examples:
-        >>> import pystencils as ps
-        >>> f, g = ps.fields("f, g: [2D]")
-        >>> stencil = [[0, 2, 0],
-        ...            [3, 4, 5],
-        ...            [0, 6, 0]]
-
-        By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down
-        >>> expected_output = Assignment(g[0, 0], 3*f[-1, 0] + 6*f[0, -1] + 4*f[0, 0] + 2*f[0, 1] + 5*f[1, 0])
-        >>> assignment_from_stencil(stencil, f, g, order='visual') == expected_output
-        True
-
-        'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc.
-        >>> expected_output = Assignment(g[0, 0], 2*f[-1, 0] + 3*f[0, -1] + 4*f[0, 0] + 5*f[0, 1] + 6*f[1, 0])
-        >>> assignment_from_stencil(stencil, f, g, order='numpy') == expected_output
-        True
-
-        You can also pass field accesses to apply the stencil at an already shifted position:
-        >>> expected_output = Assignment(g[2, 0], 3*f[0, 0] + 6*f[1, -1] + 4*f[1, 0] + 2*f[1, 1] + 5*f[2, 0])
-        >>> assignment_from_stencil(stencil, f[1, 0], g[2, 0]) == expected_output
-        True
-    """
-    from pystencils.field import Field
-
-    stencil_array = np.array(stencil_array)
-    if order == 'visual':
-        stencil_array = np.swapaxes(stencil_array, 0, 1)
-        stencil_array = np.flip(stencil_array, axis=1)
-    elif order == 'numpy':
-        pass
-    else:
-        raise ValueError("'order' has to be either 'visual' or 'numpy'")
-
-    if isinstance(input_field, Field):
-        input_field = input_field.center
-    if isinstance(output_field, Field):
-        output_field = output_field.center
-
-    rhs = 0
-    offset = tuple(s // 2 for s in stencil_array.shape)
-
-    for index, factor in np.ndenumerate(stencil_array):
-        shift = tuple(i - o for i, o in zip(index, offset))
-        rhs += factor * input_field.get_shifted(*shift)
-
-    if normalization_factor:
-        rhs *= normalization_factor
-
-    return Assignment(output_field, rhs)
diff --git a/src/pystencils/sympyextensions/assignment_collection.py b/src/pystencils/sympyextensions/assignment_collection.py
deleted file mode 100644
index de8f7b4e5..000000000
--- a/src/pystencils/sympyextensions/assignment_collection.py
+++ /dev/null
@@ -1,474 +0,0 @@
-import itertools
-from copy import copy
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
-
-import sympy as sp
-
-import pystencils
-from .assignment import Assignment
-from .simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
-from pystencils.sympyextensions.math import count_operations, fast_subs
-
-
-class AssignmentCollection:
-    """
-    A collection of equations with subexpression definitions, also represented as assignments,
-    that are used in the main equations. AssignmentCollection can be passed to simplification methods.
-    These simplification methods can change the subexpressions, but the number and
-    left hand side of the main equations themselves is not altered.
-    Additionally a dictionary of simplification hints is stored, which are set by the functions that create
-    assignment collections to transport information to the simplification system.
-
-    Args:
-        main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each
-                          assignment is a field access. Thus the generated equations write on arrays.
-        subexpressions: List of assignments defining subexpressions used in main equations
-        simplification_hints: Dict that is used to annotate the assignment collection with hints that are
-                              used by the simplification system. See documentation of the simplification rules for
-                              potentially required hints and their meaning.
-        subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added
-                                        used to get new symbols that are unique for this AssignmentCollection
-
-    """
-
-    __match_args__ = ("main_assignments", "subexpressions")
-
-    # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
-
-    def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
-                 subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None,
-                 simplification_hints: Optional[Dict[str, Any]] = None,
-                 subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
-
-        if subexpressions is None:
-            subexpressions = {}
-
-        if isinstance(main_assignments, Dict):
-            main_assignments = [Assignment(k, v)
-                                for k, v in main_assignments.items()]
-        if isinstance(subexpressions, Dict):
-            subexpressions = [Assignment(k, v)
-                              for k, v in subexpressions.items()]
-
-        main_assignments = list(itertools.chain.from_iterable(
-            [(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))
-        subexpressions = list(itertools.chain.from_iterable(
-            [(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
-
-        self.main_assignments = main_assignments
-        self.subexpressions = subexpressions
-
-        if simplification_hints is None:
-            simplification_hints = {}
-
-        self.simplification_hints = simplification_hints
-
-        ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name]
-        max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0
-
-        if subexpression_symbol_generator is None:
-            self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr)
-        else:
-            self.subexpression_symbol_generator = subexpression_symbol_generator
-
-    def add_simplification_hint(self, key: str, value: Any) -> None:
-        """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""
-        assert key not in self.simplification_hints, "This hint already exists"
-        self.simplification_hints[key] = value
-
-    def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol:
-        """Adds a subexpression to current collection.
-
-        Args:
-            rhs: right hand side of new subexpression
-            lhs: optional left hand side of new subexpression. If None a new unique symbol is generated.
-            topological_sort: sort the subexpressions topologically after insertion, to make sure that
-                              definition of a symbol comes before its usage. If False, subexpression is appended.
-
-        Returns:
-            left hand side symbol (which could have been generated)
-        """
-        if lhs is None:
-            lhs = next(self.subexpression_symbol_generator)
-        eq = Assignment(lhs, rhs)
-        self.subexpressions.append(eq)
-        if topological_sort:
-            self.topological_sort(sort_subexpressions=True,
-                                  sort_main_assignments=False)
-        return lhs
-
-    def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
-        """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
-        if sort_subexpressions:
-            self.subexpressions = sort_assignments_topologically(self.subexpressions)
-        if sort_main_assignments:
-            self.main_assignments = sort_assignments_topologically(self.main_assignments)
-
-    # ---------------------------------------------- Properties  -------------------------------------------------------
-
-    @property
-    def all_assignments(self) -> List[Assignment]:
-        """Subexpression and main equations as a single list."""
-        return self.subexpressions + self.main_assignments
-
-    @property
-    def rhs_symbols(self) -> Set[sp.Symbol]:
-        """All symbols used in the assignment collection, which occur on the rhs of any assignment."""
-        rhs_symbols = set()
-        for eq in self.all_assignments:
-            if isinstance(eq, Assignment):
-                rhs_symbols.update(eq.rhs.atoms(sp.Symbol))
-            elif isinstance(eq, pystencils.astnodes.Node):
-                rhs_symbols.update(eq.undefined_symbols)
-
-        return rhs_symbols
-
-    @property
-    def free_symbols(self) -> Set[sp.Symbol]:
-        """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
-        return self.rhs_symbols - self.bound_symbols
-
-    @property
-    def bound_symbols(self) -> Set[sp.Symbol]:
-        """All symbols which occur on the left hand side of a main assignment or a subexpression."""
-        bound_symbols_set = set(
-            [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]
-        )
-
-        assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
-            "Not in SSA form - same symbol assigned multiple times"
-
-        bound_symbols_set = bound_symbols_set.union(*[
-            assignment.symbols_defined for assignment in self.all_assignments
-            if isinstance(assignment, pystencils.astnodes.Node)
-        ])
-
-        return bound_symbols_set
-
-    @property
-    def rhs_fields(self):
-        """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
-        return {s.field for s in self.rhs_symbols if hasattr(s, 'field')}
-
-    @property
-    def free_fields(self):
-        """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
-        return {s.field for s in self.free_symbols if hasattr(s, 'field')}
-
-    @property
-    def bound_fields(self):
-        """All field accessed on the left hand side of a main assignment or a subexpression."""
-        return {s.field for s in self.bound_symbols if hasattr(s, 'field')}
-
-    @property
-    def defined_symbols(self) -> Set[sp.Symbol]:
-        """All symbols which occur as left-hand-sides of one of the main equations"""
-        lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)])
-        return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments
-                                if isinstance(assignment, pystencils.astnodes.Node)]))
-
-    @property
-    def operation_count(self):
-        """See :func:`count_operations` """
-        return count_operations(self.all_assignments, only_type=None)
-
-    def atoms(self, *args):
-        return set().union(*[a.atoms(*args) for a in self.all_assignments])
-
-    def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
-        """Returns all symbols that depend on one of the passed symbols.
-
-        A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when
-        'b' is required to compute 'a'.
-        """
-
-        queue = list(symbols)
-
-        def add_symbols_from_expr(expr):
-            dependent_symbols = expr.atoms(sp.Symbol)
-            for ds in dependent_symbols:
-                queue.append(ds)
-
-        handled_symbols = set()
-        assignment_dict = {e.lhs: e.rhs for e in self.all_assignments}
-
-        while len(queue) > 0:
-            e = queue.pop(0)
-            if e in handled_symbols:
-                continue
-            if e in assignment_dict:
-                add_symbols_from_expr(assignment_dict[e])
-            handled_symbols.add(e)
-
-        return handled_symbols
-
-    def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]] = None, module=None):
-        """Returns a python function to evaluate this equation collection.
-
-        Args:
-            symbols: symbol(s) which are the parameter for the created function
-            fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify
-            module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy'
-
-        Examples:
-              >>> a, b, c, d = sp.symbols("a b c d")
-              >>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)],
-              ...                           subexpressions=[Assignment(b, a + b / 2)])
-              >>> python_function = ac.lambdify([a], fixed_symbols={b: 2})
-              >>> python_function(4)
-              {c: 6, d: 18}
-        """
-        assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self
-        assignments = assignments.new_without_subexpressions().main_assignments
-        lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments}
-
-        def f(*args, **kwargs):
-            return {s: func(*args, **kwargs) for s, func in lambdas.items()}
-
-        return f
-
-    # ---------------------------- Creating new modified collections ---------------------------------------------------
-
-    def copy(self,
-             main_assignments: Optional[List[Assignment]] = None,
-             subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection':
-        """Returns a copy with optionally replaced main_assignments and/or subexpressions."""
-
-        res = copy(self)
-        res.simplification_hints = self.simplification_hints.copy()
-        res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator)
-
-        if main_assignments is not None:
-            res.main_assignments = main_assignments
-        else:
-            res.main_assignments = self.main_assignments.copy()
-
-        if subexpressions is not None:
-            res.subexpressions = subexpressions
-        else:
-            res.subexpressions = self.subexpressions.copy()
-
-        return res
-
-    def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,
-                               substitute_on_lhs: bool = True,
-                               sort_topologically: bool = True) -> 'AssignmentCollection':
-        """Returns new object, where terms are substituted according to the passed substitution dict.
-
-        Args:
-            substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions
-            add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions
-            substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments
-            sort_topologically: if subexpressions are added as substitutions and this parameters is true,
-                                the subexpressions are sorted topologically after insertion
-        Returns:
-            New AssignmentCollection where substitutions have been applied, self is not altered.
-        """
-        transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs
-        transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions)
-        transformed_assignments = transform(self.main_assignments, fast_subs, substitutions)
-
-        if add_substitutions_as_subexpressions:
-            transformed_subexpressions = [Assignment(b, a) for a, b in
-                                          substitutions.items()] + transformed_subexpressions
-            if sort_topologically:
-                transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
-        return self.copy(transformed_assignments, transformed_subexpressions)
-
-    def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
-        """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
-        own_definitions = set([e.lhs for e in self.main_assignments])
-        other_definitions = set([e.lhs for e in other.main_assignments])
-        assert len(own_definitions.intersection(other_definitions)) == 0, \
-            "Cannot merge collections, since both define the same symbols"
-
-        own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
-        substitution_dict = {}
-
-        processed_other_subexpression_equations = []
-        for other_subexpression_eq in other.subexpressions:
-            if other_subexpression_eq.lhs in own_subexpression_symbols:
-                if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
-                    continue  # exact the same subexpression equation exists already
-                else:
-                    # different definition - a new name has to be introduced
-                    new_lhs = next(self.subexpression_symbol_generator)
-                    new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict))
-                    processed_other_subexpression_equations.append(new_eq)
-                    substitution_dict[other_subexpression_eq.lhs] = new_lhs
-            else:
-                processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict))
-
-        processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments]
-        return self.copy(self.main_assignments + processed_other_main_assignments,
-                         self.subexpressions + processed_other_subexpression_equations)
-
-    def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection':
-        """Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions.
-
-        Returns:
-            new AssignmentCollection, self is not altered
-        """
-        symbols_to_extract = set(symbols_to_extract)
-        dependent_symbols = self.dependent_symbols(symbols_to_extract)
-        new_assignments = []
-        for eq in self.all_assignments:
-            if eq.lhs in symbols_to_extract:
-                new_assignments.append(eq)
-
-        new_sub_expr = [eq for eq in self.all_assignments
-                        if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
-        return self.copy(new_assignments, new_sub_expr)
-
-    def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
-        """Returns new collection that only contains subexpressions required to compute the main assignments."""
-        all_lhs = [eq.lhs for eq in self.main_assignments]
-        return self.new_filtered(all_lhs)
-
-    def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection':
-        """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere."""
-        new_subexpressions = []
-        subs_dict = None
-        for se in self.subexpressions:
-            if se.lhs == symbol:
-                subs_dict = {se.lhs: se.rhs}
-            else:
-                new_subexpressions.append(se)
-        if subs_dict is None:
-            return self
-
-        new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions]
-        new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
-        return self.copy(new_eqs, new_subexpressions)
-
-    def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection':
-        """Returns a new collection where all subexpressions have been inserted."""
-        if subexpressions_to_keep is None:
-            subexpressions_to_keep = set()
-        if len(self.subexpressions) == 0:
-            return self.copy()
-
-        subexpressions_to_keep = set(subexpressions_to_keep)
-
-        kept_subexpressions = []
-        if self.subexpressions[0].lhs in subexpressions_to_keep:
-            substitution_dict = {}
-            kept_subexpressions.append(self.subexpressions[0])
-        else:
-            substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
-
-        subexpression = [e for e in self.subexpressions]
-        for i in range(1, len(subexpression)):
-            subexpression[i] = fast_subs(subexpression[i], substitution_dict)
-            if subexpression[i].lhs in subexpressions_to_keep:
-                kept_subexpressions.append(subexpression[i])
-            else:
-                substitution_dict[subexpression[i].lhs] = subexpression[i].rhs
-
-        new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments]
-        return self.copy(new_assignment, kept_subexpressions)
-
-    # ----------------------------------------- Display and Printing   -------------------------------------------------
-
-    def _repr_html_(self):
-        """Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
-
-        def make_html_equation_table(equations):
-            no_border = 'style="border:none"'
-            html_table = '<table style="border:none; width: 100%; ">'
-            line = '<tr {nb}> <td {nb}>$${eq}$$</td>  </tr> '
-            for eq in equations:
-                format_dict = {'eq': sp.latex(eq),
-                               'nb': no_border, }
-                html_table += line.format(**format_dict)
-            html_table += "</table>"
-            return html_table
-
-        result = ""
-        if len(self.subexpressions) > 0:
-            result += "<div>Subexpressions:</div>"
-            result += make_html_equation_table(self.subexpressions)
-        result += "<div>Main Assignments:</div>"
-        result += make_html_equation_table(self.main_assignments)
-        return result
-
-    def __repr__(self):
-        return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}"
-
-    def __str__(self):
-        result = "Subexpressions:\n"
-        for eq in self.subexpressions:
-            result += f"\t{eq}\n"
-        result += "Main Assignments:\n"
-        for eq in self.main_assignments:
-            result += f"\t{eq}\n"
-        return result
-
-    def __iter__(self):
-        return self.all_assignments.__iter__()
-
-    @property
-    def main_assignments_dict(self):
-        return {a.lhs: a.rhs for a in self.main_assignments}
-
-    @property
-    def subexpressions_dict(self):
-        return {a.lhs: a.rhs for a in self.subexpressions}
-
-    def set_main_assignments_from_dict(self, main_assignments_dict):
-        self.main_assignments = [Assignment(k, v)
-                                 for k, v in main_assignments_dict.items()]
-
-    def set_sub_expressions_from_dict(self, sub_expressions_dict):
-        self.subexpressions = [Assignment(k, v)
-                               for k, v in sub_expressions_dict.items()]
-
-    def find(self, *args, **kwargs):
-        return set.union(
-            *[a.find(*args, **kwargs) for a in self.all_assignments]
-        )
-
-    def match(self, *args, **kwargs):
-        rtn = {}
-        for a in self.all_assignments:
-            partial_result = a.match(*args, **kwargs)
-            if partial_result:
-                rtn.update(partial_result)
-        return rtn
-
-    def subs(self, *args, **kwargs):
-        return AssignmentCollection(
-            main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments],
-            subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions]
-        )
-
-    def replace(self, *args, **kwargs):
-        return AssignmentCollection(
-            main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments],
-            subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions]
-        )
-
-    def __eq__(self, other):
-        return set(self.all_assignments) == set(other.all_assignments)
-
-    def __bool__(self):
-        return bool(self.all_assignments)
-
-
-class SymbolGen:
-    """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
-
-    def __init__(self, symbol="xi", dtype=None, ctr=0):
-        self._ctr = ctr
-        self._symbol = symbol
-        self._dtype = dtype
-
-    def __iter__(self):
-        return self
-
-    def __next__(self):
-        name = f"{self._symbol}_{self._ctr}"
-        self._ctr += 1
-        if self._dtype is not None:
-            return pystencils.TypedSymbol(name, self._dtype)
-        return sp.Symbol(name)
diff --git a/src/pystencils/sympyextensions/astnodes.py b/src/pystencils/sympyextensions/astnodes.py
index 0cd82322c..8483977d8 100644
--- a/src/pystencils/sympyextensions/astnodes.py
+++ b/src/pystencils/sympyextensions/astnodes.py
@@ -1,689 +1,588 @@
-import collections.abc
+from copy import copy
 import itertools
 import uuid
-from typing import Any, List, Optional, Sequence, Set, Union
-
-from .assignment import Assignment
-from pystencils.enums import Target, Backend
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
 
 import sympy as sp
+from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment
+from sympy.printing.latex import LatexPrinter
+import numpy as np
 
-from .math import fast_subs
-from .typed_sympy import (create_type, CastFunc,
-                          FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol)
+from .math import count_operations, fast_subs
+from .simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
+from .typed_sympy import create_type, TypedSymbol
 
 
-NodeOrExpr = Union['Node', sp.Expr]
+def print_assignment_latex(printer, expr):
+    binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else ''
+    """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
+    printed_lhs = printer.doprint(expr.lhs)
+    printed_rhs = printer.doprint(expr.rhs)
+    return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}"
 
 
-class Node:
-    """Base class for all AST nodes."""
+def assignment_str(assignment):
+    op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else '←'
+    return fr"{assignment.lhs} {op} {assignment.rhs}"
 
-    def __init__(self, parent: Optional['Node'] = None):
-        self.parent = parent
 
-    @property
-    def args(self) -> List[NodeOrExpr]:
-        """Returns all arguments/children of this node."""
-        raise NotImplementedError()
+_old_new = sp.codegen.ast.Assignment.__new__
 
-    @property
-    def symbols_defined(self) -> Set[sp.Symbol]:
-        """Set of symbols which are defined by this node."""
-        raise NotImplementedError()
 
-    @property
-    def undefined_symbols(self) -> Set[sp.Symbol]:
-        """Symbols which are used but are not defined inside this node."""
-        raise NotImplementedError()
-
-    def subs(self, subs_dict) -> None:
-        """Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
-        for i, a in enumerate(self.args):
-            result = a.subs(subs_dict)
-            if isinstance(a, sp.Expr):  # sympy expressions' subs is out-of-place
-                self.args[i] = result
-            else:  # all other should be in-place
-                assert result is None
+# TODO Typing Part2 add default type, defult_float_type, default_int_type and use sane defaults
+def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
+    if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
+        assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
+        return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
+    return _old_new(cls, lhs, rhs, *args, **kwargs)
 
-    @property
-    def func(self):
-        return self.__class__
-
-    def atoms(self, arg_type) -> Set[Any]:
-        """Returns a set of all descendants recursively, which are an instance of the given type."""
-        result = set()
-        for arg in self.args:
-            if isinstance(arg, arg_type):
-                result.add(arg)
-            result.update(arg.atoms(arg_type))
-        return result
 
+Assignment.__str__ = assignment_str
+Assignment.__new__ = _Assignment__new__
+LatexPrinter._print_Assignment = print_assignment_latex
+
+AugmentedAssignment.__str__ = assignment_str
+LatexPrinter._print_AugmentedAssignment = print_assignment_latex
 
-class Conditional(Node):
-    """Conditional that maps to a 'if' statement in C/C++.
+sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
 
-    Try to avoid using this node inside of loops, since currently this construction can not be vectorized.
-    Consider using assignments with sympy.Piecewise in this case.
+
+def assignment_from_stencil(stencil_array, input_field, output_field,
+                            normalization_factor=None, order='visual') -> Assignment:
+    """Creates an assignment
 
     Args:
-        condition_expr: sympy relational expression
-        true_block: block which is run if conditional is true
-        false_block: optional block which is run if conditional is false
+        stencil_array: nested list of numpy array defining the stencil weights
+        input_field: field or field access, defining where the stencil should be applied to
+        output_field: field or field access where the result is written to
+        normalization_factor: optional normalization factor for the stencil
+        order: defines how the stencil_array is interpreted. Possible values are 'visual' and 'numpy'.
+               For details see examples
+
+    Returns:
+        Assignment that can be used to create a kernel
+
+    Examples:
+        >>> import pystencils as ps
+        >>> f, g = ps.fields("f, g: [2D]")
+        >>> stencil = [[0, 2, 0],
+        ...            [3, 4, 5],
+        ...            [0, 6, 0]]
+
+        By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down
+        >>> expected_output = Assignment(g[0, 0], 3*f[-1, 0] + 6*f[0, -1] + 4*f[0, 0] + 2*f[0, 1] + 5*f[1, 0])
+        >>> assignment_from_stencil(stencil, f, g, order='visual') == expected_output
+        True
+
+        'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc.
+        >>> expected_output = Assignment(g[0, 0], 2*f[-1, 0] + 3*f[0, -1] + 4*f[0, 0] + 5*f[0, 1] + 6*f[1, 0])
+        >>> assignment_from_stencil(stencil, f, g, order='numpy') == expected_output
+        True
+
+        You can also pass field accesses to apply the stencil at an already shifted position:
+        >>> expected_output = Assignment(g[2, 0], 3*f[0, 0] + 6*f[1, -1] + 4*f[1, 0] + 2*f[1, 1] + 5*f[2, 0])
+        >>> assignment_from_stencil(stencil, f[1, 0], g[2, 0]) == expected_output
+        True
     """
+    from pystencils.field import Field
 
-    def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'],
-                 false_block: Optional['Block'] = None) -> None:
-        super(Conditional, self).__init__(parent=None)
+    stencil_array = np.array(stencil_array)
+    if order == 'visual':
+        stencil_array = np.swapaxes(stencil_array, 0, 1)
+        stencil_array = np.flip(stencil_array, axis=1)
+    elif order == 'numpy':
+        pass
+    else:
+        raise ValueError("'order' has to be either 'visual' or 'numpy'")
 
-        self.condition_expr = condition_expr
+    if isinstance(input_field, Field):
+        input_field = input_field.center
+    if isinstance(output_field, Field):
+        output_field = output_field.center
 
-        def handle_child(c):
-            if c is None:
-                return None
-            if not isinstance(c, Block):
-                c = Block([c])
-            c.parent = self
-            return c
+    rhs = 0
+    offset = tuple(s // 2 for s in stencil_array.shape)
 
-        self.true_block = handle_child(true_block)
-        self.false_block = handle_child(false_block)
+    for index, factor in np.ndenumerate(stencil_array):
+        shift = tuple(i - o for i, o in zip(index, offset))
+        rhs += factor * input_field.get_shifted(*shift)
 
-    def subs(self, subs_dict):
-        self.true_block.subs(subs_dict)
-        if self.false_block:
-            self.false_block.subs(subs_dict)
-        self.condition_expr = self.condition_expr.subs(subs_dict)
+    if normalization_factor:
+        rhs *= normalization_factor
 
-    @property
-    def args(self):
-        result = [self.condition_expr, self.true_block]
-        if self.false_block:
-            result.append(self.false_block)
-        return result
+    return Assignment(output_field, rhs)
 
-    @property
-    def symbols_defined(self):
-        return set()
-
-    @property
-    def undefined_symbols(self):
-        result = self.true_block.undefined_symbols
-        if self.false_block:
-            result.update(self.false_block.undefined_symbols)
-        if hasattr(self.condition_expr, 'atoms'):
-            result.update(self.condition_expr.atoms(sp.Symbol))
-        return result
 
-    def __str__(self):
-        return self.__repr__()
+class AssignmentCollection:
+    """
+    A collection of equations with subexpression definitions, also represented as assignments,
+    that are used in the main equations. AssignmentCollection can be passed to simplification methods.
+    These simplification methods can change the subexpressions, but the number and
+    left hand side of the main equations themselves is not altered.
+    Additionally a dictionary of simplification hints is stored, which are set by the functions that create
+    assignment collections to transport information to the simplification system.
 
-    def __repr__(self):
-        result = f'if:({self.condition_expr!r}) '
-        if self.true_block:
-            result += f'\n\t{self.true_block}) '
-        if self.false_block:
-            result = 'else: '
-            result += f'\n\t{self.false_block} '
+    Args:
+        main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each
+                          assignment is a field access. Thus the generated equations write on arrays.
+        subexpressions: List of assignments defining subexpressions used in main equations
+        simplification_hints: Dict that is used to annotate the assignment collection with hints that are
+                              used by the simplification system. See documentation of the simplification rules for
+                              potentially required hints and their meaning.
+        subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added
+                                        used to get new symbols that are unique for this AssignmentCollection
 
-        return result
+    """
 
-    def replace_by_true_block(self):
-        """Replaces the conditional by its True block"""
-        self.parent.replace(self, [self.true_block])
+    __match_args__ = ("main_assignments", "subexpressions")
 
-    def replace_by_false_block(self):
-        """Replaces the conditional by its False block"""
-        self.parent.replace(self, [self.false_block] if self.false_block else [])
+    # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
 
+    def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
+                 subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None,
+                 simplification_hints: Optional[Dict[str, Any]] = None,
+                 subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
 
-class SkipIteration(Node):
-    @property
-    def args(self):
-        return []
+        if subexpressions is None:
+            subexpressions = {}
 
-    @property
-    def symbols_defined(self):
-        return set()
+        if isinstance(main_assignments, Dict):
+            main_assignments = [Assignment(k, v)
+                                for k, v in main_assignments.items()]
+        if isinstance(subexpressions, Dict):
+            subexpressions = [Assignment(k, v)
+                              for k, v in subexpressions.items()]
 
-    @property
-    def undefined_symbols(self):
-        return set()
-
-
-class Block(Node):
-    def __init__(self, nodes: Union[Node, List[Node]]):
-        super(Block, self).__init__()
-        if not isinstance(nodes, list):
-            nodes = [nodes]
-        self._nodes = nodes
-        self.parent = None
-        for n in self._nodes:
-            try:
-                n.parent = self
-            except AttributeError:
-                pass
+        main_assignments = list(itertools.chain.from_iterable(
+            [(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))
+        subexpressions = list(itertools.chain.from_iterable(
+            [(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
 
-    @property
-    def args(self):
-        return self._nodes
+        self.main_assignments = main_assignments
+        self.subexpressions = subexpressions
 
-    def subs(self, subs_dict) -> None:
-        for a in self.args:
-            a.subs(subs_dict)
+        if simplification_hints is None:
+            simplification_hints = {}
 
-    def fast_subs(self, subs_dict, skip=None):
-        self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes]
-        return self
+        self.simplification_hints = simplification_hints
 
-    def insert_front(self, node, if_not_exists=False):
-        if if_not_exists and len(self._nodes) > 0 and self._nodes[0] == node:
-            return
-        if isinstance(node, collections.abc.Iterable):
-            node = list(node)
-            for n in node:
-                n.parent = self
+        ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name]
+        max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0
 
-            self._nodes = node + self._nodes
-        else:
-            node.parent = self
-            self._nodes.insert(0, node)
-
-    def insert_before(self, new_node, insert_before, if_not_exists=False):
-        new_node.parent = self
-        assert self._nodes.count(insert_before) == 1
-        idx = self._nodes.index(insert_before)
-
-        if not if_not_exists or self._nodes[idx] != new_node:
-            self._nodes.insert(idx, new_node)
-
-    def insert_after(self, new_node, insert_after, if_not_exists=False):
-        new_node.parent = self
-        assert self._nodes.count(insert_after) == 1
-        idx = self._nodes.index(insert_after) + 1
-
-        if not if_not_exists or not (self._nodes[idx - 1] == new_node
-                                     or (idx < len(self._nodes) and self._nodes[idx] == new_node)):
-            self._nodes.insert(idx, new_node)
-
-    def append(self, node):
-        if isinstance(node, list) or isinstance(node, tuple):
-            for n in node:
-                n.parent = self
-                self._nodes.append(n)
-        else:
-            node.parent = self
-            self._nodes.append(node)
-
-    def take_child_nodes(self):
-        tmp = self._nodes
-        self._nodes = []
-        return tmp
-
-    def replace(self, child, replacements):
-        assert self._nodes.count(child) == 1
-        idx = self._nodes.index(child)
-        del self._nodes[idx]
-        if type(replacements) is list:
-            for e in replacements:
-                e.parent = self
-            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
+        if subexpression_symbol_generator is None:
+            self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr)
         else:
-            replacements.parent = self
-            self._nodes.insert(idx, replacements)
+            self.subexpression_symbol_generator = subexpression_symbol_generator
 
-    @property
-    def symbols_defined(self):
-        result = set()
-        for a in self.args:
-            if isinstance(a, Assignment):
-                result.update(a.free_symbols)
-            else:
-                result.update(a.symbols_defined)
-        return result
-
-    @property
-    def undefined_symbols(self):
-        result = set()
-        defined_symbols = set()
-        for a in self.args:
-            if isinstance(a, Assignment):
-                result.update(a.free_symbols)
-                defined_symbols.update({a.lhs})
-            else:
-                result.update(a.undefined_symbols)
-                defined_symbols.update(a.symbols_defined)
-        return result - defined_symbols
-
-    def __str__(self):
-        return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
-
-    def __repr__(self):
-        return "Block"
+    def add_simplification_hint(self, key: str, value: Any) -> None:
+        """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""
+        assert key not in self.simplification_hints, "This hint already exists"
+        self.simplification_hints[key] = value
 
+    def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol:
+        """Adds a subexpression to current collection.
 
-class PragmaBlock(Block):
-    def __init__(self, pragma_line, nodes):
-        super(PragmaBlock, self).__init__(nodes)
-        self.pragma_line = pragma_line
-        for n in nodes:
-            n.parent = self
+        Args:
+            rhs: right hand side of new subexpression
+            lhs: optional left hand side of new subexpression. If None a new unique symbol is generated.
+            topological_sort: sort the subexpressions topologically after insertion, to make sure that
+                              definition of a symbol comes before its usage. If False, subexpression is appended.
 
-    def __repr__(self):
-        return self.pragma_line
-
-
-class LoopOverCoordinate(Node):
-    LOOP_COUNTER_NAME_PREFIX = "ctr"
-    BLOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
-
-    def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False, custom_loop_ctr=None):
-        super(LoopOverCoordinate, self).__init__(parent=None)
-        self.body = body
-        body.parent = self
-        self.coordinate_to_loop_over = coordinate_to_loop_over
-        self.start = start
-        self.stop = stop
-        self.step = step
-        self.body.parent = self
-        self.prefix_lines = []
-        self.is_block_loop = is_block_loop
-        self.custom_loop_ctr = custom_loop_ctr
-
-    def new_loop_with_different_body(self, new_body):
-        result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
-                                    self.step, self.is_block_loop, self.custom_loop_ctr)
-        result.prefix_lines = [prefix_line for prefix_line in self.prefix_lines]
-        return result
-
-    def subs(self, subs_dict):
-        self.body.subs(subs_dict)
-        if hasattr(self.start, "subs"):
-            self.start = self.start.subs(subs_dict)
-        if hasattr(self.stop, "subs"):
-            self.stop = self.stop.subs(subs_dict)
-        if hasattr(self.step, "subs"):
-            self.step = self.step.subs(subs_dict)
-
-    def fast_subs(self, subs_dict, skip=None):
-        self.body = fast_subs(self.body, subs_dict, skip)
-        if isinstance(self.start, sp.Basic):
-            self.start = fast_subs(self.start, subs_dict, skip)
-        if isinstance(self.stop, sp.Basic):
-            self.stop = fast_subs(self.stop, subs_dict, skip)
-        if isinstance(self.step, sp.Basic):
-            self.step = fast_subs(self.step, subs_dict, skip)
-        return self
+        Returns:
+            left hand side symbol (which could have been generated)
+        """
+        if lhs is None:
+            lhs = next(self.subexpression_symbol_generator)
+        eq = Assignment(lhs, rhs)
+        self.subexpressions.append(eq)
+        if topological_sort:
+            self.topological_sort(sort_subexpressions=True,
+                                  sort_main_assignments=False)
+        return lhs
 
-    @property
-    def args(self):
-        result = [self.body]
-        for e in [self.start, self.stop, self.step]:
-            if hasattr(e, "args"):
-                result.append(e)
-        return result
+    def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
+        """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
+        if sort_subexpressions:
+            self.subexpressions = sort_assignments_topologically(self.subexpressions)
+        if sort_main_assignments:
+            self.main_assignments = sort_assignments_topologically(self.main_assignments)
 
-    def replace(self, child, replacement):
-        if child == self.body:
-            self.body = replacement
-        elif child == self.start:
-            self.start = replacement
-        elif child == self.step:
-            self.step = replacement
-        elif child == self.stop:
-            self.stop = replacement
+    # ---------------------------------------------- Properties  -------------------------------------------------------
 
     @property
-    def symbols_defined(self):
-        return {self.loop_counter_symbol}
+    def all_assignments(self) -> List[Assignment]:
+        """Subexpression and main equations as a single list."""
+        return self.subexpressions + self.main_assignments
 
     @property
-    def undefined_symbols(self):
-        result = self.body.undefined_symbols
-        for possible_symbol in [self.start, self.stop, self.step]:
-            if isinstance(possible_symbol, Node) or isinstance(possible_symbol, sp.Basic):
-                result.update(possible_symbol.atoms(sp.Symbol))
-        return result - {self.loop_counter_symbol}
-
-    @staticmethod
-    def get_loop_counter_name(coordinate_to_loop_over):
-        return f"{LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
-
-    @staticmethod
-    def get_block_loop_counter_name(coordinate_to_loop_over):
-        return f"{LoopOverCoordinate.BLOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
+    def rhs_symbols(self) -> Set[sp.Symbol]:
+        """All symbols used in the assignment collection, which occur on the rhs of any assignment."""
+        rhs_symbols = set()
+        for eq in self.all_assignments:
+            if isinstance(eq, Assignment):
+                rhs_symbols.update(eq.rhs.atoms(sp.Symbol))
+            # TODO rewrite with SymPy AST
+            # elif isinstance(eq, pystencils.astnodes.Node):
+            #     rhs_symbols.update(eq.undefined_symbols)
 
-    @property
-    def loop_counter_name(self):
-        if self.custom_loop_ctr:
-            return self.custom_loop_ctr.name
-        else:
-            if self.is_block_loop:
-                return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
-            else:
-                return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
-
-    @staticmethod
-    def is_loop_counter_symbol(symbol):
-        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
-        if not symbol.name.startswith(prefix):
-            return None
-        if symbol.dtype != create_type('int'):
-            return None
-        coordinate = int(symbol.name[len(prefix) + 1:])
-        return coordinate
-
-    @staticmethod
-    def get_loop_counter_symbol(coordinate_to_loop_over):
-        return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True)
-
-    @staticmethod
-    def get_block_loop_counter_symbol(coordinate_to_loop_over):
-        return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over),
-                           'int',
-                           nonnegative=True)
+        return rhs_symbols
 
     @property
-    def loop_counter_symbol(self):
-        if self.custom_loop_ctr:
-            return self.custom_loop_ctr
-        else:
-            if self.is_block_loop:
-                return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
-            else:
-                return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
+    def free_symbols(self) -> Set[sp.Symbol]:
+        """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
+        return self.rhs_symbols - self.bound_symbols
 
     @property
-    def is_outermost_loop(self):
-        return get_next_parent_of_type(self, LoopOverCoordinate) is None
+    def bound_symbols(self) -> Set[sp.Symbol]:
+        """All symbols which occur on the left hand side of a main assignment or a subexpression."""
+        bound_symbols_set = set(
+            [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]
+        )
 
-    @property
-    def is_innermost_loop(self):
-        return len(self.atoms(LoopOverCoordinate)) == 0
+        assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
+            "Not in SSA form - same symbol assigned multiple times"
 
-    def __str__(self):
-        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start,
-                                                                    self.loop_counter_name, self.stop,
-                                                                    self.loop_counter_name, self.step,
-                                                                    ("\t" + "\t".join(str(self.body).splitlines(True))))
+        # TODO rewrite with SymPy AST
+        # bound_symbols_set = bound_symbols_set.union(*[
+        #     assignment.symbols_defined for assignment in self.all_assignments
+        #     if isinstance(assignment, pystencils.astnodes.Node)
+        # ])
 
-    def __repr__(self):
-        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start,
-                                                              self.loop_counter_name, self.stop,
-                                                              self.loop_counter_name, self.step)
-
-
-class SympyAssignment(Node):
-    def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False):
-        super(SympyAssignment, self).__init__(parent=None)
-        self._lhs_symbol = sp.sympify(lhs_symbol)
-        self._rhs = sp.sympify(rhs_expr)
-        self._is_const = is_const
-        self._is_declaration = self.__is_declaration()
-        self._use_auto = use_auto
-
-    def __is_declaration(self):
-        if isinstance(self._lhs_symbol, CastFunc):
-            return False
-        if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)):
-            return False
-        return True
+        return bound_symbols_set
 
     @property
-    def lhs(self):
-        return self._lhs_symbol
+    def rhs_fields(self):
+        """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
+        return {s.field for s in self.rhs_symbols if hasattr(s, 'field')}
 
     @property
-    def rhs(self):
-        return self._rhs
-
-    @lhs.setter
-    def lhs(self, new_value):
-        self._lhs_symbol = new_value
-        self._is_declaration = self.__is_declaration()
-
-    @rhs.setter
-    def rhs(self, new_rhs_expr):
-        self._rhs = new_rhs_expr
-
-    def subs(self, subs_dict):
-        self.lhs = fast_subs(self.lhs, subs_dict)
-        self.rhs = fast_subs(self.rhs, subs_dict)
-
-    def fast_subs(self, subs_dict, skip=None):
-        self.lhs = fast_subs(self.lhs, subs_dict, skip)
-        self.rhs = fast_subs(self.rhs, subs_dict, skip)
-        return self
-
-    def optimize(self, optimizations):
-        try:
-            from sympy.codegen.rewriting import optimize
-            self.rhs = optimize(self.rhs, optimizations)
-        except Exception:
-            pass
+    def free_fields(self):
+        """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
+        return {s.field for s in self.free_symbols if hasattr(s, 'field')}
 
     @property
-    def args(self):
-        return [self._lhs_symbol, self.rhs]
+    def bound_fields(self):
+        """All field accessed on the left hand side of a main assignment or a subexpression."""
+        return {s.field for s in self.bound_symbols if hasattr(s, 'field')}
 
     @property
-    def symbols_defined(self):
-        if not self._is_declaration:
-            return set()
-        return {self._lhs_symbol}
+    def defined_symbols(self) -> Set[sp.Symbol]:
+        """All symbols which occur as left-hand-sides of one of the main equations"""
+        lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)])
+        return lhs_set
+        # TODO rewrite with SymPy AST
+        # return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments
+        #                         if isinstance(assignment, pystencils.astnodes.Node)]))
 
     @property
-    def undefined_symbols(self):
-        result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)}
-        # Add loop counters if there a field accesses
-        loop_counters = set()
-        for symbol in result:
-            if isinstance(symbol, Field.Access):
-                for i in range(len(symbol.offsets)):
-                    loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
-        result.update(loop_counters)
-        
-        result.update(self._lhs_symbol.atoms(sp.Symbol))
-        
-        return result
+    def operation_count(self):
+        """See :func:`count_operations` """
+        return count_operations(self.all_assignments, only_type=None)
 
-    @property
-    def is_declaration(self):
-        return self._is_declaration
+    def atoms(self, *args):
+        return set().union(*[a.atoms(*args) for a in self.all_assignments])
 
-    @property
-    def is_const(self):
-        return self._is_const
+    def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
+        """Returns all symbols that depend on one of the passed symbols.
 
-    @property
-    def use_auto(self):
-        return self._use_auto
-
-    def replace(self, child, replacement):
-        if child == self.lhs:
-            replacement.parent = self
-            self.lhs = replacement
-        elif child == self.rhs:
-            replacement.parent = self
-            self.rhs = replacement
-        else:
-            raise ValueError(f'{replacement} is not in args of {self.__class__}')
+        A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when
+        'b' is required to compute 'a'.
+        """
 
-    def __repr__(self):
-        return repr(self.lhs) + " ← " + repr(self.rhs)
+        queue = list(symbols)
 
-    def _repr_html_(self):
-        printed_lhs = sp.latex(self.lhs)
-        printed_rhs = sp.latex(self.rhs)
-        return f"${printed_lhs} \\leftarrow {printed_rhs}$"
+        def add_symbols_from_expr(expr):
+            dependent_symbols = expr.atoms(sp.Symbol)
+            for ds in dependent_symbols:
+                queue.append(ds)
 
-    def __hash__(self):
-        return hash((self.lhs, self.rhs))
+        handled_symbols = set()
+        assignment_dict = {e.lhs: e.rhs for e in self.all_assignments}
 
-    def __eq__(self, other):
-        return type(self) is type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs)
-
-
-class ResolvedFieldAccess(sp.Indexed):
-    def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
-        if not isinstance(base, sp.IndexedBase):
-            assert isinstance(base, TypedSymbol)
-            base = sp.IndexedBase(base, shape=(1,))
-            assert isinstance(base.label, TypedSymbol)
-        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
-        obj.field = field
-        obj.offsets = offsets
-        obj.idx_coordinate_values = idx_coordinate_values
-        return obj
-
-    def _eval_subs(self, old, new):
-        return ResolvedFieldAccess(self.args[0],
-                                   self.args[1].subs(old, new),
-                                   self.field, self.offsets, self.idx_coordinate_values)
-
-    def fast_subs(self, substitutions, skip=None):
-        if self in substitutions:
-            return substitutions[self]
-        return ResolvedFieldAccess(self.args[0].subs(substitutions),
-                                   self.args[1].subs(substitutions),
-                                   self.field, self.offsets, self.idx_coordinate_values)
-
-    def _hashable_content(self):
-        super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
-        return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field))
+        while len(queue) > 0:
+            e = queue.pop(0)
+            if e in handled_symbols:
+                continue
+            if e in assignment_dict:
+                add_symbols_from_expr(assignment_dict[e])
+            handled_symbols.add(e)
 
-    @property
-    def typed_symbol(self):
-        return self.base.label
+        return handled_symbols
 
-    def __str__(self):
-        top = super(ResolvedFieldAccess, self).__str__()
-        return f"{top} ({self.typed_symbol.dtype})"
+    def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]] = None, module=None):
+        """Returns a python function to evaluate this equation collection.
 
-    def __getnewargs__(self):
-        return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
+        Args:
+            symbols: symbol(s) which are the parameter for the created function
+            fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify
+            module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy'
 
-    def __getnewargs_ex__(self):
-        return (self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values), {}
+        Examples:
+              >>> a, b, c, d = sp.symbols("a b c d")
+              >>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)],
+              ...                           subexpressions=[Assignment(b, a + b / 2)])
+              >>> python_function = ac.lambdify([a], fixed_symbols={b: 2})
+              >>> python_function(4)
+              {c: 6, d: 18}
+        """
+        assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self
+        assignments = assignments.new_without_subexpressions().main_assignments
+        lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments}
 
+        def f(*args, **kwargs):
+            return {s: func(*args, **kwargs) for s, func in lambdas.items()}
 
-class TemporaryMemoryAllocation(Node):
-    """Node for temporary memory buffer allocation.
+        return f
 
-    Always allocates aligned memory.
+    # ---------------------------- Creating new modified collections ---------------------------------------------------
 
-    Args:
-        typed_symbol: symbol used as pointer (has to be typed)
-        size: number of elements to allocate
-        align_offset: the align_offset's element is aligned
-    """
+    def copy(self,
+             main_assignments: Optional[List[Assignment]] = None,
+             subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection':
+        """Returns a copy with optionally replaced main_assignments and/or subexpressions."""
 
-    def __init__(self, typed_symbol: TypedSymbol, size, align_offset):
-        super(TemporaryMemoryAllocation, self).__init__(parent=None)
-        self.symbol = typed_symbol
-        self.size = size
-        self.headers = ['<stdlib.h>']
-        self._align_offset = align_offset
+        res = copy(self)
+        res.simplification_hints = self.simplification_hints.copy()
+        res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator)
 
-    @property
-    def symbols_defined(self):
-        return {self.symbol}
+        if main_assignments is not None:
+            res.main_assignments = main_assignments
+        else:
+            res.main_assignments = self.main_assignments.copy()
 
-    @property
-    def undefined_symbols(self):
-        if isinstance(self.size, sp.Basic):
-            return self.size.atoms(sp.Symbol)
+        if subexpressions is not None:
+            res.subexpressions = subexpressions
         else:
-            return set()
+            res.subexpressions = self.subexpressions.copy()
+
+        return res
+
+    def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,
+                               substitute_on_lhs: bool = True,
+                               sort_topologically: bool = True) -> 'AssignmentCollection':
+        """Returns new object, where terms are substituted according to the passed substitution dict.
+
+        Args:
+            substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions
+            add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions
+            substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments
+            sort_topologically: if subexpressions are added as substitutions and this parameters is true,
+                                the subexpressions are sorted topologically after insertion
+        Returns:
+            New AssignmentCollection where substitutions have been applied, self is not altered.
+        """
+        transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs
+        transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions)
+        transformed_assignments = transform(self.main_assignments, fast_subs, substitutions)
+
+        if add_substitutions_as_subexpressions:
+            transformed_subexpressions = [Assignment(b, a) for a, b in
+                                          substitutions.items()] + transformed_subexpressions
+            if sort_topologically:
+                transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
+        return self.copy(transformed_assignments, transformed_subexpressions)
+
+    def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
+        """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
+        own_definitions = set([e.lhs for e in self.main_assignments])
+        other_definitions = set([e.lhs for e in other.main_assignments])
+        assert len(own_definitions.intersection(other_definitions)) == 0, \
+            "Cannot merge collections, since both define the same symbols"
+
+        own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
+        substitution_dict = {}
+
+        processed_other_subexpression_equations = []
+        for other_subexpression_eq in other.subexpressions:
+            if other_subexpression_eq.lhs in own_subexpression_symbols:
+                if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
+                    continue  # exact the same subexpression equation exists already
+                else:
+                    # different definition - a new name has to be introduced
+                    new_lhs = next(self.subexpression_symbol_generator)
+                    new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict))
+                    processed_other_subexpression_equations.append(new_eq)
+                    substitution_dict[other_subexpression_eq.lhs] = new_lhs
+            else:
+                processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict))
+
+        processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments]
+        return self.copy(self.main_assignments + processed_other_main_assignments,
+                         self.subexpressions + processed_other_subexpression_equations)
+
+    def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection':
+        """Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions.
+
+        Returns:
+            new AssignmentCollection, self is not altered
+        """
+        symbols_to_extract = set(symbols_to_extract)
+        dependent_symbols = self.dependent_symbols(symbols_to_extract)
+        new_assignments = []
+        for eq in self.all_assignments:
+            if eq.lhs in symbols_to_extract:
+                new_assignments.append(eq)
+
+        new_sub_expr = [eq for eq in self.all_assignments
+                        if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
+        return self.copy(new_assignments, new_sub_expr)
+
+    def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
+        """Returns new collection that only contains subexpressions required to compute the main assignments."""
+        all_lhs = [eq.lhs for eq in self.main_assignments]
+        return self.new_filtered(all_lhs)
+
+    def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection':
+        """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere."""
+        new_subexpressions = []
+        subs_dict = None
+        for se in self.subexpressions:
+            if se.lhs == symbol:
+                subs_dict = {se.lhs: se.rhs}
+            else:
+                new_subexpressions.append(se)
+        if subs_dict is None:
+            return self
+
+        new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions]
+        new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
+        return self.copy(new_eqs, new_subexpressions)
+
+    def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection':
+        """Returns a new collection where all subexpressions have been inserted."""
+        if subexpressions_to_keep is None:
+            subexpressions_to_keep = set()
+        if len(self.subexpressions) == 0:
+            return self.copy()
+
+        subexpressions_to_keep = set(subexpressions_to_keep)
+
+        kept_subexpressions = []
+        if self.subexpressions[0].lhs in subexpressions_to_keep:
+            substitution_dict = {}
+            kept_subexpressions.append(self.subexpressions[0])
+        else:
+            substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
 
-    @property
-    def args(self):
-        return [self.symbol]
+        subexpression = [e for e in self.subexpressions]
+        for i in range(1, len(subexpression)):
+            subexpression[i] = fast_subs(subexpression[i], substitution_dict)
+            if subexpression[i].lhs in subexpressions_to_keep:
+                kept_subexpressions.append(subexpression[i])
+            else:
+                substitution_dict[subexpression[i].lhs] = subexpression[i].rhs
 
-    def offset(self, byte_alignment):
-        """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment."""
-        np_dtype = self.symbol.dtype.base_type.numpy_dtype
-        assert byte_alignment % np_dtype.itemsize == 0
-        return -self._align_offset % (byte_alignment / np_dtype.itemsize)
+        new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments]
+        return self.copy(new_assignment, kept_subexpressions)
 
+    # ----------------------------------------- Display and Printing   -------------------------------------------------
 
-class TemporaryMemoryFree(Node):
-    def __init__(self, alloc_node):
-        super(TemporaryMemoryFree, self).__init__(parent=None)
-        self.alloc_node = alloc_node
+    def _repr_html_(self):
+        """Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
+
+        def make_html_equation_table(equations):
+            no_border = 'style="border:none"'
+            html_table = '<table style="border:none; width: 100%; ">'
+            line = '<tr {nb}> <td {nb}>$${eq}$$</td>  </tr> '
+            for eq in equations:
+                format_dict = {'eq': sp.latex(eq),
+                               'nb': no_border, }
+                html_table += line.format(**format_dict)
+            html_table += "</table>"
+            return html_table
+
+        result = ""
+        if len(self.subexpressions) > 0:
+            result += "<div>Subexpressions:</div>"
+            result += make_html_equation_table(self.subexpressions)
+        result += "<div>Main Assignments:</div>"
+        result += make_html_equation_table(self.main_assignments)
+        return result
 
-    @property
-    def symbol(self):
-        return self.alloc_node.symbol
+    def __repr__(self):
+        return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}"
 
-    def offset(self, byte_alignment):
-        return self.alloc_node.offset(byte_alignment)
+    def __str__(self):
+        result = "Subexpressions:\n"
+        for eq in self.subexpressions:
+            result += f"\t{eq}\n"
+        result += "Main Assignments:\n"
+        for eq in self.main_assignments:
+            result += f"\t{eq}\n"
+        return result
 
-    @property
-    def symbols_defined(self):
-        return set()
+    def __iter__(self):
+        return self.all_assignments.__iter__()
 
     @property
-    def undefined_symbols(self):
-        return set()
+    def main_assignments_dict(self):
+        return {a.lhs: a.rhs for a in self.main_assignments}
 
     @property
-    def args(self):
-        return []
-
-
-def early_out(condition):
-    return Conditional(vec_all(condition), Block([SkipIteration()]))
+    def subexpressions_dict(self):
+        return {a.lhs: a.rhs for a in self.subexpressions}
 
+    def set_main_assignments_from_dict(self, main_assignments_dict):
+        self.main_assignments = [Assignment(k, v)
+                                 for k, v in main_assignments_dict.items()]
 
-def get_dummy_symbol(dtype='bool'):
-    return TypedSymbol(f'dummy{uuid.uuid4().hex}', create_type(dtype))
+    def set_sub_expressions_from_dict(self, sub_expressions_dict):
+        self.subexpressions = [Assignment(k, v)
+                               for k, v in sub_expressions_dict.items()]
 
+    def find(self, *args, **kwargs):
+        return set.union(
+            *[a.find(*args, **kwargs) for a in self.all_assignments]
+        )
 
-class SourceCodeComment(Node):
-    def __init__(self, text):
-        self.text = text
+    def match(self, *args, **kwargs):
+        rtn = {}
+        for a in self.all_assignments:
+            partial_result = a.match(*args, **kwargs)
+            if partial_result:
+                rtn.update(partial_result)
+        return rtn
 
-    @property
-    def args(self):
-        return []
+    def subs(self, *args, **kwargs):
+        return AssignmentCollection(
+            main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments],
+            subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions]
+        )
 
-    @property
-    def symbols_defined(self):
-        return set()
+    def replace(self, *args, **kwargs):
+        return AssignmentCollection(
+            main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments],
+            subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions]
+        )
 
-    @property
-    def undefined_symbols(self):
-        return set()
-
-    def __str__(self):
-        return "/* " + self.text + " */"
+    def __eq__(self, other):
+        return set(self.all_assignments) == set(other.all_assignments)
 
-    def __repr__(self):
-        return self.__str__()
+    def __bool__(self):
+        return bool(self.all_assignments)
 
 
-class EmptyLine(Node):
-    def __init__(self):
-        pass
+class SymbolGen:
+    """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
 
-    @property
-    def args(self):
-        return []
+    def __init__(self, symbol="xi", dtype=None, ctr=0):
+        self._ctr = ctr
+        self._symbol = symbol
+        self._dtype = dtype
 
-    @property
-    def symbols_defined(self):
-        return set()
+    def __iter__(self):
+        return self
 
-    @property
-    def undefined_symbols(self):
-        return set()
+    def __next__(self):
+        name = f"{self._symbol}_{self._ctr}"
+        self._ctr += 1
+        if self._dtype is not None:
+            return TypedSymbol(name, self._dtype)
+        return sp.Symbol(name)
 
-    def __str__(self):
-        return ""
 
-    def __repr__(self):
-        return self.__str__()
+def get_dummy_symbol(dtype='bool'):
+    return TypedSymbol(f'dummy{uuid.uuid4().hex}', create_type(dtype))
 
 
 class ConditionalFieldAccess(sp.Function):
@@ -712,3 +611,18 @@ class ConditionalFieldAccess(sp.Function):
 
     def __getnewargs_ex__(self):
         return (self.access, self.outofbounds_condition, self.outofbounds_value), {}
+
+
+def generic_visit(term, visitor):
+    if isinstance(term, AssignmentCollection):
+        new_main_assignments = generic_visit(term.main_assignments, visitor)
+        new_subexpressions = generic_visit(term.subexpressions, visitor)
+        return term.copy(new_main_assignments, new_subexpressions)
+    elif isinstance(term, list):
+        return [generic_visit(e, visitor) for e in term]
+    elif isinstance(term, Assignment):
+        return Assignment(term.lhs, generic_visit(term.rhs, visitor))
+    elif isinstance(term, sp.Matrix):
+        return term.applyfunc(lambda e: generic_visit(e, visitor))
+    else:
+        return visitor(term)
diff --git a/src/pystencils/sympyextensions/bit_masks.py b/src/pystencils/sympyextensions/bit_masks.py
index f8b6b7ef0..57f2ab5fb 100644
--- a/src/pystencils/sympyextensions/bit_masks.py
+++ b/src/pystencils/sympyextensions/bit_masks.py
@@ -1,5 +1,4 @@
 import sympy as sp
-# from pystencils.typing import get_type_of_expression
 
 
 # noinspection PyPep8Naming
diff --git a/src/pystencils/sympyextensions/math.py b/src/pystencils/sympyextensions/math.py
index 4a6f01a5c..00974809f 100644
--- a/src/pystencils/sympyextensions/math.py
+++ b/src/pystencils/sympyextensions/math.py
@@ -10,7 +10,7 @@ from sympy import PolynomialError
 from sympy.functions import Abs
 from sympy.core.numbers import Zero
 
-from .assignment import Assignment
+from .astnodes import Assignment
 from pystencils.functions import DivFunc
 from .typed_sympy import CastFunc, PointerType, VectorType, FieldPointerSymbol
 
diff --git a/src/pystencils/sympyextensions/simplifications.py b/src/pystencils/sympyextensions/simplifications.py
index c16e42f85..cdcad81e7 100644
--- a/src/pystencils/sympyextensions/simplifications.py
+++ b/src/pystencils/sympyextensions/simplifications.py
@@ -4,20 +4,21 @@ from collections import defaultdict
 
 import sympy as sp
 
-from .assignment import Assignment
-from pystencils.sympyextensions.astnodes import Node
+from .astnodes import Assignment
 from .math import subs_additive, is_constant, recursive_collect
 from .typed_sympy import TypedSymbol
 
 
-def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
+# TODO rewrite with SymPy AST
+# def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
+def sort_assignments_topologically(assignments: Sequence[Union[Assignment]]) -> List[Union[Assignment]]:
     """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
     edges = []
     for c1, e1 in enumerate(assignments):
         if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'):
             symbols = [e1.lhs]
-        elif isinstance(e1, Node):
-            symbols = e1.symbols_defined
+        # elif isinstance(e1, Node):
+        #     symbols = e1.symbols_defined
         else:
             raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.")
 
@@ -25,8 +26,8 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
             for c2, e2 in enumerate(assignments):
                 if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
                     edges.append((c1, c2))
-                elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
-                    edges.append((c1, c2))
+                # elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
+                #     edges.append((c1, c2))
     return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
 
 
diff --git a/src/pystencils/sympyextensions/simplificationstrategy.py b/src/pystencils/sympyextensions/simplificationstrategy.py
index cd2fa8faa..b76d12711 100644
--- a/src/pystencils/sympyextensions/simplificationstrategy.py
+++ b/src/pystencils/sympyextensions/simplificationstrategy.py
@@ -3,7 +3,7 @@ from typing import Any, Callable, Optional, Sequence
 
 import sympy as sp
 
-from .assignment_collection import AssignmentCollection
+from .astnodes import AssignmentCollection
 
 
 class SimplificationStrategy:
diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py
index 458fbae20..1f20624ce 100644
--- a/src/pystencils/sympyextensions/typed_sympy.py
+++ b/src/pystencils/sympyextensions/typed_sympy.py
@@ -1,5 +1,6 @@
 from abc import abstractmethod
-from typing import Union
+from itertools import groupby
+from typing import Sequence, Union
 
 import numpy as np
 import sympy as sp
@@ -330,6 +331,72 @@ class StructType(AbstractType):
         return hash((self.numpy_dtype, self.const))
 
 
+def result_type(*args: np.dtype):
+    """Returns the type of the result if the np.dtype arguments would be collated.
+    We can't use numpy functionality, because numpy casts don't behave exactly like C casts"""
+    s = sorted(args, key=lambda x: x.itemsize)
+
+    def kind_to_value(kind: str) -> int:
+        if kind == 'f':
+            return 3
+        elif kind == 'i':
+            return 2
+        elif kind == 'u':
+            return 1
+        elif kind == 'b':
+            return 0
+        else:
+            raise NotImplementedError(f'{kind=} is not a supported kind of a type. See "numpy.dtype.kind" for options')
+    s = sorted(s, key=lambda x: kind_to_value(x.kind))
+    return s[-1]
+
+
+def all_equal(iterable):
+    """
+    Returns ``True`` if all the elements are equal to each other.
+    Copied from: more-itertools 8.12.0
+    """
+    g = groupby(iterable)
+    return next(g, True) and not next(g, False)
+
+
+def collate_types(types: Sequence[Union[BasicType, VectorType]]):
+    """
+    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
+    Uses the collation rules from numpy.
+    """
+    # Pointer arithmetic case i.e. pointer + [int, uint] is allowed
+    if any(isinstance(t, PointerType) for t in types):
+        pointer_type = None
+        for t in types:
+            if isinstance(t, PointerType):
+                if pointer_type is not None:
+                    raise ValueError(f'Cannot collate the combination of two pointer types "{pointer_type}" and "{t}"')
+                pointer_type = t
+            elif isinstance(t, BasicType):
+                if not (t.is_int() or t.is_uint()):
+                    raise ValueError("Invalid pointer arithmetic")
+            else:
+                raise ValueError("Invalid pointer arithmetic")
+        return pointer_type
+
+    # # peel of vector types, if at least one vector type occurred the result will also be the vector type
+    vector_type = [t for t in types if isinstance(t, VectorType)]
+    if not all_equal(t.width for t in vector_type):
+        raise ValueError("Collation failed because of vector types with different width")
+
+    types = [t.base_type if isinstance(t, VectorType) else t for t in types]
+
+    # now we should have a list of basic types - struct types are not yet supported
+    assert all(type(t) is BasicType for t in types)
+
+    result_numpy_type = result_type(*(t.numpy_dtype for t in types))
+    result = BasicType(result_numpy_type)
+    if vector_type:
+        result = VectorType(result, vector_type[0].width)
+    return result
+
+
 def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]):
     """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
 
@@ -419,6 +486,9 @@ class TypedSymbol(sp.Symbol):
 
 SHAPE_DTYPE = BasicType('int64', const=True)
 STRIDE_DTYPE = BasicType('int64', const=True)
+LOOP_COUNTER_DTYPE = BasicType('int64', const=True)
+LOOP_COUNTER_NAME_PREFIX = "ctr"
+BLOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
 
 
 class FieldStrideSymbol(TypedSymbol):
@@ -501,6 +571,25 @@ class FieldPointerSymbol(TypedSymbol):
     __xnew_cached_ = staticmethod(sp.core.cacheit(__new_stage2__))
 
 
+def get_loop_counter_symbol(coordinate_to_loop_over):
+    return TypedSymbol(f"{LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}",
+                       LOOP_COUNTER_DTYPE, nonnegative=True)
+
+
+def get_block_loop_counter_symbol(coordinate_to_loop_over):
+    return TypedSymbol(f"{BLOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}",
+                       LOOP_COUNTER_DTYPE, nonnegative=True)
+
+
+def is_loop_counter_symbol(symbol):
+    if not symbol.name.startswith(LOOP_COUNTER_NAME_PREFIX):
+        return None
+    if symbol.dtype != LOOP_COUNTER_DTYPE:
+        return None
+    coordinate = int(symbol.name[len(LOOP_COUNTER_NAME_PREFIX) + 1:])
+    return coordinate
+
+
 class CastFunc(sp.Function):
     """
     CastFunc is used in order to introduce static casts. They are especially useful as a way to signal what type
-- 
GitLab