From 942c7d965b05c83b8f16fd699253c694b5089007 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de>
Date: Wed, 24 Nov 2021 19:38:22 +0100
Subject: [PATCH] Created typing as own module

---
 pystencils/__init__.py                        |   2 +-
 pystencils/alignedarray.py                    |   4 +-
 pystencils/astnodes.py                        |   6 +-
 pystencils/backends/cbackend.py               |  36 +-
 pystencils/bit_masks.py                       |   2 +-
 pystencils/boundaries/boundaryconditions.py   |   2 +-
 pystencils/boundaries/boundaryhandling.py     |   2 +-
 pystencils/boundaries/inkernel.py             |   2 +-
 pystencils/cache.py                           |   2 +-
 pystencils/cpu/cpujit.py                      |   6 +-
 pystencils/cpu/kernelcreation.py              |   4 +-
 pystencils/cpu/vectorization.py               |  38 +-
 pystencils/data_types.py                      | 927 ------------------
 pystencils/field.py                           |  26 +-
 pystencils/gpucuda/cudajit.py                 |   2 +-
 pystencils/gpucuda/indexing.py                |   2 +-
 pystencils/gpucuda/kernelcreation.py          |   4 +-
 pystencils/integer_functions.py               |   7 +-
 pystencils/integer_set_analysis.py            |   2 +-
 .../kerncraft_coupling/generate_benchmark.py  |   2 +-
 .../kerncraft_coupling/kerncraft_interface.py |   2 +-
 pystencils/kernel_contrains_check.py          | 150 +++
 pystencils/rng.py                             |   6 +-
 pystencils/sympyextensions.py                 |   4 +-
 pystencils/transformations.py                 | 394 +-------
 pystencils/typing/__init__.py                 |   4 +
 pystencils/typing/cast_functions.py           | 120 +++
 .../typed_sympy.py}                           | 114 ++-
 pystencils/typing/types.py                    | 297 ++++++
 pystencils/typing/utilities.py                | 494 ++++++++++
 pystencils_tests/test_abs.py                  |   4 +-
 pystencils_tests/test_address_of.py           |  10 +-
 pystencils_tests/test_complex_numbers.py      |   2 +-
 pystencils_tests/test_cuda_known_functions.py |   2 +-
 pystencils_tests/test_field.py                |   2 +-
 .../test_floor_ceil_int_optimization.py       |   2 +-
 pystencils_tests/test_global_definitions.py   |   2 +-
 pystencils_tests/test_kernel_data_type.py     |   2 +-
 ...st_match_subs_for_assignment_collection.py |   4 +-
 pystencils_tests/test_pickle_support.py       |   2 +-
 pystencils_tests/test_random.py               |   2 +-
 pystencils_tests/test_sum_prod.py             |   2 +-
 pystencils_tests/test_transformations.py      |   2 +-
 pystencils_tests/test_type_interference.py    |   6 +-
 pystencils_tests/test_types.py                |  24 +-
 45 files changed, 1266 insertions(+), 1465 deletions(-)
 delete mode 100644 pystencils/data_types.py
 create mode 100644 pystencils/kernel_contrains_check.py
 create mode 100644 pystencils/typing/__init__.py
 create mode 100644 pystencils/typing/cast_functions.py
 rename pystencils/{kernelparameters.py => typing/typed_sympy.py} (52%)
 create mode 100644 pystencils/typing/types.py
 create mode 100644 pystencils/typing/utilities.py

diff --git a/pystencils/__init__.py b/pystencils/__init__.py
index a10acb8f6..4d97202bd 100644
--- a/pystencils/__init__.py
+++ b/pystencils/__init__.py
@@ -3,7 +3,7 @@ from .enums import Backend, Target
 from . import fd
 from . import stencil as stencil
 from .assignment import Assignment, assignment_from_stencil
-from .data_types import TypedSymbol
+from pystencils.typing.typed_sympy import TypedSymbol
 from .datahandling import create_data_handling
 from .display_utils import get_code_obj, get_code_str, show_code, to_dot
 from .field import Field, FieldType, fields
diff --git a/pystencils/alignedarray.py b/pystencils/alignedarray.py
index da20a778e..26c3aa5ba 100644
--- a/pystencils/alignedarray.py
+++ b/pystencils/alignedarray.py
@@ -1,5 +1,5 @@
 import numpy as np
-from pystencils.data_types import BasicType
+from pystencils.typing import numpy_name_to_c
 
 
 def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, order='C', align_inner_coordinate=True):
@@ -21,7 +21,7 @@ def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, o
         from pystencils.backends.simd_instruction_sets import (get_supported_instruction_sets, get_cacheline_size,
                                                                get_vector_instruction_set)
 
-        type_name = BasicType.numpy_name_to_c(np.dtype(dtype).name)
+        type_name = numpy_name_to_c(np.dtype(dtype).name)
         instruction_sets = get_supported_instruction_sets()
         if instruction_sets is None:
             byte_alignment = 64
diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 689a18b02..b9f13ae64 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -6,7 +6,7 @@ from typing import Any, List, Optional, Sequence, Set, Union
 import sympy as sp
 
 import pystencils
-from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
+from pystencils.typing import TypedSymbol, CastFunc, create_type, get_next_parent_of_type
 from pystencils.enums import Target, Backend
 from pystencils.field import Field
 from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
@@ -542,7 +542,6 @@ class LoopOverCoordinate(Node):
 
     @property
     def is_outermost_loop(self):
-        from pystencils.transformations import get_next_parent_of_type
         return get_next_parent_of_type(self, LoopOverCoordinate) is None
 
     @property
@@ -571,7 +570,7 @@ class SympyAssignment(Node):
         self.use_auto = use_auto
 
     def __is_declaration(self):
-        if isinstance(self._lhs_symbol, cast_func):
+        if isinstance(self._lhs_symbol, CastFunc):
             return False
         if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)):
             return False
@@ -616,7 +615,6 @@ class SympyAssignment(Node):
             if isinstance(symbol, Field.Access):
                 for i in range(len(symbol.offsets)):
                     loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
-        result = {r for r in result if not isinstance(r, TypedImaginaryUnit)}
         result.update(loop_counters)
         result.update(self._lhs_symbol.atoms(sp.Symbol))
         return result
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index a17ee7269..fa3079e32 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -11,9 +11,9 @@ from sympy.logic.boolalg import BooleanFalse, BooleanTrue
 
 from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
 from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
-from pystencils.data_types import (
-    PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
-    reinterpret_cast_func, vector_memory_access, BasicType, TypedSymbol)
+from pystencils.typing import (
+    PointerType, VectorType, address_of, CastFunc, create_type, get_type_of_expression,
+    ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol)
 from pystencils.enums import Backend
 from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
 from pystencils.integer_functions import (
@@ -276,7 +276,7 @@ class CBackend:
         else:
             lhs_type = get_type_of_expression(node.lhs)
             printed_mask = ""
-            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
+            if type(lhs_type) is VectorType and isinstance(node.lhs, CastFunc):
                 arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args
                 instr = 'storeU'
                 if aligned:
@@ -289,12 +289,12 @@ class CBackend:
                                 self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs),
                                 '{1}', '{2}', **self._kwargs), **self._kwargs)
                     printed_mask = self.sympy_printer.doprint(mask)
-                    if data_type.base_type.base_name == 'double':
+                    if data_type.base_type.c_name == 'double':
                         if self._vector_instruction_set['double'] == '__m256d':
                             printed_mask = f"_mm256_castpd_si256({printed_mask})"
                         elif self._vector_instruction_set['double'] == '__m128d':
                             printed_mask = f"_mm_castpd_si128({printed_mask})"
-                    elif data_type.base_type.base_name == 'float':
+                    elif data_type.base_type.c_name == 'float':
                         if self._vector_instruction_set['float'] == '__m256':
                             printed_mask = f"_mm256_castps_si256({printed_mask})"
                         elif self._vector_instruction_set['float'] == '__m128':
@@ -302,7 +302,7 @@ class CBackend:
 
                 rhs_type = get_type_of_expression(node.rhs)
                 if type(rhs_type) is not VectorType:
-                    rhs = cast_func(node.rhs, VectorType(rhs_type))
+                    rhs = CastFunc(node.rhs, VectorType(rhs_type))
                 else:
                     rhs = node.rhs
 
@@ -322,7 +322,7 @@ class CBackend:
                     if stride == 1:
                         offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1})
                     size = sp.Mul(*node.lhs.args[0].field.spatial_shape)
-                    element_size = 8 if data_type.base_type.base_name == 'double' else 4
+                    element_size = 8 if data_type.base_type.c_name == 'double' else 4
                     size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}"
                     pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
                         self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n'
@@ -483,13 +483,13 @@ class CustomSympyPrinter(CCodePrinter):
         }
         if hasattr(expr, 'to_c'):
             return expr.to_c(self._print)
-        if isinstance(expr, reinterpret_cast_func):
+        if isinstance(expr, ReinterpretCastFunc):
             arg, data_type = expr.args
             return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))"
         elif isinstance(expr, address_of):
             assert len(expr.args) == 1, "address_of must only have one argument"
             return f"&({self._print(expr.args[0])})"
-        elif isinstance(expr, cast_func):
+        elif isinstance(expr, CastFunc):
             arg, data_type = expr.args
             if isinstance(arg, sp.Number) and arg.is_finite:
                 return self._typed_number(arg, data_type)
@@ -648,22 +648,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
             return None
 
     def _print_Abs(self, expr):
-        if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access):
+        if 'abs' in self.instruction_set and isinstance(expr.args[0], VectorMemoryAccess):
             return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs)
         return super()._print_Abs(expr)
 
     def _print_Function(self, expr):
-        if isinstance(expr, vector_memory_access):
+        if isinstance(expr, VectorMemoryAccess):
             arg, data_type, aligned, _, mask, stride = expr.args
             if stride != 1:
                 return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs)
             instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
             return instruction.format(f"& {self._print(arg)}", **self._kwargs)
-        elif isinstance(expr, cast_func):
+        elif isinstance(expr, CastFunc):
             arg, data_type = expr.args
             if type(data_type) is VectorType:
                 # vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
-                assert not isinstance(arg, vector_memory_access)
+                assert not isinstance(arg, VectorMemoryAccess)
                 if isinstance(arg, sp.Tuple):
                     is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
                     is_integer = get_type_of_expression(arg[0]) == create_type("int")
@@ -747,12 +747,12 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
 
         # special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization
         suffix = ""
-        if all([(type(e) is cast_func and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer)
+        if all([(type(e) is CastFunc and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer)
                 or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]):
-            dtype = set([e.dtype for e in args if type(e) is cast_func])
+            dtype = set([e.dtype for e in args if type(e) is CastFunc])
             assert len(dtype) == 1
             dtype = dtype.pop()
-            args = [cast_func(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e
+            args = [CastFunc(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e
                     for e in args]
             suffix = "int"
 
@@ -880,7 +880,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
 
         result = self._print(expr.args[-1][0])
         for true_expr, condition in reversed(expr.args[:-1]):
-            if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
+            if isinstance(condition, CastFunc) and get_type_of_expression(condition.args[0]) == create_type("bool"):
                 if not KERNCRAFT_NO_TERNARY_MODE:
                     result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
                                                            result, **self._kwargs)
diff --git a/pystencils/bit_masks.py b/pystencils/bit_masks.py
index 0fab63b25..73c18688c 100644
--- a/pystencils/bit_masks.py
+++ b/pystencils/bit_masks.py
@@ -1,5 +1,5 @@
 import sympy as sp
-from pystencils.data_types import get_type_of_expression
+from pystencils.typing import get_type_of_expression
 
 
 # noinspection PyPep8Naming
diff --git a/pystencils/boundaries/boundaryconditions.py b/pystencils/boundaries/boundaryconditions.py
index dc01224d0..c53d248ac 100644
--- a/pystencils/boundaries/boundaryconditions.py
+++ b/pystencils/boundaries/boundaryconditions.py
@@ -2,7 +2,7 @@ from typing import Any, List, Tuple
 
 from pystencils import Assignment
 from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo
-from pystencils.data_types import create_type
+from pystencils.typing import create_type
 
 
 class Boundary:
diff --git a/pystencils/boundaries/boundaryhandling.py b/pystencils/boundaries/boundaryhandling.py
index 5705d3d53..4ad3ab3ff 100644
--- a/pystencils/boundaries/boundaryhandling.py
+++ b/pystencils/boundaries/boundaryhandling.py
@@ -7,7 +7,7 @@ from pystencils.backends.cbackend import CustomCodeNode
 from pystencils.boundaries.createindexlist import (
     create_boundary_index_array, numpy_data_type_for_boundary_object)
 from pystencils.cache import memorycache
-from pystencils.data_types import TypedSymbol, create_type
+from pystencils.typing import TypedSymbol, create_type
 from pystencils.datahandling.pycuda import PyCudaArrayHandler
 from pystencils.field import Field
 from pystencils.kernelparameters import FieldPointerSymbol
diff --git a/pystencils/boundaries/inkernel.py b/pystencils/boundaries/inkernel.py
index 1d78814db..479f30d22 100644
--- a/pystencils/boundaries/inkernel.py
+++ b/pystencils/boundaries/inkernel.py
@@ -1,7 +1,7 @@
 import sympy as sp
 
 from pystencils.boundaries.boundaryhandling import DEFAULT_FLAG_TYPE
-from pystencils.data_types import TypedSymbol, create_type
+from pystencils.typing import TypedSymbol, create_type
 from pystencils.field import Field
 from pystencils.integer_functions import bitwise_and
 
diff --git a/pystencils/cache.py b/pystencils/cache.py
index f29678920..15274ccb8 100644
--- a/pystencils/cache.py
+++ b/pystencils/cache.py
@@ -5,7 +5,7 @@ from itertools import chain
 
 try:
     from functools import lru_cache as memorycache
-except ImportError:
+except ImportError:  # TODO what python version is this???
     from backports.functools_lru_cache import lru_cache as memorycache
 
 from joblib import Memory
diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py
index 240cddd49..2861d671f 100644
--- a/pystencils/cpu/cpujit.py
+++ b/pystencils/cpu/cpujit.py
@@ -60,7 +60,7 @@ from appdirs import user_cache_dir, user_config_dir
 from pystencils import FieldType
 from pystencils.astnodes import LoopOverCoordinate
 from pystencils.backends.cbackend import generate_c, get_headers, CFunction
-from pystencils.data_types import cast_func, VectorType, vector_memory_access
+from pystencils.typing import CastFunc, VectorType, VectorMemoryAccess
 from pystencils.include import get_pystencils_include_path
 from pystencils.kernel_wrapper import KernelWrapper
 from pystencils.utils import atomic_file_write, recursive_dict_update
@@ -388,7 +388,7 @@ def create_function_boilerplate_code(parameter_info, name, ast_node, insert_chec
                 aligned = False
                 if ast_node.assignments:
                     aligned = any([a.lhs.args[2] for a in ast_node.assignments
-                                   if hasattr(a, 'lhs') and isinstance(a.lhs, cast_func)
+                                   if hasattr(a, 'lhs') and isinstance(a.lhs, CastFunc)
                                    and hasattr(a.lhs, 'dtype') and isinstance(a.lhs.dtype, VectorType)])
 
                 if ast_node.instruction_set and aligned:
@@ -398,7 +398,7 @@ def create_function_boilerplate_code(parameter_info, name, ast_node, insert_chec
                         for loop in ast_node.atoms(LoopOverCoordinate):
                             has_openmp = has_openmp or any(['#pragma omp' in p for p in loop.prefix_lines])
                             has_nontemporal = has_nontemporal or any([a.args[0].field == field and a.args[3] for a in
-                                                                      loop.atoms(vector_memory_access)])
+                                                                      loop.atoms(VectorMemoryAccess)])
                         if has_openmp and has_nontemporal:
                             byte_width = ast_node.instruction_set['cachelineSize']
                     offset = max(max(ast_node.ghost_layers)) * item_size
diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py
index 865beefa9..7b2719fd7 100644
--- a/pystencils/cpu/kernelcreation.py
+++ b/pystencils/cpu/kernelcreation.py
@@ -8,10 +8,10 @@ from pystencils.assignment import Assignment
 from pystencils.enums import Target, Backend
 from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
 from pystencils.cpu.cpujit import make_python_function
-from pystencils.data_types import StructType, TypedSymbol, create_type
+from pystencils.typing import StructType, TypedSymbol, create_type, add_types
 from pystencils.field import Field, FieldType
 from pystencils.transformations import (
-    add_types, filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering, make_loop_over_domain,
+    filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering, make_loop_over_domain,
     move_constants_before_loop, parse_base_pointer_info, resolve_buffer_accesses,
     resolve_field_accesses, split_inner_loop)
 
diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index c0511aa16..a161d5879 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -7,8 +7,8 @@ from sympy.logic.boolalg import BooleanFunction
 
 import pystencils.astnodes as ast
 from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set
-from pystencils.data_types import (
-    PointerType, TypedSymbol, VectorType, cast_func, collate_types, get_type_of_expression, vector_memory_access)
+from pystencils.typing import (
+    PointerType, TypedSymbol, VectorType, CastFunc, collate_types, get_type_of_expression, VectorMemoryAccess)
 from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
 from pystencils.field import Field
 from pystencils.integer_functions import modulo_ceil, modulo_floor
@@ -180,8 +180,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
                 nontemporal = False
                 if hasattr(indexed, 'field'):
                     nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
-                substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True,
-                                                              stride if strided else 1)
+                substitutions[indexed] = VectorMemoryAccess(indexed, vec_type, use_aligned_access, nontemporal, True,
+                                                            stride if strided else 1)
                 if nontemporal:
                     # insert NontemporalFence after the outermost loop
                     parent = loop_node.parent
@@ -197,12 +197,12 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
         loop_node.step = vector_width
         loop_node.subs(substitutions)
         vector_int_width = ast_node.instruction_set['intwidth']
-        vector_loop_counter = cast_func(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \
-            + cast_func(tuple(range(vector_int_width if type(vector_int_width) is int else 2)),
-                        VectorType(loop_counter_symbol.dtype, vector_int_width))
+        vector_loop_counter = CastFunc(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \
+                              + CastFunc(tuple(range(vector_int_width if type(vector_int_width) is int else 2)),
+                                         VectorType(loop_counter_symbol.dtype, vector_int_width))
 
         fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
-                  skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access))
+                  skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, VectorMemoryAccess))
 
         mask_conditionals(loop_node)
 
@@ -232,8 +232,8 @@ def mask_conditionals(loop_body):
                 node.condition_expr = vec_any(node.condition_expr)
         elif isinstance(node, ast.SympyAssignment):
             if mask is not True:
-                s = {ma: vector_memory_access(*ma.args[0:4], sp.And(mask, ma.args[4]), *ma.args[5:])
-                     for ma in node.atoms(vector_memory_access)}
+                s = {ma: VectorMemoryAccess(*ma.args[0:4], sp.And(mask, ma.args[4]), *ma.args[5:])
+                     for ma in node.atoms(VectorMemoryAccess)}
                 node.subs(s)
         else:
             for arg in node.args:
@@ -248,13 +248,13 @@ def insert_vector_casts(ast_node, default_float_type='double'):
     handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all)
 
     def visit_expr(expr, default_type='double'):
-        if isinstance(expr, vector_memory_access):
-            return vector_memory_access(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:])
-        elif isinstance(expr, cast_func):
+        if isinstance(expr, VectorMemoryAccess):
+            return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:])
+        elif isinstance(expr, CastFunc):
             return expr
         elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
             new_arg = visit_expr(expr.args[0], default_type)
-            base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is vector_memory_access \
+            base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is VectorMemoryAccess \
                 else get_type_of_expression(expr.args[0])
             pw = sp.Piecewise((-new_arg, new_arg < base_type.numpy_dtype.type(0)),
                               (new_arg, True))
@@ -263,7 +263,7 @@ def insert_vector_casts(ast_node, default_float_type='double'):
             if expr.func is sp.Mul and expr.args[0] == -1:
                 # special treatment for the unary minus: make sure that the -1 has the same type as the argument
                 dtype = int
-                for arg in expr.atoms(vector_memory_access):
+                for arg in expr.atoms(VectorMemoryAccess):
                     if arg.dtype.base_type.is_float():
                         dtype = arg.dtype.base_type.numpy_dtype.type
                 for arg in expr.atoms(TypedSymbol):
@@ -280,7 +280,7 @@ def insert_vector_casts(ast_node, default_float_type='double'):
             else:
                 target_type = collate_types(arg_types)
                 casted_args = [
-                    cast_func(a, target_type) if t != target_type and not isinstance(a, vector_memory_access) else a
+                    CastFunc(a, target_type) if t != target_type and not isinstance(a, VectorMemoryAccess) else a
                     for a, t in zip(new_args, arg_types)]
                 return expr.func(*casted_args)
         elif expr.func is sp.Pow:
@@ -299,10 +299,10 @@ def insert_vector_casts(ast_node, default_float_type='double'):
             if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType:
                 condition_target_type = VectorType(condition_target_type, width=result_target_type.width)
 
-            casted_results = [cast_func(a, result_target_type) if t != result_target_type else a
+            casted_results = [CastFunc(a, result_target_type) if t != result_target_type else a
                               for a, t in zip(new_results, types_of_results)]
 
-            casted_conditions = [cast_func(a, condition_target_type)
+            casted_conditions = [CastFunc(a, condition_target_type)
                                  if t != condition_target_type and a is not True else a
                                  for a, t in zip(new_conditions, types_of_conditions)]
 
@@ -326,7 +326,7 @@ def insert_vector_casts(ast_node, default_float_type='double'):
                         new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type)
                         substitution_dict[assignment.lhs] = new_lhs
                         assignment.lhs = new_lhs
-                elif isinstance(assignment.lhs, vector_memory_access):
+                elif isinstance(assignment.lhs, VectorMemoryAccess):
                     assignment.lhs = visit_expr(assignment.lhs, default_type)
             elif isinstance(arg, ast.Conditional):
                 arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict,
diff --git a/pystencils/data_types.py b/pystencils/data_types.py
deleted file mode 100644
index 9bf8375bf..000000000
--- a/pystencils/data_types.py
+++ /dev/null
@@ -1,927 +0,0 @@
-import ctypes
-from collections import defaultdict
-from functools import partial
-from typing import Tuple
-
-import numpy as np
-import sympy as sp
-import sympy.codegen.ast
-from sympy.core.cache import cacheit
-from sympy.logic.boolalg import Boolean, BooleanFunction
-
-import pystencils
-from pystencils.cache import memorycache, memorycache_if_hashable
-from pystencils.utils import all_equal
-
-try:
-    import llvmlite.ir as ir
-except ImportError as e:
-    ir = None
-    _ir_importerror = e
-
-
-def typed_symbols(names, dtype, *args):
-    symbols = sp.symbols(names, *args)
-    if isinstance(symbols, Tuple):
-        return tuple(TypedSymbol(str(s), dtype) for s in symbols)
-    else:
-        return TypedSymbol(str(symbols), dtype)
-
-
-def type_all_numbers(expr, dtype):
-    # TODO: move to pystnecils_walberla
-    substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)}
-    return expr.subs(substitutions)
-
-
-def matrix_symbols(names, dtype, rows, cols):
-    # TODO: check if needed. (lbmpy, walberla)
-    if isinstance(names, str):
-        names = names.replace(' ', '').split(',')
-
-    matrices = []
-    for n in names:
-        symbols = typed_symbols(f"{n}:{rows * cols}", dtype)
-        matrices.append(sp.Matrix(rows, cols, lambda i, j: symbols[i * cols + j]))
-
-    return tuple(matrices)
-
-
-def assumptions_from_dtype(dtype):
-    # TODO: type hints and if dtype is correct type form Numpy
-    """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
-
-    Args:
-        dtype (BasicType, np.dtype): a Numpy data type
-    Returns:
-        A dict of SymPy assumptions
-    """
-    if hasattr(dtype, 'numpy_dtype'):
-        dtype = dtype.numpy_dtype
-
-    assumptions = dict()
-
-    try:
-        if np.issubdtype(dtype, np.integer):
-            assumptions.update({'integer': True})
-
-        if np.issubdtype(dtype, np.unsignedinteger):
-            assumptions.update({'negative': False})
-
-        if np.issubdtype(dtype, np.integer) or \
-                np.issubdtype(dtype, np.floating):
-            assumptions.update({'real': True})
-    except Exception:
-        pass
-
-    return assumptions
-
-
-# noinspection PyPep8Naming
-class address_of(sp.Function):
-    # DONE: ask Martin
-    # TODO: documentation
-    # TODO: move function to `functions.py`
-    # this is '&' in C
-    is_Atom = True
-
-    def __new__(cls, arg):
-        obj = sp.Function.__new__(cls, arg)
-        return obj
-
-    @property
-    def canonical(self):
-        if hasattr(self.args[0], 'canonical'):
-            return self.args[0].canonical
-        else:
-            raise NotImplementedError()
-
-    @property
-    def is_commutative(self):
-        return self.args[0].is_commutative
-
-    @property
-    def dtype(self):
-        if hasattr(self.args[0], 'dtype'):
-            return PointerType(self.args[0].dtype, restrict=True)
-        else:
-            return PointerType('void', restrict=True)
-
-
-# noinspection PyPep8Naming
-class cast_func(sp.Function):
-    # TODO: documentation
-    # TODO: move function to `functions.py`
-    is_Atom = True
-
-    def __new__(cls, *args, **kwargs):
-        if len(args) != 2:
-            pass
-        expr, dtype, *other_args = args
-        if not isinstance(dtype, Type):
-            dtype = create_type(dtype)
-        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
-        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
-        # to problems when for example comparing cast_func's for equality
-        #
-        # lhs = bitwise_and(a, cast_func(1, 'int'))
-        # rhs = cast_func(0, 'int')
-        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
-        # -> thus a separate class boolean_cast_func is introduced
-        if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType(bool)):
-            cls = boolean_cast_func
-
-        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
-
-    @property
-    def canonical(self):
-        if hasattr(self.args[0], 'canonical'):
-            return self.args[0].canonical
-        else:
-            raise NotImplementedError()
-
-    @property
-    def is_commutative(self):
-        return self.args[0].is_commutative
-
-    def _eval_evalf(self, *args, **kwargs):
-        return self.args[0].evalf()
-
-    @property
-    def dtype(self):
-        return self.args[1]
-
-    @property
-    def is_integer(self):
-        """
-        Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
-
-        For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
-        """
-        if hasattr(self.dtype, 'numpy_dtype'):
-            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
-        else:
-            return super().is_integer
-
-    @property
-    def is_negative(self):
-        """
-        See :func:`.TypedSymbol.is_integer`
-        """
-        if hasattr(self.dtype, 'numpy_dtype'):
-            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
-                return False
-
-        return super().is_negative
-
-    @property
-    def is_nonnegative(self):
-        """
-        See :func:`.TypedSymbol.is_integer`
-        """
-        if self.is_negative is False:
-            return True
-        else:
-            return super().is_nonnegative
-
-    @property
-    def is_real(self):
-        """
-        See :func:`.TypedSymbol.is_integer`
-        """
-        if hasattr(self.dtype, 'numpy_dtype'):
-            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
-                np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
-                super().is_real
-        else:
-            return super().is_real
-
-
-# noinspection PyPep8Naming
-class boolean_cast_func(cast_func, Boolean):
-    # TODO: documentation
-    # TODO: move function to `functions.py`
-    pass
-
-
-# noinspection PyPep8Naming
-class vector_memory_access(cast_func):
-    # TODO: documentation
-    # TODO: move function to `functions.py`
-    # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride
-    nargs = (6,)
-
-
-# noinspection PyPep8Naming
-class reinterpret_cast_func(cast_func):
-    # TODO: documentation
-    # TODO: move function to `functions.py`
-    pass
-
-
-# noinspection PyPep8Naming
-class pointer_arithmetic_func(sp.Function, Boolean):
-    # TODO: documentation
-    # TODO: move function to `functions.py`
-    @property
-    def canonical(self):
-        if hasattr(self.args[0], 'canonical'):
-            return self.args[0].canonical
-        else:
-            raise NotImplementedError()
-
-
-class TypedSymbol(sp.Symbol):
-    def __new__(cls, *args, **kwds):
-        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
-        return obj
-
-    def __new_stage2__(cls, name, dtype, **kwargs):
-        assumptions = assumptions_from_dtype(dtype)
-        assumptions.update(kwargs)
-        obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
-        try:
-            obj._dtype = create_type(dtype)
-        except (TypeError, ValueError):
-            # on error keep the string
-            obj._dtype = dtype
-        return obj
-
-    __xnew__ = staticmethod(__new_stage2__)
-    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
-
-    @property
-    def dtype(self):
-        return self._dtype
-
-    def _hashable_content(self):
-        return super()._hashable_content(), hash(self._dtype)
-
-    def __getnewargs__(self):
-        return self.name, self.dtype
-
-    def __getnewargs_ex__(self):
-        return (self.name, self.dtype), self.assumptions0
-
-    @property
-    def canonical(self):
-        return self
-
-    @property
-    def reversed(self):
-        return self
-
-    @property
-    def headers(self):
-        headers = []
-        try:
-            if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
-                headers.append('"cuda_complex.hpp"')
-        except Exception:
-            pass
-        try:
-            if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
-                headers.append('"cuda_complex.hpp"')
-        except Exception:
-            pass
-
-        return headers
-
-
-def create_type(specification):
-    # TODO: HERE
-    # TODO: type hint -> np.type
-    """Creates a subclass of Type according to a string or an object of subclass Type.
-
-    Args:
-        specification: Type object, or a string
-
-    Returns:
-        Type object, or a new Type object parsed from the string
-    """
-    if isinstance(specification, Type):
-        return specification
-    else:
-        numpy_dtype = np.dtype(specification)
-        if numpy_dtype.fields is None:
-            return BasicType(numpy_dtype, const=False)
-        else:
-            return StructType(numpy_dtype, const=False)
-
-
-@memorycache(maxsize=64)
-def create_composite_type_from_string(specification):
-    # TODO: can be removed after llvm removla and fix of kernelparameters
-    """Creates a new Type object from a c-like string specification.
-
-    Args:
-        specification: Specification string
-
-    Returns:
-        Type object
-    """
-    specification = specification.lower().split()
-    parts = []
-    current = []
-    for s in specification:
-        if s == '*':
-            parts.append(current)
-            current = [s]
-        else:
-            current.append(s)
-    if len(current) > 0:
-        parts.append(current)
-        # Parse native part
-    base_part = parts.pop(0)
-    const = False
-    if 'const' in base_part:
-        const = True
-        base_part.remove('const')
-    assert len(base_part) == 1
-    if base_part[0][-1] == "*":
-        base_part[0] = base_part[0][:-1]
-        parts.append('*')
-    current_type = BasicType(np.dtype(base_part[0]), const)
-    # Parse pointer parts
-    for part in parts:
-        restrict = False
-        const = False
-        if 'restrict' in part:
-            restrict = True
-            part.remove('restrict')
-        if 'const' in part:
-            const = True
-            part.remove("const")
-        assert len(part) == 1 and part[0] == '*'
-        current_type = PointerType(current_type, const, restrict)
-    return current_type
-
-
-def get_base_type(data_type):
-    # TODO: WTF is this?? DOCS!!!
-    # TODO: Can be removed after removal of kerncraft and fix in FieldPointer Symbol
-    while data_type.base_type is not None:
-        data_type = data_type.base_type
-    return data_type
-
-
-def to_ctypes(data_type):
-    # TODO: can be removed with llvm
-    """
-    Transforms a given Type into ctypes
-    :param data_type: Subclass of Type
-    :return: ctypes type object
-    """
-    if isinstance(data_type, PointerType):
-        return ctypes.POINTER(to_ctypes(data_type.base_type))
-    elif isinstance(data_type, StructType):
-        return ctypes.POINTER(ctypes.c_uint8)
-    else:
-        return to_ctypes.map[data_type.numpy_dtype]
-
-# TODO: can be removed with llvm
-to_ctypes.map = {
-    np.dtype(np.int8): ctypes.c_int8,
-    np.dtype(np.int16): ctypes.c_int16,
-    np.dtype(np.int32): ctypes.c_int32,
-    np.dtype(np.int64): ctypes.c_int64,
-
-    np.dtype(np.uint8): ctypes.c_uint8,
-    np.dtype(np.uint16): ctypes.c_uint16,
-    np.dtype(np.uint32): ctypes.c_uint32,
-    np.dtype(np.uint64): ctypes.c_uint64,
-
-    np.dtype(np.float32): ctypes.c_float,
-    np.dtype(np.float64): ctypes.c_double,
-}
-
-
-def ctypes_from_llvm(data_type):
-    # TODO can be removed with LLVM
-    if not ir:
-        raise _ir_importerror
-    if isinstance(data_type, ir.PointerType):
-        ctype = ctypes_from_llvm(data_type.pointee)
-        if ctype is None:
-            return ctypes.c_void_p
-        else:
-            return ctypes.POINTER(ctype)
-    elif isinstance(data_type, ir.IntType):
-        if data_type.width == 8:
-            return ctypes.c_int8
-        elif data_type.width == 16:
-            return ctypes.c_int16
-        elif data_type.width == 32:
-            return ctypes.c_int32
-        elif data_type.width == 64:
-            return ctypes.c_int64
-        else:
-            raise ValueError("Int width %d is not supported" % data_type.width)
-    elif isinstance(data_type, ir.FloatType):
-        return ctypes.c_float
-    elif isinstance(data_type, ir.DoubleType):
-        return ctypes.c_double
-    elif isinstance(data_type, ir.VoidType):
-        return None  # Void type is not supported by ctypes
-    else:
-        raise NotImplementedError(f'Data type {type(data_type)} of {data_type} is not supported yet')
-
-
-def to_llvm_type(data_type, nvvm_target=False):
-    # TODO: can be removed with LLVM
-    """
-    Transforms a given type into ctypes
-    :param data_type: Subclass of Type
-    :return: llvmlite type object
-    """
-    if not ir:
-        raise _ir_importerror
-    if isinstance(data_type, PointerType):
-        return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0)
-    else:
-        return to_llvm_type.map[data_type.numpy_dtype]
-
-
-# TODO: can be removed with LLVM
-if ir:
-    to_llvm_type.map = {
-        np.dtype(np.int8): ir.IntType(8),
-        np.dtype(np.int16): ir.IntType(16),
-        np.dtype(np.int32): ir.IntType(32),
-        np.dtype(np.int64): ir.IntType(64),
-
-        np.dtype(np.uint8): ir.IntType(8),
-        np.dtype(np.uint16): ir.IntType(16),
-        np.dtype(np.uint32): ir.IntType(32),
-        np.dtype(np.uint64): ir.IntType(64),
-
-        np.dtype(np.float32): ir.FloatType(),
-        np.dtype(np.float64): ir.DoubleType(),
-    }
-
-
-def peel_off_type(dtype, type_to_peel_off):
-    # TODO: WTF is this??? DOCS!!!
-    # TODO: used only once.... can be a lambda there
-    while type(dtype) is type_to_peel_off:
-        dtype = dtype.base_type
-    return dtype
-
-
-############################# This is basically our type system ########################################################
-def collate_types(types,
-                  forbid_collation_to_complex=False,  # TODO: type system shouldn't need this!!!
-                  forbid_collation_to_float=False,    # TODO: type system shouldn't need this!!!
-                  default_float_type='float64',       # TODO: AST leaves should be typed. Expressions should be able to find out correct type
-                  default_int_type='int64'):          # TODO: AST leaves should be typed. Expressions should be able to find out correct type
-    """
-    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
-    Uses the collation rules from numpy.
-    """
-    if forbid_collation_to_complex:
-        types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.complexfloating)]
-        if not types:
-            return create_type(default_float_type)
-
-    if forbid_collation_to_float:
-        types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)]
-        if not types:
-            return create_type(default_int_type)
-
-    # Pointer arithmetic case i.e. pointer + integer is allowed
-    if any(type(t) is PointerType for t in types):
-        pointer_type = None
-        for t in types:
-            if type(t) is PointerType:
-                if pointer_type is not None:
-                    raise ValueError("Cannot collate the combination of two pointer types")
-                pointer_type = t
-            elif type(t) is BasicType:
-                if not (t.is_int() or t.is_uint()):
-                    raise ValueError("Invalid pointer arithmetic")
-            else:
-                raise ValueError("Invalid pointer arithmetic")
-        return pointer_type
-
-    # peel of vector types, if at least one vector type occurred the result will also be the vector type
-    vector_type = [t for t in types if type(t) is VectorType]
-    if not all_equal(t.width for t in vector_type):
-        raise ValueError("Collation failed because of vector types with different width")
-    types = [peel_off_type(t, VectorType) for t in types]
-
-    # now we should have a list of basic types - struct types are not yet supported
-    assert all(type(t) is BasicType for t in types)
-
-    if any(t.is_float() for t in types):
-        types = tuple(t for t in types if t.is_float())
-    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
-    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
-    result = BasicType(result_numpy_type)
-    if vector_type:
-        result = VectorType(result, vector_type[0].width)
-    return result
-
-
-@memorycache_if_hashable(maxsize=2048)
-def get_type_of_expression(expr,
-                           default_float_type='double', # TODO: we shouldn't need to have default. AST leaves should have a type
-                           default_int_type='int', # TODO: we shouldn't need to have default. AST leaves should have a type
-                           symbol_type_dict=None): # TODO: we shouldn't need to have default. AST leaves should have a type
-    from pystencils.astnodes import ResolvedFieldAccess
-    from pystencils.cpu.vectorization import vec_all, vec_any
-
-    if default_float_type == 'float':
-        default_float_type = 'float32'
-
-    if not symbol_type_dict:
-        symbol_type_dict = defaultdict(lambda: create_type('double'))
-
-    get_type = partial(get_type_of_expression,
-                       default_float_type=default_float_type,
-                       default_int_type=default_int_type,
-                       symbol_type_dict=symbol_type_dict)
-
-    expr = sp.sympify(expr)
-    if isinstance(expr, sp.Integer):
-        return create_type(default_int_type)
-    elif expr.is_real is False:
-        return create_type((np.zeros((1,), default_float_type) * 1j).dtype)
-    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
-        return create_type(default_float_type)
-    elif isinstance(expr, ResolvedFieldAccess):
-        return expr.field.dtype
-    elif isinstance(expr, pystencils.field.Field.AbstractAccess):
-        return expr.field.dtype
-    elif isinstance(expr, TypedSymbol):
-        return expr.dtype
-    elif isinstance(expr, sp.Symbol):
-        if symbol_type_dict:
-            return symbol_type_dict[expr.name]
-        else:
-            raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
-    elif isinstance(expr, cast_func):
-        return expr.args[1]
-    elif isinstance(expr, (vec_any, vec_all)):
-        return create_type("bool")
-    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
-        collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
-        collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
-        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
-            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
-        return collated_result_type
-    elif isinstance(expr, sp.Indexed):
-        typed_symbol = expr.base.label
-        return typed_symbol.dtype.base_type
-    elif isinstance(expr, (Boolean, BooleanFunction)):
-        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
-        result = create_type("bool")
-        vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
-        if vec_args:
-            result = VectorType(result, width=vec_args[0].width)
-        return result
-    elif isinstance(expr, sp.Pow):
-        base_type = get_type(expr.args[0])
-        if expr.exp.is_integer:
-            return base_type
-        else:
-            return collate_types([create_type(default_float_type), base_type])
-    elif isinstance(expr, (sp.Sum, sp.Product)):
-        return get_type(expr.args[0])
-    elif isinstance(expr, sp.Expr):
-        expr: sp.Expr
-        if expr.args:
-            types = tuple(get_type(a) for a in expr.args)
-            # collate_types checks numpy_dtype in the special cases
-            if any(not hasattr(t, 'numpy_dtype') for t in types):
-                forbid_collation_to_complex = False
-                forbid_collation_to_float = False
-            else:
-                forbid_collation_to_complex = expr.is_real is True
-                forbid_collation_to_float = expr.is_integer is True
-            return collate_types(
-                types,
-                forbid_collation_to_complex=forbid_collation_to_complex,
-                forbid_collation_to_float=forbid_collation_to_float,
-                default_float_type=default_float_type,
-                default_int_type=default_int_type)
-        else:
-            if expr.is_integer:
-                return create_type(default_int_type)
-            else:
-                return create_type(default_float_type)
-
-    raise NotImplementedError("Could not determine type for", expr, type(expr))
-############################# End This is basically our type system ##################################################
-
-
-sympy_version = sp.__version__.split('.')
-if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
-    # __setstate__ would bypass the contructor, so we remove it
-    sp.Number.__getstate__ = sp.Basic.__getstate__
-    del sp.Basic.__getstate__
-
-    class FunctorWithStoredKwargs:
-        def __init__(self, func, **kwargs):
-            self.func = func
-            self.kwargs = kwargs
-
-        def __call__(self, *args):
-            return self.func(*args, **self.kwargs)
-
-    # __reduce_ex__ would strip kwargs, so we override it
-    def basic_reduce_ex(self, protocol):
-        if hasattr(self, '__getnewargs_ex__'):
-            args, kwargs = self.__getnewargs_ex__()
-        else:
-            args, kwargs = self.__getnewargs__(), {}
-        if hasattr(self, '__getstate__'):
-            state = self.__getstate__()
-        else:
-            state = None
-        return FunctorWithStoredKwargs(type(self), **kwargs), args, state
-    sp.Number.__reduce_ex__ = sp.Basic.__reduce_ex__
-    sp.Basic.__reduce_ex__ = basic_reduce_ex
-
-
-class Type(sp.Atom):
-    # TODO: why is our type system dependent on sympy???
-    # DONE: ask Martin
-    # TODO: inherits from sp.Atom because of cast function (and maybe others)
-    def __new__(cls, *args, **kwargs):
-        return sp.Basic.__new__(cls)
-
-    def _sympystr(self, *args, **kwargs):
-        return str(self)
-
-
-class BasicType(Type):
-    # TODO: check if Type inheritance is needed
-    # TODO: should be a sensible interface to np.dtype
-    # TODO: read numpy docs (Jan)
-    @staticmethod
-    def numpy_name_to_c(name):
-        # TODO: this should be a free function
-        # TODO: also check if numpy has this functionality
-        # TODO: docs!!!
-        # TODO: is this C?
-        if name == 'float64':
-            return 'double'
-        elif name == 'float32':
-            return 'float'
-        elif name == 'complex64':
-            return 'ComplexFloat'
-        elif name == 'complex128':
-            return 'ComplexDouble'
-        elif name.startswith('int'):
-            width = int(name[len("int"):])
-            return f"int{width}_t"
-        elif name.startswith('uint'):
-            width = int(name[len("uint"):])
-            return f"uint{width}_t"
-        elif name == 'bool':
-            return 'bool'
-        else:
-            raise NotImplementedError(f"Can map numpy to C name for {name}")
-
-    def __init__(self, dtype, const=False):
-        # TODO: type hints
-        self.const = const
-        if isinstance(dtype, Type):
-            self._dtype = dtype.numpy_dtype  # TODO: wtf?
-        else:
-            self._dtype = np.dtype(dtype)
-        assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
-        assert self._dtype.hasobject is False
-        assert self._dtype.subdtype is None
-
-    def __getnewargs__(self):
-        return self.numpy_dtype, self.const
-
-    def __getnewargs_ex__(self):
-        return (self.numpy_dtype, self.const), {}
-
-    @property
-    def base_type(self): # TODO: what is base_type?
-        return None
-
-    @property
-    def numpy_dtype(self):
-        return self._dtype
-
-    @property
-    def sympy_dtype(self):
-        return getattr(sympy.codegen.ast, str(self.numpy_dtype))
-
-    @property
-    def item_size(self):  # TODO: what is this?
-        return 1
-
-    def is_int(self):
-        return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
-
-    def is_float(self):
-        return self.numpy_dtype in np.sctypes['float']
-
-    def is_uint(self):
-        return self.numpy_dtype in np.sctypes['uint']
-
-    def is_complex(self):
-        return self.numpy_dtype in np.sctypes['complex']
-
-    def is_other(self):
-        return self.numpy_dtype in np.sctypes['others']
-
-    @property
-    def base_name(self):  # TODO: name of the function is highly confusing
-        return BasicType.numpy_name_to_c(str(self._dtype))
-
-    def __str__(self):
-        result = BasicType.numpy_name_to_c(str(self._dtype))
-        if self.const:
-            result += " const"
-        return result
-
-    def __repr__(self):
-        return str(self)
-
-    def __eq__(self, other):
-        if not isinstance(other, BasicType):
-            return False
-        else:
-            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
-
-    def __hash__(self):
-        return hash(str(self))
-
-
-class VectorType(Type):
-    # TODO: check with rest
-    instruction_set = None
-
-    def __init__(self, base_type, width=4):
-        self._base_type = base_type
-        self.width = width
-
-    @property
-    def base_type(self):
-        return self._base_type
-
-    @property
-    def item_size(self):
-        return self.width * self.base_type.item_size
-
-    def __eq__(self, other):
-        if not isinstance(other, VectorType):
-            return False
-        else:
-            return (self.base_type, self.width) == (other.base_type, other.width)
-
-    def __str__(self):
-        if self.instruction_set is None:
-            return f"{self.base_type}[{self.width}]"
-        else:
-            if self.base_type == create_type("int64") or self.base_type == create_type("int32"):
-                return self.instruction_set['int']
-            elif self.base_type == create_type("float64"):
-                return self.instruction_set['double']
-            elif self.base_type == create_type("float32"):
-                return self.instruction_set['float']
-            elif self.base_type == create_type("bool"):
-                return self.instruction_set['bool']
-            else:
-                raise NotImplementedError()
-
-    def __hash__(self):
-        return hash((self.base_type, self.width))
-
-    def __getnewargs__(self):
-        return self._base_type, self.width
-
-    def __getnewargs_ex__(self):
-        return (self._base_type, self.width), {}
-
-
-class PointerType(Type):
-    # TODO: rename to FieldType
-    def __init__(self, base_type, const=False, restrict=True):
-        self._base_type = base_type
-        self.const = const
-        self.restrict = restrict
-
-    def __getnewargs__(self):
-        return self.base_type, self.const, self.restrict
-
-    def __getnewargs_ex__(self):
-        return (self.base_type, self.const, self.restrict), {}
-
-    @property
-    def alias(self):
-        return not self.restrict
-
-    @property
-    def base_type(self):
-        return self._base_type
-
-    @property
-    def item_size(self):
-        return self.base_type.item_size
-
-    def __eq__(self, other):
-        if not isinstance(other, PointerType):
-            return False
-        else:
-            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
-
-    def __str__(self):
-        components = [str(self.base_type), '*']
-        if self.restrict:
-            components.append('RESTRICT')
-        if self.const:
-            components.append("const")
-        return " ".join(components)
-
-    def __repr__(self):
-        return str(self)
-
-    def __hash__(self):
-        return hash((self._base_type, self.const, self.restrict))
-
-
-class StructType:
-    # TODO: Docs. This is a struct. A list of types (with C offsets)
-    def __init__(self, numpy_type, const=False):
-        self.const = const
-        self._dtype = np.dtype(numpy_type)
-
-    def __getnewargs__(self):
-        return self.numpy_dtype, self.const
-
-    def __getnewargs_ex__(self):
-        return (self.numpy_dtype, self.const), {}
-
-    @property
-    def base_type(self):
-        return None
-
-    @property
-    def numpy_dtype(self):
-        return self._dtype
-
-    @property
-    def item_size(self):
-        return self.numpy_dtype.itemsize
-
-    def get_element_offset(self, element_name):
-        return self.numpy_dtype.fields[element_name][1]
-
-    def get_element_type(self, element_name):
-        np_element_type = self.numpy_dtype.fields[element_name][0]
-        return BasicType(np_element_type, self.const)
-
-    def has_element(self, element_name):
-        return element_name in self.numpy_dtype.fields
-
-    def __eq__(self, other):
-        if not isinstance(other, StructType):
-            return False
-        else:
-            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
-
-    def __str__(self):
-        # structs are handled byte-wise
-        result = "uint8_t"
-        if self.const:
-            result += " const"
-        return result
-
-    def __repr__(self):
-        return str(self)
-
-    def __hash__(self):
-        return hash((self.numpy_dtype, self.const))
-
-
-class TypedImaginaryUnit(TypedSymbol):
-    # TODO: why is this an extra class???
-    # TODO: remove?
-    def __new__(cls, *args, **kwds):
-        obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
-        return obj
-
-    def __new_stage2__(cls, dtype):
-        obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
-                                                      "_i",
-                                                      dtype,
-                                                      imaginary=True)
-        return obj
-
-    headers = ['"cuda_complex.hpp"']
-
-    __xnew__ = staticmethod(__new_stage2__)
-    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
-
-    def __getnewargs__(self):
-        return (self.dtype,)
-
-    def __getnewargs_ex__(self):
-        return (self.dtype,), {}
diff --git a/pystencils/field.py b/pystencils/field.py
index dcb33ca99..146a1cacb 100644
--- a/pystencils/field.py
+++ b/pystencils/field.py
@@ -13,8 +13,8 @@ from sympy.core.cache import cacheit
 
 import pystencils
 from pystencils.alignedarray import aligned_empty
-from pystencils.data_types import StructType, TypedSymbol, create_type
-from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol
+from pystencils.typing import StructType, TypedSymbol, create_type
+from pystencils.typing.typed_sympy import FieldShapeSymbol, FieldStrideSymbol
 from pystencils.stencil import (
     direction_string_to_offset, inverse_direction, offset_to_direction_string)
 from pystencils.sympyextensions import is_integer_sequence
@@ -137,6 +137,7 @@ def fields(description=None, index_dimensions=0, layout=None, field_type=FieldTy
         return result
 
 
+# TODO why this??? Why abstarct?
 class AbstractField:
     class AbstractAccess:
         pass
@@ -472,27 +473,6 @@ class Field(AbstractField):
         assert FieldType.is_custom(self)
         return Field.Access(self, offset, index, is_absolute_access=True)
 
-    def interpolated_access(self,
-                            offset: Tuple,
-                            interpolation_mode='linear',
-                            address_mode='BORDER',
-                            allow_textures=True):
-        """Provides access to field values at non-integer positions
-
-        ``interpolated_access`` is similar to :func:`Field.absolute_access` except that
-        it allows non-integer offsets and automatic handling of out-of-bound accesses.
-
-        :param offset:              Tuple of spatial coordinates (can be floats)
-        :param interpolation_mode:  One of :class:`pystencils.interpolation_astnodes.InterpolationMode`
-        :param address_mode:        How boundaries are handled can be 'border', 'wrap', 'mirror', 'clamp'
-        :param allow_textures:      Allow implementation by texture accesses on GPUs
-        """
-        from pystencils.interpolation_astnodes import Interpolator
-        return Interpolator(self,
-                            interpolation_mode,
-                            address_mode,
-                            allow_textures=allow_textures).at(offset)
-
     def staggered_access(self, offset, index=None):
         """If this field is a staggered field, it can be accessed using half-integer offsets.
         For example, an offset of ``(0, sp.Rational(1,2))`` or ``"E"`` corresponds to the staggered point to the east
diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py
index 67adac657..a13297e0d 100644
--- a/pystencils/gpucuda/cudajit.py
+++ b/pystencils/gpucuda/cudajit.py
@@ -2,7 +2,7 @@ import numpy as np
 
 from pystencils.backends.cbackend import get_headers
 from pystencils.backends.cuda_backend import generate_cuda
-from pystencils.data_types import StructType
+from pystencils.typing import StructType
 from pystencils.field import FieldType
 from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
 from pystencils.kernel_wrapper import KernelWrapper
diff --git a/pystencils/gpucuda/indexing.py b/pystencils/gpucuda/indexing.py
index ae5db1b98..6f30b0a1c 100644
--- a/pystencils/gpucuda/indexing.py
+++ b/pystencils/gpucuda/indexing.py
@@ -5,7 +5,7 @@ import sympy as sp
 from sympy.core.cache import cacheit
 
 from pystencils.astnodes import Block, Conditional
-from pystencils.data_types import TypedSymbol, create_type
+from pystencils.typing import TypedSymbol, create_type
 from pystencils.integer_functions import div_ceil, div_floor
 from pystencils.slicing import normalize_slice
 from pystencils.sympyextensions import is_integer_sequence, prod
diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py
index 39808eab0..96399ae1c 100644
--- a/pystencils/gpucuda/kernelcreation.py
+++ b/pystencils/gpucuda/kernelcreation.py
@@ -1,13 +1,13 @@
 import numpy as np
 
 from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
-from pystencils.data_types import StructType, TypedSymbol
+from pystencils.typing import StructType, TypedSymbol, add_types
 from pystencils.field import Field, FieldType
 from pystencils.enums import Target, Backend
 from pystencils.gpucuda.cudajit import make_python_function
 from pystencils.gpucuda.indexing import BlockIndexing
 from pystencils.transformations import (
-    add_types, get_base_buffer_index, get_common_shape, parse_base_pointer_info,
+    get_base_buffer_index, get_common_shape, parse_base_pointer_info,
     resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols)
 
 
diff --git a/pystencils/integer_functions.py b/pystencils/integer_functions.py
index efdaaaecf..1975a877e 100644
--- a/pystencils/integer_functions.py
+++ b/pystencils/integer_functions.py
@@ -1,7 +1,8 @@
+# TODO move to a module functions
 import numpy as np
 import sympy as sp
 
-from pystencils.data_types import cast_func, collate_types, create_type, get_type_of_expression
+from pystencils.typing import CastFunc, collate_types, create_type, get_type_of_expression
 from pystencils.sympyextensions import is_integer_sequence
 
 
@@ -12,9 +13,9 @@ class IntegerFunctionTwoArgsMixIn(sp.Function):
         args = []
         for a in (arg1, arg2):
             if isinstance(a, sp.Number) or isinstance(a, int):
-                args.append(cast_func(a, create_type("int")))
+                args.append(CastFunc(a, create_type("int")))
             elif isinstance(a, np.generic):
-                args.append(cast_func(a, a.dtype))
+                args.append(CastFunc(a, a.dtype))
             else:
                 args.append(a)
 
diff --git a/pystencils/integer_set_analysis.py b/pystencils/integer_set_analysis.py
index 82af791ca..2e37c643f 100644
--- a/pystencils/integer_set_analysis.py
+++ b/pystencils/integer_set_analysis.py
@@ -4,7 +4,7 @@ import islpy as isl
 import sympy as sp
 
 import pystencils.astnodes as ast
-from pystencils.transformations import parents_of_type
+from pystencils.typing import parents_of_type
 
 
 def remove_brackets(s):
diff --git a/pystencils/kerncraft_coupling/generate_benchmark.py b/pystencils/kerncraft_coupling/generate_benchmark.py
index 955098d2c..8d8d7d1da 100644
--- a/pystencils/kerncraft_coupling/generate_benchmark.py
+++ b/pystencils/kerncraft_coupling/generate_benchmark.py
@@ -8,7 +8,7 @@ from jinja2 import Environment, PackageLoader, StrictUndefined
 from pystencils.astnodes import PragmaBlock
 from pystencils.backends.cbackend import generate_c, get_headers
 from pystencils.cpu.cpujit import get_compiler_config, run_compile_step
-from pystencils.data_types import get_base_type
+from pystencils.typing import get_base_type
 from pystencils.enums import Backend
 from pystencils.include import get_pystencils_include_path
 from pystencils.integer_functions import modulo_ceil
diff --git a/pystencils/kerncraft_coupling/kerncraft_interface.py b/pystencils/kerncraft_coupling/kerncraft_interface.py
index 61867e518..bfb5a2d6a 100644
--- a/pystencils/kerncraft_coupling/kerncraft_interface.py
+++ b/pystencils/kerncraft_coupling/kerncraft_interface.py
@@ -21,7 +21,7 @@ from pystencils.sympyextensions import count_operations_in_ast
 from pystencils.transformations import filtered_tree_iteration
 from pystencils.utils import DotDict
 from pystencils.cpu.kernelcreation import add_openmp
-from pystencils.data_types import get_base_type
+from pystencils.typing.utilities import get_base_type
 from pystencils.sympyextensions import prod
 
 
diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py
new file mode 100644
index 000000000..55f141201
--- /dev/null
+++ b/pystencils/kernel_contrains_check.py
@@ -0,0 +1,150 @@
+from collections import namedtuple, defaultdict
+
+import numpy as np
+
+import pystencils.integer_functions
+import sympy as sp
+from pystencils import astnodes as ast, TypedSymbol
+from pystencils.bit_masks import flag_cond
+from pystencils.field import AbstractField
+from pystencils.transformations import NestedScopes
+from pystencils.typing import CastFunc, create_type, get_type_of_expression, collate_types
+from sympy.logic.boolalg import BooleanFunction
+
+
+class KernelConstraintsCheck:
+    # TODO: Logs
+    # TODO: specification
+    """Checks if the input to create_kernel is valid.
+
+    Test the following conditions:
+
+    - SSA Form for pure symbols:
+        -  Every pure symbol may occur only once as left-hand-side of an assignment
+        -  Every pure symbol that is read, may not be written to later
+    - Independence / Parallelization condition:
+        - a field that is written may only be read at exact the same spatial position
+
+    (Pure symbols are symbols that are not Field.Accesses)
+    """
+    FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
+
+    def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True):
+        self._type_for_symbol = type_for_symbol
+
+        self.scopes = NestedScopes()
+        self._field_writes = defaultdict(set)
+        self.fields_read = set()
+        self.check_independence_condition = check_independence_condition
+        self.check_double_write_condition = check_double_write_condition
+
+    def process_assignment(self, assignment):
+        # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
+        new_rhs = self.process_expression(assignment.rhs)
+        new_lhs = self._process_lhs(assignment.lhs)
+        return ast.SympyAssignment(new_lhs, new_rhs)
+
+    def process_expression(self, rhs, type_constants=True):
+
+        self._update_accesses_rhs(rhs)
+        if isinstance(rhs, AbstractField.AbstractAccess):
+            self.fields_read.add(rhs.field)
+            self.fields_read.update(rhs.indirect_addressing_fields)
+            return rhs
+        # TODO remove this
+        #elif isinstance(rhs, ImaginaryUnit):
+        #    return TypedImaginaryUnit(create_type(self._type_for_symbol['_complex_type']))
+        elif isinstance(rhs, TypedSymbol):
+            return rhs
+        elif isinstance(rhs, sp.Symbol):
+            return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
+        elif type_constants and isinstance(rhs, np.generic):
+            return CastFunc(rhs, create_type(rhs.dtype))
+        elif type_constants and isinstance(rhs, sp.Number):
+            return CastFunc(rhs, create_type(self._type_for_symbol['_constant']))
+        # Very important that this clause comes before BooleanFunction
+        elif isinstance(rhs, sp.Equality):
+            if isinstance(rhs.args[1], sp.Number):
+                return sp.Equality(
+                    self.process_expression(rhs.args[0], type_constants),
+                    rhs.args[1])
+            else:
+                return sp.Equality(
+                    self.process_expression(rhs.args[0], type_constants),
+                    self.process_expression(rhs.args[1], type_constants))
+        elif isinstance(rhs, CastFunc):
+            return CastFunc(
+                self.process_expression(rhs.args[0], type_constants=False),
+                rhs.dtype)
+        elif isinstance(rhs, BooleanFunction) or \
+                type(rhs) in pystencils.integer_functions.__dict__.values():
+            new_args = [self.process_expression(a, type_constants) for a in rhs.args]
+            types_of_expressions = [get_type_of_expression(a) for a in new_args]
+            arg_type = collate_types(types_of_expressions, forbid_collation_to_float=True)
+            new_args = [a if not hasattr(a, 'dtype') or a.dtype == arg_type
+                        else CastFunc(a, arg_type)
+                        for a in new_args]
+            return rhs.func(*new_args)
+        elif isinstance(rhs, flag_cond):
+            #   do not process the arguments to the bit shift - they must remain integers
+            processed_args = (self.process_expression(a) for a in rhs.args[2:])
+            return flag_cond(rhs.args[0], rhs.args[1], *processed_args)
+        elif isinstance(rhs, sp.Mul):
+            new_args = [
+                self.process_expression(arg, type_constants)
+                if arg not in (-1, 1) else arg for arg in rhs.args
+            ]
+            return rhs.func(*new_args) if new_args else rhs
+        elif isinstance(rhs, sp.Indexed):
+            return rhs
+        else:
+            if isinstance(rhs, sp.Pow):
+                # don't process exponents -> they should remain integers
+                return sp.Pow(
+                    self.process_expression(rhs.args[0], type_constants),
+                    rhs.args[1])
+            else:
+                new_args = [
+                    self.process_expression(arg, type_constants)
+                    for arg in rhs.args
+                ]
+                return rhs.func(*new_args) if new_args else rhs
+
+    @property
+    def fields_written(self):
+        return set(k.field for k, v in self._field_writes.items() if len(v))
+
+    def _process_lhs(self, lhs):
+        assert isinstance(lhs, sp.Symbol)
+        self._update_accesses_lhs(lhs)
+        if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
+            return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
+        else:
+            return lhs
+
+    def _update_accesses_lhs(self, lhs):
+        if isinstance(lhs, AbstractField.AbstractAccess):
+            fai = self.FieldAndIndex(lhs.field, lhs.index)
+            self._field_writes[fai].add(lhs.offsets)
+            if self.check_double_write_condition and len(self._field_writes[fai]) > 1:
+                raise ValueError(
+                    f"Field {lhs.field.name} is written at two different locations")
+        elif isinstance(lhs, sp.Symbol):
+            if self.scopes.is_defined_locally(lhs):
+                raise ValueError(f"Assignments not in SSA form, multiple assignments to {lhs.name}")
+            if lhs in self.scopes.free_parameters:
+                raise ValueError(f"Symbol {lhs.name} is written, after it has been read")
+            self.scopes.define_symbol(lhs)
+
+    def _update_accesses_rhs(self, rhs):
+        if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition:
+            writes = self._field_writes[self.FieldAndIndex(
+                rhs.field, rhs.index)]
+            for write_offset in writes:
+                assert len(writes) == 1
+                if write_offset != rhs.offsets:
+                    raise ValueError("Violation of loop independence condition. Field "
+                                     "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
+            self.fields_read.add(rhs.field)
+        elif isinstance(rhs, sp.Symbol):
+            self.scopes.access_symbol(rhs)
\ No newline at end of file
diff --git a/pystencils/rng.py b/pystencils/rng.py
index 7c4f894f9..c75c3f972 100644
--- a/pystencils/rng.py
+++ b/pystencils/rng.py
@@ -2,7 +2,7 @@ import copy
 import numpy as np
 import sympy as sp
 
-from pystencils.data_types import TypedSymbol, cast_func
+from pystencils.typing import TypedSymbol, CastFunc
 from pystencils.astnodes import LoopOverCoordinate
 from pystencils.backends.cbackend import CustomCodeNode
 from pystencils.sympyextensions import fast_subs
@@ -47,11 +47,11 @@ class RNGBase(CustomCodeNode):
     def get_code(self, dialect, vector_instruction_set, print_arg):
         code = "\n"
         for r in self.result_symbols:
-            if vector_instruction_set and not self.args[1].atoms(cast_func):
+            if vector_instruction_set and not self.args[1].atoms(CastFunc):
                 # this vector RNG has become scalar through substitution
                 code += f"{r.dtype} {r.name};\n"
             else:
-                code += f"{vector_instruction_set[r.dtype.base_name] if vector_instruction_set else r.dtype} " + \
+                code += f"{vector_instruction_set[r.dtype.c_name] if vector_instruction_set else r.dtype} " + \
                         f"{r.name};\n"
         args = [print_arg(a) for a in self.args] + ['' + r.name for r in self.result_symbols]
         code += (self._name + "(" + ", ".join(args) + ");\n")
diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index f63328d81..1746a8b99 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -10,7 +10,7 @@ from sympy.functions import Abs
 from sympy.core.numbers import Zero
 
 from pystencils.assignment import Assignment
-from pystencils.data_types import cast_func, get_type_of_expression, PointerType, VectorType
+from pystencils.typing import CastFunc, get_type_of_expression, PointerType, VectorType
 from pystencils.kernelparameters import FieldPointerSymbol
 
 T = TypeVar('T')
@@ -519,7 +519,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
             visit_children = False
         elif t.is_integer:
             pass
-        elif isinstance(t, cast_func):
+        elif isinstance(t, CastFunc):
             visit_children = False
             visit(t.args[0])
         elif t.func is fast_sqrt:
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index 0c6d00658..beb5d287e 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -1,28 +1,21 @@
 import hashlib
 import pickle
 import warnings
-from typing import List, Dict
-from collections import OrderedDict, defaultdict, namedtuple
+from collections import OrderedDict
 from copy import deepcopy
 from types import MappingProxyType
 
-import numpy as np
 import sympy as sp
-from sympy.core.numbers import ImaginaryUnit
-from sympy.logic.boolalg import Boolean, BooleanFunction
 
 import pystencils.astnodes as ast
-import pystencils.integer_functions
 from pystencils.assignment import Assignment
-from pystencils.data_types import (
-    PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type,
-    get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
+from pystencils.typing import (
+    PointerType, StructType, TypedSymbol, get_base_type, ReinterpretCastFunc, get_next_parent_of_type, parents_of_type)
 from pystencils.field import AbstractField, Field, FieldType
-from pystencils.kernelparameters import FieldPointerSymbol
+from pystencils.typing import FieldPointerSymbol
 from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.slicing import normalize_slice
 from pystencils.integer_functions import int_div
-from pystencils.bit_masks import flag_cond
 
 
 class NestedScopes:
@@ -379,7 +372,10 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
     return base_buffer_index * buffer_index_size
 
 
-def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
+def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=None):
+
+    if read_only_field_names is None:
+        read_only_field_names = set()
 
     def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
         if isinstance(expr, AbstractField.AbstractAccess):
@@ -522,7 +518,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=None,
                 if isinstance(accessed_field_name, sp.Symbol):
                     accessed_field_name = accessed_field_name.name
                 new_type = field_access.field.dtype.get_element_type(accessed_field_name)
-                result = reinterpret_cast_func(result, new_type)
+                result = ReinterpretCastFunc(result, new_type)
 
             return visit_sympy_expr(result, enclosing_block, sympy_assignment)
         else:
@@ -804,298 +800,6 @@ def cleanup_blocks(node: ast.Node) -> None:
             cleanup_blocks(a)
 
 
-class KernelConstraintsCheck:
-    # TODO: Logs
-    # TODO: specification
-    """Checks if the input to create_kernel is valid.
-
-    Test the following conditions:
-
-    - SSA Form for pure symbols:
-        -  Every pure symbol may occur only once as left-hand-side of an assignment
-        -  Every pure symbol that is read, may not be written to later
-    - Independence / Parallelization condition:
-        - a field that is written may only be read at exact the same spatial position
-
-    (Pure symbols are symbols that are not Field.Accesses)
-    """
-    FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
-
-    def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True):
-        self._type_for_symbol = type_for_symbol
-
-        self.scopes = NestedScopes()
-        self._field_writes = defaultdict(set)
-        self.fields_read = set()
-        self.check_independence_condition = check_independence_condition
-        self.check_double_write_condition = check_double_write_condition
-
-    def process_assignment(self, assignment):
-        # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
-        new_rhs = self.process_expression(assignment.rhs)
-        new_lhs = self._process_lhs(assignment.lhs)
-        return ast.SympyAssignment(new_lhs, new_rhs)
-
-    def process_expression(self, rhs, type_constants=True):
-
-        self._update_accesses_rhs(rhs)
-        if isinstance(rhs, AbstractField.AbstractAccess):
-            self.fields_read.add(rhs.field)
-            self.fields_read.update(rhs.indirect_addressing_fields)
-            return rhs
-        elif isinstance(rhs, ImaginaryUnit):
-            return TypedImaginaryUnit(create_type(self._type_for_symbol['_complex_type']))
-        elif isinstance(rhs, TypedSymbol):
-            return rhs
-        elif isinstance(rhs, sp.Symbol):
-            return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
-        elif type_constants and isinstance(rhs, np.generic):
-            return cast_func(rhs, create_type(rhs.dtype))
-        elif type_constants and isinstance(rhs, sp.Number):
-            return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
-        # Very important that this clause comes before BooleanFunction
-        elif isinstance(rhs, sp.Equality):
-            if isinstance(rhs.args[1], sp.Number):
-                return sp.Equality(
-                    self.process_expression(rhs.args[0], type_constants),
-                    rhs.args[1])
-            else:
-                return sp.Equality(
-                    self.process_expression(rhs.args[0], type_constants),
-                    self.process_expression(rhs.args[1], type_constants))
-        elif isinstance(rhs, cast_func):
-            return cast_func(
-                self.process_expression(rhs.args[0], type_constants=False),
-                rhs.dtype)
-        elif isinstance(rhs, BooleanFunction) or \
-                type(rhs) in pystencils.integer_functions.__dict__.values():
-            new_args = [self.process_expression(a, type_constants) for a in rhs.args]
-            types_of_expressions = [get_type_of_expression(a) for a in new_args]
-            arg_type = collate_types(types_of_expressions, forbid_collation_to_float=True)
-            new_args = [a if not hasattr(a, 'dtype') or a.dtype == arg_type
-                        else cast_func(a, arg_type)
-                        for a in new_args]
-            return rhs.func(*new_args)
-        elif isinstance(rhs, flag_cond):
-            #   do not process the arguments to the bit shift - they must remain integers
-            processed_args = (self.process_expression(a) for a in rhs.args[2:])
-            return flag_cond(rhs.args[0], rhs.args[1], *processed_args)
-        elif isinstance(rhs, sp.Mul):
-            new_args = [
-                self.process_expression(arg, type_constants)
-                if arg not in (-1, 1) else arg for arg in rhs.args
-            ]
-            return rhs.func(*new_args) if new_args else rhs
-        elif isinstance(rhs, sp.Indexed):
-            return rhs
-        else:
-            if isinstance(rhs, sp.Pow):
-                # don't process exponents -> they should remain integers
-                return sp.Pow(
-                    self.process_expression(rhs.args[0], type_constants),
-                    rhs.args[1])
-            else:
-                new_args = [
-                    self.process_expression(arg, type_constants)
-                    for arg in rhs.args
-                ]
-                return rhs.func(*new_args) if new_args else rhs
-
-    @property
-    def fields_written(self):
-        return set(k.field for k, v in self._field_writes.items() if len(v))
-
-    def _process_lhs(self, lhs):
-        assert isinstance(lhs, sp.Symbol)
-        self._update_accesses_lhs(lhs)
-        if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
-            return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
-        else:
-            return lhs
-
-    def _update_accesses_lhs(self, lhs):
-        if isinstance(lhs, AbstractField.AbstractAccess):
-            fai = self.FieldAndIndex(lhs.field, lhs.index)
-            self._field_writes[fai].add(lhs.offsets)
-            if self.check_double_write_condition and len(self._field_writes[fai]) > 1:
-                raise ValueError(
-                    f"Field {lhs.field.name} is written at two different locations")
-        elif isinstance(lhs, sp.Symbol):
-            if self.scopes.is_defined_locally(lhs):
-                raise ValueError(f"Assignments not in SSA form, multiple assignments to {lhs.name}")
-            if lhs in self.scopes.free_parameters:
-                raise ValueError(f"Symbol {lhs.name} is written, after it has been read")
-            self.scopes.define_symbol(lhs)
-
-    def _update_accesses_rhs(self, rhs):
-        if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition:
-            writes = self._field_writes[self.FieldAndIndex(
-                rhs.field, rhs.index)]
-            for write_offset in writes:
-                assert len(writes) == 1
-                if write_offset != rhs.offsets:
-                    raise ValueError("Violation of loop independence condition. Field "
-                                     "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
-            self.fields_read.add(rhs.field)
-        elif isinstance(rhs, sp.Symbol):
-            self.scopes.access_symbol(rhs)
-
-
-def add_types(eqs: List[Assignment], type_for_symbol: Dict[sp.Symbol, np.dtype], check_independence_condition: bool,
-              check_double_write_condition: bool=True):
-    """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
-
-    Additionally returns sets of all fields which are read/written
-
-    Args:
-        eqs: list of equations
-        type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
-        check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
-                                      kernels
-
-    Returns:
-        ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
-         list of equations where symbols have been replaced by typed symbols
-    """
-    if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'):
-        type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
-
-    type_for_symbol = adjust_c_single_precision_type(type_for_symbol)
-
-    # TODO what does this do????
-    # TODO: ask Martin
-    check = KernelConstraintsCheck(type_for_symbol, check_independence_condition,
-                                   check_double_write_condition=check_double_write_condition)
-
-    # TODO: check if this adds only types to leave nodes of AST, get type info
-    def visit(obj):
-        if isinstance(obj, (list, tuple)):
-            return [visit(e) for e in obj]
-        if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
-            return check.process_assignment(obj)
-        elif isinstance(obj, ast.Conditional):
-            check.scopes.push()
-            # Disable double write check inside conditionals
-            # would be triggered by e.g. in-kernel boundaries
-            old_double_write = check.check_double_write_condition
-            check.check_double_write_condition = False
-            false_block = None if obj.false_block is None else visit(
-                obj.false_block)
-            result = ast.Conditional(check.process_expression(
-                obj.condition_expr, type_constants=False),
-                true_block=visit(obj.true_block),
-                false_block=false_block)
-            check.check_double_write_condition = old_double_write
-            check.scopes.pop()
-            return result
-        elif isinstance(obj, ast.Block):
-            check.scopes.push()
-            result = ast.Block([visit(e) for e in obj.args])
-            check.scopes.pop()
-            return result
-        elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
-            return obj
-        else:
-            raise ValueError("Invalid object in kernel " + str(type(obj)))
-
-    typed_equations = visit(eqs)
-
-    return check.fields_read, check.fields_written, typed_equations
-
-
-def insert_casts(node):
-    """Checks the types and inserts casts and pointer arithmetic where necessary.
-
-    Args:
-        node: the head node of the ast
-
-    Returns:
-        modified AST
-    """
-    def cast(zipped_args_types, target_dtype):
-        """
-        Adds casts to the arguments if their type differs from the target type
-        :param zipped_args_types: a zipped list of args and types
-        :param target_dtype: The target data type
-        :return: args with possible casts
-        """
-        casted_args = []
-        for argument, data_type in zipped_args_types:
-            if data_type.numpy_dtype != target_dtype.numpy_dtype:  # ignoring const
-                casted_args.append(cast_func(argument, target_dtype))
-            else:
-                casted_args.append(argument)
-        return casted_args
-
-    def pointer_arithmetic(expr_args):
-        """
-        Creates a valid pointer arithmetic function
-        :param expr_args: Arguments of the add expression
-        :return: pointer_arithmetic_func
-        """
-        pointer = None
-        new_args = []
-        for arg, data_type in expr_args:
-            if data_type.func is PointerType:
-                assert pointer is None
-                pointer = arg
-        for arg, data_type in expr_args:
-            if arg != pointer:
-                assert data_type.is_int() or data_type.is_uint()
-                new_args.append(arg)
-        new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
-        return pointer_arithmetic_func(pointer, new_args)
-
-    if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func):
-        return node
-    args = []
-    for arg in node.args:
-        args.append(insert_casts(arg))
-    # TODO indexed, LoopOverCoordinate
-    if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
-        # TODO optimize pow, don't cast integer on double
-        types = [get_type_of_expression(arg) for arg in args]
-        assert len(types) > 0
-        # Never ever, ever collate to float type for boolean functions!
-        target = collate_types(types, forbid_collation_to_float=isinstance(node.func, BooleanFunction))
-        zipped = list(zip(args, types))
-        if target.func is PointerType:
-            assert node.func is sp.Add
-            return pointer_arithmetic(zipped)
-        else:
-            return node.func(*cast(zipped, target))
-    elif node.func is ast.SympyAssignment:
-        lhs = args[0]
-        rhs = args[1]
-        target = get_type_of_expression(lhs)
-        if target.func is PointerType:
-            return node.func(*args)  # TODO fix, not complete
-        else:
-            return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
-    elif node.func is ast.ResolvedFieldAccess:
-        return node
-    elif node.func is ast.Block:
-        for old_arg, new_arg in zip(node.args, args):
-            node.replace(old_arg, new_arg)
-        return node
-    elif node.func is ast.LoopOverCoordinate:
-        for old_arg, new_arg in zip(node.args, args):
-            node.replace(old_arg, new_arg)
-        return node
-    elif node.func is sp.Piecewise:
-        expressions = [expr for (expr, _) in args]
-        types = [get_type_of_expression(expr) for expr in expressions]
-        target = collate_types(types)
-        zipped = list(zip(expressions, types))
-        casted_expressions = cast(zipped, target)
-        args = [
-            arg.func(*[expr, arg.cond])
-            for (arg, expr) in zip(args, casted_expressions)
-        ]
-
-    return node.func(*args)
-
-
 def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, include_first=True) -> None:
     """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or
        first and last element"""
@@ -1118,73 +822,6 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, i
 
 
 # --------------------------------------- Helper Functions -------------------------------------------------------------
-
-
-def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='int64'):
-    """
-    Creates a default symbol name to type mapping.
-    If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
-
-    Args:
-        eqs: list of equations
-        default_type: the type for non-boolean symbols
-    Returns:
-        dictionary, mapping symbol name to type
-    """
-    result = defaultdict(lambda: default_type)
-    if hasattr(default_type, 'numpy_dtype'):
-        result['_complex_type'] = (np.zeros((1,), default_type.numpy_dtype) * 1j).dtype
-    else:
-        result['_complex_type'] = (np.zeros((1,), default_type) * 1j).dtype
-    for eq in eqs:
-        if isinstance(eq, ast.Conditional):
-            result.update(typing_from_sympy_inspection(eq.true_block.args))
-            if eq.false_block:
-                result.update(typing_from_sympy_inspection(
-                    eq.false_block.args))
-        elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment):
-            continue
-        else:
-            from pystencils.cpu.vectorization import vec_all, vec_any
-            if isinstance(eq.rhs, (vec_all, vec_any)):
-                result[eq.lhs.name] = "bool"
-            # problematic case here is when rhs is a symbol: then it is impossible to decide here without
-            # further information what type the left hand side is - default fallback is the dict value then
-            if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
-                result[eq.lhs.name] = "bool"
-            try:
-                result[eq.lhs.name] = get_type_of_expression(eq.rhs,
-                                                             default_float_type=default_type,
-                                                             default_int_type=default_int_type,
-                                                             symbol_type_dict=result)
-            except Exception:
-                pass  # gracefully fail in case get_type_of_expression cannot determine type
-    return result
-
-
-def get_next_parent_of_type(node, parent_type):
-    """Returns the next parent node of given type or None, if root is reached.
-
-    Traverses the AST nodes parents until a parent of given type was found.
-    If no such parent is found, None is returned
-    """
-    parent = node.parent
-    while parent is not None:
-        if isinstance(parent, parent_type):
-            return parent
-        parent = parent.parent
-    return None
-
-
-def parents_of_type(node, parent_type, include_current=False):
-    """Generator for all parent nodes of given type"""
-    parent = node if include_current else node.parent
-    while parent is not None:
-        if isinstance(parent, parent_type):
-            yield parent
-        parent = parent.parent
-
-
 def get_optimal_loop_ordering(fields):
     """
     Determines the optimal loop order for a given set of fields.
@@ -1340,16 +977,3 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
         inner_loop.start = block_ctr
         inner_loop.stop = stop
     return coordinates_taken_into_account
-
-
-def adjust_c_single_precision_type(type_for_symbol):
-    """Replaces every occurrence of 'float' with 'single' to enforce the numpy single precision type."""
-    def single_factory():
-        return "single"
-
-    for symbol in type_for_symbol:
-        if type_for_symbol[symbol] == "float":
-            type_for_symbol[symbol] = single_factory()
-    if hasattr(type_for_symbol, "default_factory") and type_for_symbol.default_factory() == "float":
-        type_for_symbol.default_factory = single_factory
-    return type_for_symbol
diff --git a/pystencils/typing/__init__.py b/pystencils/typing/__init__.py
new file mode 100644
index 000000000..55fb731c0
--- /dev/null
+++ b/pystencils/typing/__init__.py
@@ -0,0 +1,4 @@
+from pystencils.typing.utilities import *
+from pystencils.typing.types import *
+from pystencils.typing.typed_sympy import *
+from pystencils.typing.cast_functions import *
diff --git a/pystencils/typing/cast_functions.py b/pystencils/typing/cast_functions.py
new file mode 100644
index 000000000..0c2da8d20
--- /dev/null
+++ b/pystencils/typing/cast_functions.py
@@ -0,0 +1,120 @@
+import numpy as np
+import sympy as sp
+from sympy.logic.boolalg import Boolean
+
+from pystencils.typing.types import AbstractType, BasicType, create_type
+from pystencils.typing.typed_sympy import TypedSymbol
+
+
+class CastFunc(sp.Function):
+    # TODO: documentation
+    # TODO: move function to `functions.py`
+    is_Atom = True
+
+    def __new__(cls, *args, **kwargs):
+        if len(args) != 2:
+            pass
+        expr, dtype, *other_args = args
+        if not isinstance(dtype, AbstractType):
+            dtype = create_type(dtype)
+        # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
+        # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
+        # to problems when for example comparing cast_func's for equality
+        #
+        # lhs = bitwise_and(a, cast_func(1, 'int'))
+        # rhs = cast_func(0, 'int')
+        # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
+        # -> thus a separate class boolean_cast_func is introduced
+        if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType(bool)):
+            cls = BooleanCastFunc
+
+        return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
+
+    @property
+    def canonical(self):
+        if hasattr(self.args[0], 'canonical'):
+            return self.args[0].canonical
+        else:
+            raise NotImplementedError()
+
+    @property
+    def is_commutative(self):
+        return self.args[0].is_commutative
+
+    def _eval_evalf(self, *args, **kwargs):
+        return self.args[0].evalf()
+
+    @property
+    def dtype(self):
+        return self.args[1]
+
+    @property
+    def is_integer(self):
+        """
+        Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
+
+        For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
+        """
+        if hasattr(self.dtype, 'numpy_dtype'):
+            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
+        else:
+            return super().is_integer
+
+    @property
+    def is_negative(self):
+        """
+        See :func:`.TypedSymbol.is_integer`
+        """
+        if hasattr(self.dtype, 'numpy_dtype'):
+            if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
+                return False
+
+        return super().is_negative
+
+    @property
+    def is_nonnegative(self):
+        """
+        See :func:`.TypedSymbol.is_integer`
+        """
+        if self.is_negative is False:
+            return True
+        else:
+            return super().is_nonnegative
+
+    @property
+    def is_real(self):
+        """
+        See :func:`.TypedSymbol.is_integer`
+        """
+        if hasattr(self.dtype, 'numpy_dtype'):
+            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
+                   np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
+                   super().is_real
+        else:
+            return super().is_real
+
+
+class BooleanCastFunc(CastFunc, Boolean):
+    # TODO: documentation
+    pass
+
+
+class VectorMemoryAccess(CastFunc):
+    # TODO: documentation
+    # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride
+    nargs = (6,)
+
+
+class ReinterpretCastFunc(CastFunc):
+    # TODO: documentation
+    pass
+
+
+class PointerArithmeticFunc(sp.Function, Boolean):
+    # TODO: documentation
+    @property
+    def canonical(self):
+        if hasattr(self.args[0], 'canonical'):
+            return self.args[0].canonical
+        else:
+            raise NotImplementedError()
diff --git a/pystencils/kernelparameters.py b/pystencils/typing/typed_sympy.py
similarity index 52%
rename from pystencils/kernelparameters.py
rename to pystencils/typing/typed_sympy.py
index 8bd4341be..0a253f748 100644
--- a/pystencils/kernelparameters.py
+++ b/pystencils/typing/typed_sympy.py
@@ -1,30 +1,102 @@
-"""Special symbols representing kernel parameters related to fields/arrays.
-
-A `KernelFunction` node determines parameters that have to be passed to the function by searching for all undefined
-symbols. Some symbols are not directly defined by the user, but are related to the `Field`s used in the kernel:
-For each field a `FieldPointerSymbol` needs to be passed in, which is the pointer to the memory region where
-the field is stored. This pointer is represented by the `FieldPointerSymbol` class that additionally stores the
-name of the corresponding field. For fields where the size is not known at compile time, additionally shape and stride
-information has to be passed in at runtime. These values are represented by  `FieldShapeSymbol`
-and `FieldPointerSymbol`.
-
-The special symbols in this module store only the field name instead of a field reference. Storing a field reference
-directly leads to problems with copying and pickling behaviour due to the circular dependency of `Field` and
-e.g. `FieldShapeSymbol`, since a Field contains `FieldShapeSymbol`s in its shape, and a `FieldShapeSymbol`
-would reference back to the field.
-"""
+from typing import Union
+
+import numpy as np
+import sympy as sp
 from sympy.core.cache import cacheit
 
-from pystencils.data_types import (
-    PointerType, TypedSymbol, create_composite_type_from_string, get_base_type)
+from pystencils.typing.types import BasicType, create_type, PointerType
+from pystencils.typing.utilities import get_base_type
+
+
+def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]):
+    # TODO: type hints and if dtype is correct type form Numpy
+    """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
+
+    Args:
+        dtype (BasicType, np.dtype): a Numpy data type
+    Returns:
+        A dict of SymPy assumptions
+    """
+    if hasattr(dtype, 'numpy_dtype'):
+        dtype = dtype.numpy_dtype
+
+    assumptions = dict()
+
+    try:
+        if np.issubdtype(dtype, np.integer):
+            assumptions.update({'integer': True})
 
+        if np.issubdtype(dtype, np.unsignedinteger):
+            assumptions.update({'negative': False})
 
-# TODO: Why do we need extra classes? Why isn't TypedSymbol enough?
-# TODO: Replace with a factory function
+        if np.issubdtype(dtype, np.integer) or \
+                np.issubdtype(dtype, np.floating):
+            assumptions.update({'real': True})
+    except Exception:  # TODO this is dirty
+        pass
 
+    return assumptions
 
-SHAPE_DTYPE = create_composite_type_from_string("const int64")
-STRIDE_DTYPE = create_composite_type_from_string("const int64")
+
+class TypedSymbol(sp.Symbol):
+    def __new__(cls, *args, **kwds):
+        obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
+        return obj
+
+    def __new_stage2__(cls, name, dtype, **kwargs):  # TODO does not match signature of sp.Symbol???
+        assumptions = assumptions_from_dtype(dtype)  # TODO should by dtype a np.dtype or our Type???
+        assumptions.update(kwargs)
+        obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
+        try:
+            obj.numpy_dtype = create_type(dtype)
+        except (TypeError, ValueError):
+            # on error keep the string
+            obj.numpy_dtype = dtype
+        return obj
+
+    __xnew__ = staticmethod(__new_stage2__)
+    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
+
+    @property
+    def dtype(self):
+        return self._dtype
+
+    def _hashable_content(self):
+        return super()._hashable_content(), hash(self._dtype)
+
+    def __getnewargs__(self):
+        return self.name, self.dtype
+
+    def __getnewargs_ex__(self):
+        return (self.name, self.dtype), self.assumptions0
+
+    @property
+    def canonical(self):
+        return self
+
+    @property
+    def reversed(self):
+        return self
+
+    @property
+    def headers(self):
+        headers = []
+        try:
+            if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
+                headers.append('"cuda_complex.hpp"')
+        except Exception:
+            pass
+        try:
+            if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
+                headers.append('"cuda_complex.hpp"')
+        except Exception:
+            pass
+
+        return headers
+
+
+SHAPE_DTYPE = BasicType('int64', const=True)
+STRIDE_DTYPE = BasicType('int64', const=True)
 
 
 class FieldStrideSymbol(TypedSymbol):
diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py
new file mode 100644
index 000000000..eabe87dbd
--- /dev/null
+++ b/pystencils/typing/types.py
@@ -0,0 +1,297 @@
+from abc import ABC, abstractmethod
+from typing import Union
+
+import numpy as np
+import sympy as sp
+import sympy.codegen.ast
+
+
+def is_supported_type(dtype: np.dtype):
+    scalar = dtype.type
+    c = np.issctype(dtype)
+    subclass = issubclass(scalar, np.floating) or issubclass(scalar, np.integer) or issubclass(scalar, np.bool)
+    additional_checks = dtype.fields is None and dtype.hasobject is False and dtype.subdtype is None
+    return c and subclass and additional_checks
+
+
+def numpy_name_to_c(name: str) -> str:
+    """
+    Converts a np.dtype.name into a C type
+    Args:
+        name: np.dtype.name string
+    Returns:
+        type as a C string
+    """
+    if name == 'float64':
+        return 'double'
+    elif name == 'float32':
+        return 'float'
+    elif name.startswith('int'):
+        width = int(name[len("int"):])
+        return f"int{width}_t"
+    elif name.startswith('uint'):
+        width = int(name[len("uint"):])
+        return f"uint{width}_t"
+    elif name == 'bool':
+        return 'bool'
+    else:
+        raise NotImplementedError(f"Can't map numpy to C name for {name}")
+
+
+class AbstractType(sp.Atom, ABC):
+    # TODO: inherits from sp.Atom because of cast function (and maybe others)
+    # TODO: is this necessary?
+    def __new__(cls, *args, **kwargs):
+        return sp.Basic.__new__(cls)
+
+    def _sympystr(self, *args, **kwargs):
+        return str(self)
+
+    @property
+    @abstractmethod
+    def base_type(self) -> Union[None, 'BasicType']:
+        """
+        Returns: Returns BasicType of a Vector or Pointer type, None otherwise
+        """
+        pass
+
+    @property
+    @abstractmethod
+    def item_size(self) -> int:
+        """
+        Returns: WHO THE FUCK KNOWS!??!!?
+        """
+        pass
+
+
+class BasicType(AbstractType):
+    # TODO: should be a sensible interface to np.dtype
+
+    def __init__(self, dtype: Union[np.dtype, 'BasicType', str], const: bool = False):
+        self.const = const
+        if isinstance(dtype, BasicType):
+            self.numpy_dtype = dtype.numpy_dtype  # TODO copy const as well??
+        else:
+            self.numpy_dtype = np.dtype(dtype)
+        assert is_supported_type(self.numpy_dtype), f'Type {self.numpy_dtype} is currently not supported!'
+
+    def __getnewargs__(self):
+        return self.numpy_dtype, self.const
+
+    def __getnewargs_ex__(self):
+        return (self.numpy_dtype, self.const), {}
+
+    @property
+    def base_type(self):
+        return None
+
+    @property
+    def sympy_dtype(self):
+        return getattr(sympy.codegen.ast, str(self.numpy_dtype))
+
+    @property
+    def item_size(self):  # TODO: what is this? Do we want self.numpy_type.itemsize????
+        return 1
+
+    def is_float(self):
+        return issubclass(self.numpy_dtype.type, np.floating)
+
+    def is_int(self):
+        return issubclass(self.numpy_dtype.type, np.integer)
+
+    def is_uint(self):
+        return issubclass(self.numpy_dtype.type, np.unsignedinteger)
+
+    def is_sint(self):
+        return issubclass(self.numpy_dtype.type, np.signedinteger)
+
+    def is_bool(self):
+        return issubclass(self.numpy_dtype.type, np.bool)
+
+    @property
+    def c_name(self) -> str:
+        return numpy_name_to_c(self.numpy_dtype.name)
+
+    def __str__(self):
+        return f'{self.c_name}{" const" if self.const else ""}'
+
+    def __repr__(self):
+        return str(self)
+
+    def __eq__(self, other):
+        if not isinstance(other, BasicType):
+            return False
+        else:
+            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
+
+    def __hash__(self):
+        return hash(str(self))
+
+
+class VectorType(AbstractType):
+    # TODO: check with rest
+    instruction_set = None
+
+    def __init__(self, base_type: BasicType, width: int = 4):  # TODO default vector length is dangerous
+        self._base_type = base_type
+        self.width = width
+
+    @property
+    def base_type(self):
+        return self._base_type
+
+    @property
+    def item_size(self):
+        return self.width * self.base_type.item_size
+
+    def __eq__(self, other):
+        if not isinstance(other, VectorType):
+            return False
+        else:
+            return (self.base_type, self.width) == (other.base_type, other.width)
+
+    def __str__(self):
+        if self.instruction_set is None:
+            return f"{self.base_type}[{self.width}]"
+        else:
+            # TODO this seems super weird. the instruction_set should know how to print a type out!!!
+            # TODO this is error prone. base_type could be cons=True. Use dtype instead
+            if self.base_type == create_type("int64") or self.base_type == create_type("int32"):
+                return self.instruction_set['int']
+            elif self.base_type == create_type("float64"):
+                return self.instruction_set['double']
+            elif self.base_type == create_type("float32"):
+                return self.instruction_set['float']
+            elif self.base_type == create_type("bool"):
+                return self.instruction_set['bool']
+            else:
+                raise NotImplementedError()
+
+    def __hash__(self):
+        return hash((self.base_type, self.width))
+
+    def __getnewargs__(self):
+        return self._base_type, self.width
+
+    def __getnewargs_ex__(self):
+        return (self._base_type, self.width), {}
+
+
+class PointerType(AbstractType):
+    def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True):
+        self._base_type = base_type
+        self.const = const
+        self.restrict = restrict
+
+    def __getnewargs__(self):
+        return self.base_type, self.const, self.restrict
+
+    def __getnewargs_ex__(self):
+        return (self.base_type, self.const, self.restrict), {}
+
+    @property
+    def alias(self):
+        return not self.restrict
+
+    @property
+    def base_type(self):
+        return self._base_type
+
+    @property
+    def item_size(self):
+        return self.base_type.item_size
+
+    def __eq__(self, other):
+        if not isinstance(other, PointerType):
+            return False
+        else:
+            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
+
+    def __str__(self):
+        return f'{str(self.base_type)} * {"RESTRICT " if self.restrict else "" }{"const" if self.const else ""}'
+
+    def __repr__(self):
+        return str(self)
+
+    def __hash__(self):
+        return hash((self._base_type, self.const, self.restrict))
+
+
+class StructType(AbstractType):
+    # TODO: Docs. This is a struct. A list of types (with C offsets)
+    # TODO StructType didn't inherit from AbstractType.....
+    # TODO: This is basically like a BasicType... only as struct
+    def __init__(self, numpy_type, const=False):
+        self.const = const
+        self._dtype = np.dtype(numpy_type)
+
+    def __getnewargs__(self):
+        return self.numpy_dtype, self.const
+
+    def __getnewargs_ex__(self):
+        return (self.numpy_dtype, self.const), {}
+
+    @property
+    def base_type(self):
+        return None
+
+    @property
+    def numpy_dtype(self):
+        return self._dtype
+
+    @property
+    def item_size(self):
+        return self.numpy_dtype.itemsize
+
+    def get_element_offset(self, element_name):
+        return self.numpy_dtype.fields[element_name][1]
+
+    def get_element_type(self, element_name):
+        np_element_type = self.numpy_dtype.fields[element_name][0]
+        return BasicType(np_element_type, self.const)
+
+    def has_element(self, element_name):
+        return element_name in self.numpy_dtype.fields
+
+    def __eq__(self, other):
+        if not isinstance(other, StructType):
+            return False
+        else:
+            return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
+
+    def __str__(self):
+        # structs are handled byte-wise
+        # TODO structs are weird
+        result = "uint8_t"
+        if self.const:
+            result += " const"
+        return result
+
+    def __repr__(self):
+        return str(self)
+
+    def __hash__(self):
+        return hash((self.numpy_dtype, self.const))
+
+
+def create_type(specification: Union[np.dtype, AbstractType, str]) -> AbstractType:
+    # TODO: Ok, this is basically useless. Except for it can differentiate between BasicType and StructType
+    # TODO: Everything else is already implemented inside BasicType
+    # TODO: Also why don't we support Vector and Pointer types???
+    """Creates a subclass of Type according to a string or an object of subclass Type.
+
+    Args:
+        specification: Type object, or a string
+
+    Returns:
+        Type object, or a new Type object parsed from the string
+    """
+    if isinstance(specification, AbstractType):
+        return specification
+    else:
+        numpy_dtype = np.dtype(specification)
+        if numpy_dtype.fields is None:
+            return BasicType(numpy_dtype, const=False)
+        else:
+            return StructType(numpy_dtype, const=False)
+
diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py
new file mode 100644
index 000000000..8187d929e
--- /dev/null
+++ b/pystencils/typing/utilities.py
@@ -0,0 +1,494 @@
+from collections import defaultdict
+from functools import partial
+from typing import Tuple, Union, List, Dict
+
+import numpy as np
+import sympy as sp
+from pystencils import astnodes as ast
+from pystencils.kernel_contrains_check import KernelConstraintsCheck
+from sympy.codegen import Assignment
+from sympy.logic.boolalg import Boolean, BooleanFunction
+
+import pystencils
+from pystencils.cache import memorycache, memorycache_if_hashable
+from pystencils.utils import all_equal
+from pystencils.typing.types import AbstractType, BasicType, VectorType, PointerType, StructType, create_type
+from pystencils.typing.cast_functions import CastFunc, PointerArithmeticFunc
+from pystencils.typing.typed_sympy import TypedSymbol
+
+
+def typed_symbols(names, dtype, *args):
+    # TODO docs, type hints
+    symbols = sp.symbols(names, *args)
+    if isinstance(symbols, Tuple):
+        return tuple(TypedSymbol(str(s), dtype) for s in symbols)
+    else:
+        return TypedSymbol(str(symbols), dtype)
+
+
+# noinspection PyPep8Naming
+class address_of(sp.Function):
+    # DONE: ask Martin
+    # TODO: docstring
+    # this is '&' in C
+    is_Atom = True
+
+    def __new__(cls, arg):
+        obj = sp.Function.__new__(cls, arg)
+        return obj
+
+    @property
+    def canonical(self):
+        if hasattr(self.args[0], 'canonical'):
+            return self.args[0].canonical
+        else:
+            raise NotImplementedError()
+
+    @property
+    def is_commutative(self):
+        return self.args[0].is_commutative
+
+    @property
+    def dtype(self):
+        if hasattr(self.args[0], 'dtype'):
+            return PointerType(self.args[0].dtype, restrict=True)
+        else:
+            return PointerType('void', restrict=True)  # TODO this shouldn't work??? FIX: Allow BasicType to be Void and use that. Or raise exception
+
+
+def get_base_type(data_type):
+    # TODO: WTF is this?? DOCS!!!
+    # TODO: This is unsafe.
+    # TODO: remove
+    # Pointer(Pointer(int))
+    while data_type.base_type is not None:
+        data_type = data_type.base_type
+    return data_type
+
+
+def peel_off_type(dtype, type_to_peel_off):
+    # TODO: WTF is this??? DOCS!!!
+    # TODO: used only once.... can be a lambda there
+    while type(dtype) is type_to_peel_off:
+        dtype = dtype.base_type
+    return dtype
+
+
+
+############################# This is basically our type system ########################################################
+def collate_types(types,
+                  forbid_collation_to_complex=False,  # TODO: type system shouldn't need this!!!
+                  forbid_collation_to_float=False,  # TODO: type system shouldn't need this!!!
+                  default_float_type='float64',
+                  # TODO: AST leaves should be typed. Expressions should be able to find out correct type
+                  default_int_type='int64'):  # TODO: AST leaves should be typed. Expressions should be able to find out correct type
+    """
+    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
+    Uses the collation rules from numpy.
+    """
+    # TODO: use np.can_cast and np.promote_types and np.result_type and np.find_common_type
+    if forbid_collation_to_complex:
+        types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.complexfloating)]
+        if not types:
+            return create_type(default_float_type)
+
+    if forbid_collation_to_float:
+        types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)]
+        if not types:
+            return create_type(default_int_type)
+
+    # Pointer arithmetic case i.e. pointer + integer is allowed
+    if any(type(t) is PointerType for t in types):
+        pointer_type = None
+        for t in types:
+            if type(t) is PointerType:
+                if pointer_type is not None:
+                    raise ValueError("Cannot collate the combination of two pointer types")
+                pointer_type = t
+            elif type(t) is BasicType:
+                if not (t.is_int() or t.is_uint()):
+                    raise ValueError("Invalid pointer arithmetic")
+            else:
+                raise ValueError("Invalid pointer arithmetic")
+        return pointer_type
+
+    # peel of vector types, if at least one vector type occurred the result will also be the vector type
+    vector_type = [t for t in types if type(t) is VectorType]
+    if not all_equal(t.width for t in vector_type):
+        raise ValueError("Collation failed because of vector types with different width")
+    types = [peel_off_type(t, VectorType) for t in types]
+
+    # now we should have a list of basic types - struct types are not yet supported
+    assert all(type(t) is BasicType for t in types)
+
+    if any(t.is_float() for t in types):
+        types = tuple(t for t in types if t.is_float())
+    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
+    result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
+    result = BasicType(result_numpy_type)
+    if vector_type:
+        result = VectorType(result, vector_type[0].width)
+    return result
+
+
+@memorycache_if_hashable(maxsize=2048)
+def get_type_of_expression(expr,
+                           default_float_type='double',
+                           # TODO: we shouldn't need to have default. AST leaves should have a type
+                           default_int_type='int',
+                           # TODO: we shouldn't need to have default. AST leaves should have a type
+                           symbol_type_dict=None):  # TODO: we shouldn't need to have default. AST leaves should have a type
+    from pystencils.astnodes import ResolvedFieldAccess
+    from pystencils.cpu.vectorization import vec_all, vec_any
+
+    if default_float_type == 'float':
+        default_float_type = 'float32'
+
+    if not symbol_type_dict:
+        symbol_type_dict = defaultdict(lambda: create_type('double'))
+
+    get_type = partial(get_type_of_expression,
+                       default_float_type=default_float_type,
+                       default_int_type=default_int_type,
+                       symbol_type_dict=symbol_type_dict)
+
+    expr = sp.sympify(expr)
+    if isinstance(expr, sp.Integer):
+        return create_type(default_int_type)
+    elif expr.is_real is False:
+        return create_type((np.zeros((1,), default_float_type) * 1j).dtype)
+    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
+        return create_type(default_float_type)
+    elif isinstance(expr, ResolvedFieldAccess):
+        return expr.field.dtype
+    elif isinstance(expr, pystencils.field.Field.AbstractAccess):
+        return expr.field.dtype
+    elif isinstance(expr, TypedSymbol):
+        return expr.dtype
+    elif isinstance(expr, sp.Symbol):
+        if symbol_type_dict:
+            return symbol_type_dict[expr.name]
+        else:
+            raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
+    elif isinstance(expr, CastFunc):
+        return expr.args[1]
+    elif isinstance(expr, (vec_any, vec_all)):
+        return create_type("bool")
+    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
+        collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
+        collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
+        if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
+            collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
+        return collated_result_type
+    elif isinstance(expr, sp.Indexed):
+        typed_symbol = expr.base.label
+        return typed_symbol.dtype.base_type
+    elif isinstance(expr, (Boolean, BooleanFunction)):
+        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
+        result = create_type("bool")
+        vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
+        if vec_args:
+            result = VectorType(result, width=vec_args[0].width)
+        return result
+    elif isinstance(expr, sp.Pow):
+        base_type = get_type(expr.args[0])
+        if expr.exp.is_integer:
+            return base_type
+        else:
+            return collate_types([create_type(default_float_type), base_type])
+    elif isinstance(expr, (sp.Sum, sp.Product)):
+        return get_type(expr.args[0])
+    elif isinstance(expr, sp.Expr):
+        expr: sp.Expr
+        if expr.args:
+            types = tuple(get_type(a) for a in expr.args)
+            # collate_types checks numpy_dtype in the special cases
+            if any(not hasattr(t, 'numpy_dtype') for t in types):
+                forbid_collation_to_complex = False
+                forbid_collation_to_float = False
+            else:
+                forbid_collation_to_complex = expr.is_real is True
+                forbid_collation_to_float = expr.is_integer is True
+            return collate_types(
+                types,
+                forbid_collation_to_complex=forbid_collation_to_complex,
+                forbid_collation_to_float=forbid_collation_to_float,
+                default_float_type=default_float_type,
+                default_int_type=default_int_type)
+        else:
+            if expr.is_integer:
+                return create_type(default_int_type)
+            else:
+                return create_type(default_float_type)
+
+    raise NotImplementedError("Could not determine type for", expr, type(expr))
+
+
+############################# End This is basically our type system ##################################################
+
+
+# TODO this seems quite wrong...
+sympy_version = sp.__version__.split('.')
+if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
+    # __setstate__ would bypass the contructor, so we remove it
+    sp.Number.__getstate__ = sp.Basic.__getstate__
+    del sp.Basic.__getstate__
+
+
+    class FunctorWithStoredKwargs:
+        def __init__(self, func, **kwargs):
+            self.func = func
+            self.kwargs = kwargs
+
+        def __call__(self, *args):
+            return self.func(*args, **self.kwargs)
+
+
+    # __reduce_ex__ would strip kwargs, so we override it
+    def basic_reduce_ex(self, protocol):
+        if hasattr(self, '__getnewargs_ex__'):
+            args, kwargs = self.__getnewargs_ex__()
+        else:
+            args, kwargs = self.__getnewargs__(), {}
+        if hasattr(self, '__getstate__'):
+            state = self.__getstate__()
+        else:
+            state = None
+        return FunctorWithStoredKwargs(type(self), **kwargs), args, state
+
+
+    sp.Number.__reduce_ex__ = sp.Basic.__reduce_ex__
+    sp.Basic.__reduce_ex__ = basic_reduce_ex
+
+
+def add_types(eqs: List[Assignment], type_for_symbol: Dict[sp.Symbol, np.dtype], check_independence_condition: bool,
+              check_double_write_condition: bool=True):
+    """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
+
+    Additionally returns sets of all fields which are read/written
+
+    Args:
+        eqs: list of equations
+        type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
+        check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
+                                      kernels
+
+    Returns:
+        ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
+         list of equations where symbols have been replaced by typed symbols
+    """
+    if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'):
+        type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
+
+    type_for_symbol = adjust_c_single_precision_type(type_for_symbol)
+
+    # TODO what does this do????
+    # TODO: ask Martin
+    check = KernelConstraintsCheck(type_for_symbol, check_independence_condition,
+                                   check_double_write_condition=check_double_write_condition)
+
+    # TODO: check if this adds only types to leave nodes of AST, get type info
+    def visit(obj):
+        if isinstance(obj, (list, tuple)):
+            return [visit(e) for e in obj]
+        if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
+            return check.process_assignment(obj)
+        elif isinstance(obj, ast.Conditional):
+            check.scopes.push()
+            # Disable double write check inside conditionals
+            # would be triggered by e.g. in-kernel boundaries
+            old_double_write = check.check_double_write_condition
+            check.check_double_write_condition = False
+            false_block = None if obj.false_block is None else visit(
+                obj.false_block)
+            result = ast.Conditional(check.process_expression(
+                obj.condition_expr, type_constants=False),
+                true_block=visit(obj.true_block),
+                false_block=false_block)
+            check.check_double_write_condition = old_double_write
+            check.scopes.pop()
+            return result
+        elif isinstance(obj, ast.Block):
+            check.scopes.push()
+            result = ast.Block([visit(e) for e in obj.args])
+            check.scopes.pop()
+            return result
+        elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
+            return obj
+        else:
+            raise ValueError("Invalid object in kernel " + str(type(obj)))
+
+    typed_equations = visit(eqs)
+
+    return check.fields_read, check.fields_written, typed_equations
+
+
+def insert_casts(node):
+    """Checks the types and inserts casts and pointer arithmetic where necessary.
+
+    Args:
+        node: the head node of the ast
+
+    Returns:
+        modified AST
+    """
+    def cast(zipped_args_types, target_dtype):
+        """
+        Adds casts to the arguments if their type differs from the target type
+        :param zipped_args_types: a zipped list of args and types
+        :param target_dtype: The target data type
+        :return: args with possible casts
+        """
+        casted_args = []
+        for argument, data_type in zipped_args_types:
+            if data_type.numpy_dtype != target_dtype.numpy_dtype:  # ignoring const
+                casted_args.append(CastFunc(argument, target_dtype))
+            else:
+                casted_args.append(argument)
+        return casted_args
+
+    def pointer_arithmetic(expr_args):
+        """
+        Creates a valid pointer arithmetic function
+        :param expr_args: Arguments of the add expression
+        :return: pointer_arithmetic_func
+        """
+        pointer = None
+        new_args = []
+        for arg, data_type in expr_args:
+            if data_type.func is PointerType:
+                assert pointer is None
+                pointer = arg
+        for arg, data_type in expr_args:
+            if arg != pointer:
+                assert data_type.is_int() or data_type.is_uint()
+                new_args.append(arg)
+        new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
+        return PointerArithmeticFunc(pointer, new_args)
+
+    if isinstance(node, sp.AtomicExpr) or isinstance(node, CastFunc):
+        return node
+    args = []
+    for arg in node.args:
+        args.append(insert_casts(arg))
+    # TODO indexed, LoopOverCoordinate
+    if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
+        # TODO optimize pow, don't cast integer on double
+        types = [get_type_of_expression(arg) for arg in args]
+        assert len(types) > 0
+        # Never ever, ever collate to float type for boolean functions!
+        target = collate_types(types, forbid_collation_to_float=isinstance(node.func, BooleanFunction))
+        zipped = list(zip(args, types))
+        if target.func is PointerType:
+            assert node.func is sp.Add
+            return pointer_arithmetic(zipped)
+        else:
+            return node.func(*cast(zipped, target))
+    elif node.func is ast.SympyAssignment:
+        lhs = args[0]
+        rhs = args[1]
+        target = get_type_of_expression(lhs)
+        if target.func is PointerType:
+            return node.func(*args)  # TODO fix, not complete
+        else:
+            return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
+    elif node.func is ast.ResolvedFieldAccess:
+        return node
+    elif node.func is ast.Block:
+        for old_arg, new_arg in zip(node.args, args):
+            node.replace(old_arg, new_arg)
+        return node
+    elif node.func is ast.LoopOverCoordinate:
+        for old_arg, new_arg in zip(node.args, args):
+            node.replace(old_arg, new_arg)
+        return node
+    elif node.func is sp.Piecewise:
+        expressions = [expr for (expr, _) in args]
+        types = [get_type_of_expression(expr) for expr in expressions]
+        target = collate_types(types)
+        zipped = list(zip(expressions, types))
+        casted_expressions = cast(zipped, target)
+        args = [
+            arg.func(*[expr, arg.cond])
+            for (arg, expr) in zip(args, casted_expressions)
+        ]
+
+    return node.func(*args)
+
+
+def adjust_c_single_precision_type(type_for_symbol):
+    """Replaces every occurrence of 'float' with 'single' to enforce the numpy single precision type."""
+    def single_factory():
+        return "single"
+
+    for symbol in type_for_symbol:
+        if type_for_symbol[symbol] == "float":
+            type_for_symbol[symbol] = single_factory()
+    if hasattr(type_for_symbol, "default_factory") and type_for_symbol.default_factory() == "float":
+        type_for_symbol.default_factory = single_factory
+    return type_for_symbol
+
+
+def get_next_parent_of_type(node, parent_type):
+    """Returns the next parent node of given type or None, if root is reached.
+
+    Traverses the AST nodes parents until a parent of given type was found.
+    If no such parent is found, None is returned
+    """
+    parent = node.parent
+    while parent is not None:
+        if isinstance(parent, parent_type):
+            return parent
+        parent = parent.parent
+    return None
+
+
+def parents_of_type(node, parent_type, include_current=False):
+    """Generator for all parent nodes of given type"""
+    parent = node if include_current else node.parent
+    while parent is not None:
+        if isinstance(parent, parent_type):
+            yield parent
+        parent = parent.parent
+
+
+def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='int64'):
+    """
+    Creates a default symbol name to type mapping.
+    If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
+
+    Args:
+        eqs: list of equations
+        default_type: the type for non-boolean symbols
+    Returns:
+        dictionary, mapping symbol name to type
+    """
+    result = defaultdict(lambda: default_type)
+    if hasattr(default_type, 'numpy_dtype'):
+        result['_complex_type'] = (np.zeros((1,), default_type.numpy_dtype) * 1j).dtype
+    else:
+        result['_complex_type'] = (np.zeros((1,), default_type) * 1j).dtype
+    for eq in eqs:
+        if isinstance(eq, ast.Conditional):
+            result.update(typing_from_sympy_inspection(eq.true_block.args))
+            if eq.false_block:
+                result.update(typing_from_sympy_inspection(
+                    eq.false_block.args))
+        elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment):
+            continue
+        else:
+            from pystencils.cpu.vectorization import vec_all, vec_any
+            if isinstance(eq.rhs, (vec_all, vec_any)):
+                result[eq.lhs.name] = "bool"
+            # problematic case here is when rhs is a symbol: then it is impossible to decide here without
+            # further information what type the left hand side is - default fallback is the dict value then
+            if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
+                result[eq.lhs.name] = "bool"
+            try:
+                result[eq.lhs.name] = get_type_of_expression(eq.rhs,
+                                                             default_float_type=default_type,
+                                                             default_int_type=default_int_type,
+                                                             symbol_type_dict=result)
+            except Exception:
+                pass  # gracefully fail in case get_type_of_expression cannot determine type
+    return result
\ No newline at end of file
diff --git a/pystencils_tests/test_abs.py b/pystencils_tests/test_abs.py
index cf71bc04c..7bf7a1a45 100644
--- a/pystencils_tests/test_abs.py
+++ b/pystencils_tests/test_abs.py
@@ -1,7 +1,7 @@
 import sympy
 
 import pystencils as ps
-from pystencils.data_types import cast_func, create_type
+from pystencils.typing import CastFunc, create_type
 
 
 def test_abs():
@@ -10,7 +10,7 @@ def test_abs():
     default_int_type = create_type('int64')
 
     assignments = ps.AssignmentCollection({
-        x[0, 0]: sympy.Abs(cast_func(y[0, 0], default_int_type))
+        x[0, 0]: sympy.Abs(CastFunc(y[0, 0], default_int_type))
     })
 
     config = ps.CreateKernelConfig(target=ps.Target.GPU)
diff --git a/pystencils_tests/test_address_of.py b/pystencils_tests/test_address_of.py
index 659f5d92f..1cb9c8ed1 100644
--- a/pystencils_tests/test_address_of.py
+++ b/pystencils_tests/test_address_of.py
@@ -3,7 +3,7 @@ Test of pystencils.data_types.address_of
 """
 import sympy as sp
 import pystencils
-from pystencils.data_types import PointerType, address_of, cast_func, create_type
+from pystencils.typing import PointerType, address_of, CastFunc, create_type
 from pystencils.simp.simplifications import sympy_cse
 
 
@@ -17,14 +17,14 @@ def test_address_of():
 
     assignments = pystencils.AssignmentCollection({
         s: address_of(x[0, 0]),
-        y[0, 0]: cast_func(s, create_type('int64'))
+        y[0, 0]: CastFunc(s, create_type('int64'))
     }, {})
 
     ast = pystencils.create_kernel(assignments)
     pystencils.show_code(ast)
 
     assignments = pystencils.AssignmentCollection({
-        y[0, 0]: cast_func(address_of(x[0, 0]), create_type('int64'))
+        y[0, 0]: CastFunc(address_of(x[0, 0]), create_type('int64'))
     }, {})
 
     ast = pystencils.create_kernel(assignments)
@@ -36,8 +36,8 @@ def test_address_of_with_cse():
     s = pystencils.TypedSymbol('s', PointerType(create_type('int64')))
 
     assignments = pystencils.AssignmentCollection({
-        y[0, 0]: cast_func(address_of(x[0, 0]), create_type('int64')) + s,
-        x[0, 0]: cast_func(address_of(x[0, 0]), create_type('int64')) + 1
+        y[0, 0]: CastFunc(address_of(x[0, 0]), create_type('int64')) + s,
+        x[0, 0]: CastFunc(address_of(x[0, 0]), create_type('int64')) + 1
     }, {})
 
     ast = pystencils.create_kernel(assignments)
diff --git a/pystencils_tests/test_complex_numbers.py b/pystencils_tests/test_complex_numbers.py
index 9d9f71952..7f3894825 100644
--- a/pystencils_tests/test_complex_numbers.py
+++ b/pystencils_tests/test_complex_numbers.py
@@ -16,7 +16,7 @@ from sympy.functions import im, re
 
 import pystencils
 from pystencils import AssignmentCollection
-from pystencils.data_types import TypedSymbol, create_type
+from pystencils.typing import TypedSymbol, create_type
 
 X, Y = pystencils.fields('x, y: complex64[2d]')
 A, B = pystencils.fields('a, b: float32[2d]')
diff --git a/pystencils_tests/test_cuda_known_functions.py b/pystencils_tests/test_cuda_known_functions.py
index 32b7d9b76..7e465da9f 100644
--- a/pystencils_tests/test_cuda_known_functions.py
+++ b/pystencils_tests/test_cuda_known_functions.py
@@ -5,7 +5,7 @@ import pytest
 import pystencils
 from pystencils.astnodes import get_dummy_symbol
 from pystencils.backends.cuda_backend import CudaSympyPrinter
-from pystencils.data_types import address_of
+from pystencils.typing import address_of
 from pystencils.enums import Target
 
 
diff --git a/pystencils_tests/test_field.py b/pystencils_tests/test_field.py
index 596f9f4da..14c751336 100644
--- a/pystencils_tests/test_field.py
+++ b/pystencils_tests/test_field.py
@@ -4,7 +4,7 @@ import sympy as sp
 
 import pystencils as ps
 from pystencils import TypedSymbol
-from pystencils.data_types import create_type
+from pystencils.typing import create_type
 from pystencils.field import Field, FieldType, layout_string_to_tuple
 
 
diff --git a/pystencils_tests/test_floor_ceil_int_optimization.py b/pystencils_tests/test_floor_ceil_int_optimization.py
index 7ec81b05b..ce06f0559 100644
--- a/pystencils_tests/test_floor_ceil_int_optimization.py
+++ b/pystencils_tests/test_floor_ceil_int_optimization.py
@@ -11,7 +11,7 @@
 import sympy as sp
 
 import pystencils
-from pystencils.data_types import create_type
+from pystencils.typing import create_type
 
 
 def test_floor_ceil_int_optimization():
diff --git a/pystencils_tests/test_global_definitions.py b/pystencils_tests/test_global_definitions.py
index c08557018..fa51ccc9d 100644
--- a/pystencils_tests/test_global_definitions.py
+++ b/pystencils_tests/test_global_definitions.py
@@ -2,7 +2,7 @@ import sympy
 
 import pystencils.astnodes
 from pystencils.backends.cbackend import CBackend
-from pystencils.data_types import TypedSymbol
+from pystencils.typing import TypedSymbol
 
 
 class BogusDeclaration(pystencils.astnodes.Node):
diff --git a/pystencils_tests/test_kernel_data_type.py b/pystencils_tests/test_kernel_data_type.py
index 2fbab3ff1..25ca56c2b 100644
--- a/pystencils_tests/test_kernel_data_type.py
+++ b/pystencils_tests/test_kernel_data_type.py
@@ -5,7 +5,7 @@ import pytest
 from sympy.abc import x, y
 
 from pystencils import Assignment, create_kernel, fields, CreateKernelConfig
-from pystencils.transformations import adjust_c_single_precision_type
+from pystencils.typing import adjust_c_single_precision_type
 
 
 @pytest.mark.parametrize("data_type", ("float", "double"))
diff --git a/pystencils_tests/test_match_subs_for_assignment_collection.py b/pystencils_tests/test_match_subs_for_assignment_collection.py
index 9bcc5ad6b..7bb0ec509 100644
--- a/pystencils_tests/test_match_subs_for_assignment_collection.py
+++ b/pystencils_tests/test_match_subs_for_assignment_collection.py
@@ -11,12 +11,12 @@
 import sympy as sp
 
 import pystencils
-from pystencils.data_types import create_type
+from pystencils.typing import create_type
 
 
 def test_wild_typed_symbol():
     x = pystencils.fields('x:  float32[3d]')
-    typed_symbol = pystencils.data_types.TypedSymbol('a', create_type('float64'))
+    typed_symbol = pystencils.typing.data_types.TypedSymbol('a', create_type('float64'))
 
     assert x.center().match(sp.Wild('w1'))
     assert typed_symbol.match(sp.Wild('w1'))
diff --git a/pystencils_tests/test_pickle_support.py b/pystencils_tests/test_pickle_support.py
index 462645198..87268a777 100644
--- a/pystencils_tests/test_pickle_support.py
+++ b/pystencils_tests/test_pickle_support.py
@@ -1,7 +1,7 @@
 from copy import copy, deepcopy
 
 from pystencils.field import Field
-from pystencils.data_types import TypedSymbol
+from pystencils.typing import TypedSymbol
 
 
 def test_field_access():
diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py
index d1f509e65..b29f15eb7 100644
--- a/pystencils_tests/test_random.py
+++ b/pystencils_tests/test_random.py
@@ -6,7 +6,7 @@ import pystencils as ps
 from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol
 from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets
 from pystencils.cpu.cpujit import get_compiler_config
-from pystencils.data_types import TypedSymbol
+from pystencils.typing import TypedSymbol
 from pystencils.enums import Target
 
 RNGs = {('philox', 'float'): PhiloxFourFloats, ('philox', 'double'): PhiloxTwoDoubles,
diff --git a/pystencils_tests/test_sum_prod.py b/pystencils_tests/test_sum_prod.py
index 2f6bf7359..235644db2 100644
--- a/pystencils_tests/test_sum_prod.py
+++ b/pystencils_tests/test_sum_prod.py
@@ -13,7 +13,7 @@ import sympy as sp
 import sympy.abc
 
 import pystencils as ps
-from pystencils.data_types import create_type
+from pystencils.typing import create_type
 
 
 @pytest.mark.parametrize('default_assignment_simplifications', [False, True])
diff --git a/pystencils_tests/test_transformations.py b/pystencils_tests/test_transformations.py
index 9b0024980..3ede70a85 100644
--- a/pystencils_tests/test_transformations.py
+++ b/pystencils_tests/test_transformations.py
@@ -1,7 +1,7 @@
 import pystencils as ps
 from pystencils import TypedSymbol
 from pystencils.astnodes import LoopOverCoordinate, SympyAssignment
-from pystencils.data_types import create_type
+from pystencils.typing import create_type
 from pystencils.transformations import filtered_tree_iteration, get_loop_hierarchy, get_loop_counter_symbol_hierarchy
 
 
diff --git a/pystencils_tests/test_type_interference.py b/pystencils_tests/test_type_interference.py
index 953b87742..179fa2836 100644
--- a/pystencils_tests/test_type_interference.py
+++ b/pystencils_tests/test_type_interference.py
@@ -1,14 +1,14 @@
 from sympy.abc import a, b, c, d, e, f
 
 import pystencils
-from pystencils.data_types import cast_func, create_type
+from pystencils.typing import CastFunc, create_type
 
 
 def test_type_interference():
     x = pystencils.fields('x:  float32[3d]')
     assignments = pystencils.AssignmentCollection({
-        a: cast_func(10, create_type('float64')),
-        b: cast_func(10, create_type('uint16')),
+        a: CastFunc(10, create_type('float64')),
+        b: CastFunc(10, create_type('uint16')),
         e: 11,
         c: b,
         f: c + b,
diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py
index 75ba2c5e3..5c2b008e4 100644
--- a/pystencils_tests/test_types.py
+++ b/pystencils_tests/test_types.py
@@ -1,21 +1,9 @@
 import sympy as sp
 import numpy as np
-import pytest
-import ctypes
 
 import pystencils as ps
-from pystencils import data_types
-from pystencils.data_types import TypedSymbol, get_type_of_expression, VectorType, collate_types, create_type, \
-    typed_symbols, type_all_numbers, matrix_symbols, cast_func, pointer_arithmetic_func, ctypes_from_llvm, PointerType
-
-
-def test_parsing():
-    assert str(data_types.create_composite_type_from_string("const double *")) == "double const *"
-    assert str(data_types.create_composite_type_from_string("double const *")) == "double const *"
-
-    t1 = data_types.create_composite_type_from_string("const double * const * const restrict")
-    t2 = data_types.create_composite_type_from_string(str(t1))
-    assert t1 == t2
+from pystencils.typing import TypedSymbol, get_type_of_expression, VectorType, collate_types, create_type, \
+    typed_symbols, type_all_numbers, matrix_symbols, CastFunc, PointerArithmeticFunc, PointerType
 
 
 def test_collation():
@@ -133,7 +121,7 @@ def test_Basic_data_type():
     assert typed_symbols("s", bool).dtype.is_other()
     assert typed_symbols("s", np.void).dtype.is_other()
 
-    assert typed_symbols("s", np.float64).dtype.base_name == 'double'
+    assert typed_symbols("s", np.float64).dtype.c_name == 'double'
     # removed for old sympy version
     # assert typed_symbols(("s"), np.float64).dtype.sympy_dtype == typed_symbols(("s"), float).dtype.sympy_dtype
 
@@ -157,15 +145,15 @@ def test_Basic_data_type():
 
 
 def test_cast_func():
-    assert cast_func(TypedSymbol("s", np.uint), np.int64).canonical == TypedSymbol("s", np.uint).canonical
+    assert CastFunc(TypedSymbol("s", np.uint), np.int64).canonical == TypedSymbol("s", np.uint).canonical
 
-    a = cast_func(5, np.uint)
+    a = CastFunc(5, np.uint)
     assert a.is_negative is False
     assert a.is_nonnegative
 
 
 def test_pointer_arithmetic_func():
-    assert pointer_arithmetic_func(TypedSymbol("s", np.uint), 1).canonical == TypedSymbol("s", np.uint).canonical
+    assert PointerArithmeticFunc(TypedSymbol("s", np.uint), 1).canonical == TypedSymbol("s", np.uint).canonical
 
 
 def test_division():
-- 
GitLab