diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 4d9a77e479ca1602d67157c56f7cad197334281e..7b819db3cb3b3097a32f22657a85b02ae6b981e5 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -6,10 +6,10 @@ from typing import Any, List, Optional, Sequence, Set, Union
 import sympy as sp
 
 import pystencils
-from pystencils.typing import TypedSymbol, CastFunc, create_type, get_next_parent_of_type
+from pystencils.typing import create_type, get_next_parent_of_type, CastFunc
 from pystencils.enums import Target, Backend
 from pystencils.field import Field
-from pystencils.typing.typed_sympy import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
+from pystencils.typing.typed_sympy import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol
 from pystencils.sympyextensions import fast_subs
 
 NodeOrExpr = Union['Node', sp.Expr]
diff --git a/pystencils/boundaries/boundaryhandling.py b/pystencils/boundaries/boundaryhandling.py
index 4ad3ab3ffba2d4d8e774121ab3123a4683d9ccf9..52a314d75e5783085d73772439d91d0c44fb4d02 100644
--- a/pystencils/boundaries/boundaryhandling.py
+++ b/pystencils/boundaries/boundaryhandling.py
@@ -10,7 +10,7 @@ from pystencils.cache import memorycache
 from pystencils.typing import TypedSymbol, create_type
 from pystencils.datahandling.pycuda import PyCudaArrayHandler
 from pystencils.field import Field
-from pystencils.kernelparameters import FieldPointerSymbol
+from pystencils.typing.typed_sympy import FieldPointerSymbol
 
 try:
     # noinspection PyPep8Naming
diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py
index f2dc0ff933a1e51d6bfc67cc0b90963b478f010d..77a4a7d7920b2a64fb5129d86e9e342c9bd63906 100644
--- a/pystencils/cpu/kernelcreation.py
+++ b/pystencils/cpu/kernelcreation.py
@@ -59,7 +59,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
         else:
             raise ValueError("Term has to be field access or symbol")
 
-    # TODO 1) check kernel
+    # TODO 1) check kernel -> do general checks elsewhere
     # TODO 2) add leaf types
     fields_read, fields_written, assignments = add_types(
         assignments, type_info, not skip_independence_check, check_double_write_condition=not allow_double_writes)
diff --git a/pystencils/field.py b/pystencils/field.py
index 91b33eed39c45998256b9e58e1705404e1f7d437..4a29a1be201cf34cc396e5da355dc77fdba23446 100644
--- a/pystencils/field.py
+++ b/pystencils/field.py
@@ -319,7 +319,7 @@ class Field:
         assert isinstance(field_type, FieldType)
         assert len(shape) == len(strides)
         self.field_type = field_type
-        self._dtype = create_type(dtype)
+        self._dtype = create_type(dtype)  # TODO do we have AoS???
         self._layout = normalize_layout(layout)
         self.shape = shape
         self.strides = strides
@@ -619,7 +619,7 @@ class Field:
         self.coordinate_origin = -sp.Matrix([i / 2 for i in self.spatial_shape])
 
     # noinspection PyAttributeOutsideInit,PyUnresolvedReferences
-    class Access(TypedSymbol, Field.Access):
+    class Access(TypedSymbol):
         """Class representing a relative access into a `Field`.
 
         This class behaves like a normal sympy Symbol, it is actually derived from it. One can built up
diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py
index a13297e0d7a222f40af25ccefb2623304a9f2f62..b6fb901750895b341d44fde26040ff3b91d0e9e9 100644
--- a/pystencils/gpucuda/cudajit.py
+++ b/pystencils/gpucuda/cudajit.py
@@ -6,7 +6,7 @@ from pystencils.typing import StructType
 from pystencils.field import FieldType
 from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
 from pystencils.kernel_wrapper import KernelWrapper
-from pystencils.kernelparameters import FieldPointerSymbol
+from pystencils.typing.typed_sympy import FieldPointerSymbol
 
 USE_FAST_MATH = True
 
diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py
index 842e70ad93cbdc3cd16710c212ecfd51b71b4456..a2b0740c987098a74984789f4c76d7bf35445f83 100644
--- a/pystencils/kernel_contrains_check.py
+++ b/pystencils/kernel_contrains_check.py
@@ -9,8 +9,15 @@ from pystencils.field import Field
 from pystencils.transformations import NestedScopes
 
 
+accepted_functions = [
+    sp.Pow,
+    sp.sqrt,  # TODO why not a class??
+    # TODO trigonometric functions
+]
+
+
 class KernelConstraintsCheck:
-    # TODO: specification
+    # TODO: proper specification
     # TODO: More checks :)
     """Checks if the input to create_kernel is valid.
 
@@ -26,28 +33,52 @@ class KernelConstraintsCheck:
     """
     FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
 
-    def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True):
-        self._type_for_symbol = type_for_symbol
-
+    def __init__(self, check_independence_condition, check_double_write_condition=True):
         self.scopes = NestedScopes()
         self.field_writes = defaultdict(set)
         self.fields_read = set()
         self.check_independence_condition = check_independence_condition
         self.check_double_write_condition = check_double_write_condition
 
+    def visit(self, obj):
+        if isinstance(obj, (list, tuple)):
+            [self.visit(e) for e in obj]
+        if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
+            self.process_assignment(obj)
+        elif isinstance(obj, ast.Conditional):
+            self.scopes.push()
+            # Disable double write check inside conditionals
+            # would be triggered by e.g. in-kernel boundaries
+            old_double_write = self.check_double_write_condition
+            self.check_double_write_condition = False
+            if obj.false_block:
+                self.visit(obj.false_block)
+            self.process_expression(obj.condition_expr)
+            self.check_double_write_condition = old_double_write
+            self.scopes.pop()
+        elif isinstance(obj, ast.Block):
+            self.scopes.push()
+            [self.visit(e) for e in obj.args]
+            self.scopes.pop()
+        elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
+            pass
+        else:
+            raise ValueError(f'Invalid object in kernel {type(obj)}')
+
     def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]):
         # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
         self.process_expression(assignment.rhs)
         self.process_lhs(assignment.lhs)
 
-    def process_expression(self, rhs, type_constants=True):
+    def process_expression(self, rhs):
+        # TODO constraint for accepted functions
         self.update_accesses_rhs(rhs)
         if isinstance(rhs, Field.Access):
             self.fields_read.add(rhs.field)
             self.fields_read.update(rhs.indirect_addressing_fields)
         else:
             for arg in rhs.args:
-                self.process_expression(arg, type_constants)
+                self.process_expression(arg)
 
     @property
     def fields_written(self):
diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index 1746a8b9994292bee2b74aeaa4aacefb4931f5f7..0a9aea653fc716e2ce2f5c129cfb62f30e7c44f9 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -11,7 +11,7 @@ from sympy.core.numbers import Zero
 
 from pystencils.assignment import Assignment
 from pystencils.typing import CastFunc, get_type_of_expression, PointerType, VectorType
-from pystencils.kernelparameters import FieldPointerSymbol
+from pystencils.typing.typed_sympy import FieldPointerSymbol
 
 T = TypeVar('T')
 
diff --git a/pystencils/typing/__init__.py b/pystencils/typing/__init__.py
index 55fb731c0bd45c006d38c72cd20558fdf2dd6d17..2221b812b82e08976e2b4bdc73a1181605a0fcad 100644
--- a/pystencils/typing/__init__.py
+++ b/pystencils/typing/__init__.py
@@ -1,4 +1,6 @@
-from pystencils.typing.utilities import *
+
+
 from pystencils.typing.types import *
 from pystencils.typing.typed_sympy import *
 from pystencils.typing.cast_functions import *
+from pystencils.typing.utilities import *
diff --git a/pystencils/leaf_typing.py b/pystencils/typing/leaf_typing.py
similarity index 51%
rename from pystencils/leaf_typing.py
rename to pystencils/typing/leaf_typing.py
index 789bb4a8d8601e6a8cbabb5c87277c9e3ddc15c9..b6ef0362f17b8da213ea2157b11ebb19ead784c5 100644
--- a/pystencils/leaf_typing.py
+++ b/pystencils/typing/leaf_typing.py
@@ -1,5 +1,5 @@
-from collections import namedtuple, defaultdict
-from typing import List, Union
+from collections import namedtuple
+from typing import Union, Dict, Tuple, Any
 
 import numpy as np
 
@@ -9,13 +9,13 @@ import sympy as sp
 from pystencils import astnodes as ast, TypedSymbol
 from pystencils.bit_masks import flag_cond
 from pystencils.field import Field
-from pystencils.transformations import NestedScopes
-from pystencils.typing import CastFunc, create_type, get_type_of_expression, collate_types
+from pystencils.typing import AbstractType, BasicType, CastFunc, create_type, get_type_of_expression, collate_types
+from pystencils.utils import ContextVar
 from sympy.codegen import Assignment
 from sympy.logic.boolalg import BooleanFunction
 
 
-class KernelConstraintsCheck: # TODO rename
+class TypeAdder:
     # TODO: Logs
     # TODO: specification
     # TODO: split this into checker and leaf typing
@@ -33,33 +33,95 @@ class KernelConstraintsCheck: # TODO rename
     """
     FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
 
-    def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True):
-        self._type_for_symbol = type_for_symbol
-
-        self.scopes = NestedScopes()
-        self.field_writes = defaultdict(set)
-        self.fields_read = set()
-        self.check_independence_condition = check_independence_condition
-        self.check_double_write_condition = check_double_write_condition
+    def __init__(self, default_symbol_type: BasicType, type_for_symbol: Dict[str, BasicType],
+                 default_number_float: BasicType, default_number_int: BasicType):
+        self.type_for_symbol = type_for_symbol
+        self.default_symbol_type = ContextVar(default_symbol_type)
+        self.default_number_float = ContextVar(default_number_float)
+        self.default_number_int = ContextVar(default_number_int)
+
+    def get_symbol_type(self, symbol: str) -> BasicType:
+        return self.type_for_symbol.get(symbol, self.default_symbol_type.get())
+
+    # TODO: check if this adds only types to leave nodes of AST, get type info
+    def visit(self, obj):
+        if isinstance(obj, (list, tuple)):
+            return [self.visit(e) for e in obj]
+        if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
+            return self.process_assignment(obj)
+        elif isinstance(obj, ast.Conditional):
+            false_block = None if obj.false_block is None else self.visit(
+                obj.false_block)
+            result = ast.Conditional(self.process_expression(
+                obj.condition_expr, type_constants=False),
+                true_block=self.visit(obj.true_block),
+                false_block=false_block)
+            return result
+        elif isinstance(obj, ast.Block):
+            result = ast.Block([self.visit(e) for e in obj.args])
+            return result
+        elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
+            return obj
+        else:
+            raise ValueError("Invalid object in kernel " + str(type(obj)))
 
     def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]) -> ast.SympyAssignment:
         # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
         new_rhs = self.process_expression(assignment.rhs)
+        # TODO check type rhs lhs
         new_lhs = self.process_lhs(assignment.lhs)
         return ast.SympyAssignment(new_lhs, new_rhs)
 
+    # Type System Specification
+    # - Defined Types: TypedSymbol, Field, Field.Access, ...?
+    # - Indexed: always unsigned_integer64
+    # - Undefined Types: Symbol - Is specified in Config in the dict or as 'default_type'
+    # - Constants/Numbers: Are either integer or floating. The precision and sign is specified via config
+    #       - Example: 1.4 config:float32 -> float32
+    # - Expressions deduce types from arguments
+    # - Functions deduce types from arguments
+    # - default_type and default_float and default_int can be given for a list of assignment, or
+    #   individually as a list for assignment
+
+    # Possible Problems - Do we need to support this?
+    # - Mixture in expression with int and float
+    # - Mixture in expression with uint64 and sint64
+
+    def figure_out_type(self, expr) -> Tuple[Any, BasicType]:  #TODO or abstract type?
+        # Trivial cases
+        if isinstance(expr, Field.Access):
+            return expr, expr.dtype
+        elif isinstance(expr, TypedSymbol):
+            return expr, expr.dtype
+        elif isinstance(expr, sp.Symbol):
+            t = TypedSymbol(expr.name, self.get_symbol_type(expr.name))  # TODO with or without name
+            return t, t.dtype
+        elif isinstance(expr, np.generic):
+            assert False, f'Why do we have a np.generic in rhs???? {expr}'
+        elif isinstance(expr, sp.Number):
+            if expr.is_Float:
+                data_type = self.default_number_float.get()
+            elif expr.is_Integer:
+                data_type = self.default_number_int.get()
+            return expr, data_type
+        # TODO add everything in between
+        elif isinstance(expr, sp.Mul):
+            # TODO can we ignore this and move it to general expr handling, i.e. removing Mul?
+            types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)]
+            return None  # TODO collate types
+        elif isinstance(expr, sp.Indexed):
+            self.apply_type(expr, BasicType('uintp'))  # TODO double check
+            return None
+        elif isinstance(expr, sp.Pow):
+            # TODO sp.Pow should know a type
+            return None  # TODO
+        else:
+            types = [self.figure_out_type(arg) for arg in expr.args]
+            # TODO collate
+            return None  # TODO
 
-    # Expression
-    # 1) ask children if they are cocksure about a type
-    # 1b) Postpone clueless children (see 5)
-    # cocksure: Children have somewhere type from Field.Access, TypedSymbol, CastFunction or Function^TM
-    # clueless: Children without Field.Access,...
-    # 1c) none child is cocksure -> do nothing a return None, wait for recall from parent
-    # 2) collate_type of children
-    # 3) apply collated type on children
-    # 4) issue warnings of casts on cocksure children
-    # 5a) resume on clueless children with the collated type as default datatype, issue warning
-    # 5b) or apply special circumstances
+    def apply_type(self, expr, data_type: AbstractType):
+        pass
 
     def process_expression(self, rhs, type_constants=True):  # TODO default_type as parameter
         if isinstance(rhs, Field.Access):
@@ -115,13 +177,6 @@ class KernelConstraintsCheck: # TODO rename
             new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
             return rhs.func(*new_args) if new_args else rhs
 
-    @property
-    def fields_written(self):
-        """
-        Return all rhs fields
-        """
-        return set(k.field for k, v in self.field_writes.items() if len(v))
-
     def process_lhs(self, lhs: Union[Field.Access, TypedSymbol, sp.Symbol]):
         if not isinstance(lhs, (Field.Access, TypedSymbol)):
             return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
diff --git a/pystencils/typing/typed_sympy.py b/pystencils/typing/typed_sympy.py
index 0a253f748082aa7baf3359a716dcd0a873cb02fb..dffffe9e26763e0474a3d5ec3a5d59c28c3a1270 100644
--- a/pystencils/typing/typed_sympy.py
+++ b/pystencils/typing/typed_sympy.py
@@ -5,7 +5,6 @@ import sympy as sp
 from sympy.core.cache import cacheit
 
 from pystencils.typing.types import BasicType, create_type, PointerType
-from pystencils.typing.utilities import get_base_type
 
 
 def assumptions_from_dtype(dtype: Union[BasicType, np.dtype]):
@@ -44,6 +43,7 @@ class TypedSymbol(sp.Symbol):
         return obj
 
     def __new_stage2__(cls, name, dtype, **kwargs):  # TODO does not match signature of sp.Symbol???
+        # TODO: also Symbol should be allowed  ---> see sympy Variable
         assumptions = assumptions_from_dtype(dtype)  # TODO should by dtype a np.dtype or our Type???
         assumptions.update(kwargs)
         obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
@@ -59,10 +59,10 @@ class TypedSymbol(sp.Symbol):
 
     @property
     def dtype(self):
-        return self._dtype
+        return self.numpy_dtype
 
     def _hashable_content(self):
-        return super()._hashable_content(), hash(self._dtype)
+        return super()._hashable_content(), hash(self.numpy_dtype)
 
     def __getnewargs__(self):
         return self.name, self.dtype
@@ -160,6 +160,8 @@ class FieldPointerSymbol(TypedSymbol):
         return obj
 
     def __new_stage2__(cls, field_name, field_dtype, const):
+        from pystencils.typing.utilities import get_base_type
+
         name = f"_data_{field_name}"
         dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True)
         obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py
index eabe87dbd837940f0ec48edb33962befb482d354..318b0932ac1733c20f474721c26a6246b78d874a 100644
--- a/pystencils/typing/types.py
+++ b/pystencils/typing/types.py
@@ -38,7 +38,7 @@ def numpy_name_to_c(name: str) -> str:
         raise NotImplementedError(f"Can't map numpy to C name for {name}")
 
 
-class AbstractType(sp.Atom, ABC):
+class AbstractType(sp.Atom):
     # TODO: inherits from sp.Atom because of cast function (and maybe others)
     # TODO: is this necessary?
     def __new__(cls, *args, **kwargs):
diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py
index a7a506f0cbfcd724d0d9f6b0f6e900568f94dda1..2f3d175daa3aa80d8138977d69a8a9c06a6543ef 100644
--- a/pystencils/typing/utilities.py
+++ b/pystencils/typing/utilities.py
@@ -1,18 +1,16 @@
 from collections import defaultdict
 from functools import partial
-from typing import Tuple, Union, List, Dict
+from typing import Tuple, List, Dict
 
 import numpy as np
 import sympy as sp
-from pystencils import astnodes as ast
-from pystencils.kernel_contrains_check import KernelConstraintsCheck
+# from pystencils.typing.leaf_typing import TypeAdder  # TODO this should be leaf_typing
 from sympy.codegen import Assignment
 from sympy.logic.boolalg import Boolean, BooleanFunction
 
 import pystencils
-from pystencils.cache import memorycache, memorycache_if_hashable
-from pystencils.utils import all_equal
-from pystencils.typing.types import AbstractType, BasicType, VectorType, PointerType, StructType, create_type
+from pystencils.cache import memorycache_if_hashable
+from pystencils.typing.types import BasicType, VectorType, PointerType, create_type
 from pystencils.typing.cast_functions import CastFunc, PointerArithmeticFunc
 from pystencils.typing.typed_sympy import TypedSymbol
 
@@ -74,49 +72,53 @@ def peel_off_type(dtype, type_to_peel_off):
     return dtype
 
 
-
 ############################# This is basically our type system ########################################################
-def collate_types(types,
-                  forbid_collation_to_complex=False,  # TODO: type system shouldn't need this!!!
-                  forbid_collation_to_float=False,  # TODO: type system shouldn't need this!!!
-                  default_float_type='float64',
-                  # TODO: AST leaves should be typed. Expressions should be able to find out correct type
-                  default_int_type='int64'):  # TODO: AST leaves should be typed. Expressions should be able to find out correct type
+
+def result_type(*args: np.dtype):
+    s = sorted(args, key=lambda x: x.itemsize)
+
+    def kind_to_value(kind: str) -> int:
+        if kind == 'f':
+            return 3
+        elif kind == 'i':
+            return 2
+        elif kind == 'u':
+            return 1
+        elif kind == 'b':
+            return 0
+        else:
+            raise NotImplementedError(f'{kind=} is not a supported kind of a type. See "numpy.dtype.kind" for options')
+    s = sorted(s, key=lambda x: kind_to_value(x.kind))
+    return s[-1]
+
+
+def collate_types(types):
     """
     Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
     Uses the collation rules from numpy.
     """
     # TODO: use np.can_cast and np.promote_types and np.result_type and np.find_common_type
-    if forbid_collation_to_complex:
-        types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.complexfloating)]
-        if not types:
-            return create_type(default_float_type)
-
-    if forbid_collation_to_float:
-        types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)]
-        if not types:
-            return create_type(default_int_type)
-
-    # Pointer arithmetic case i.e. pointer + integer is allowed
-    if any(type(t) is PointerType for t in types):
-        pointer_type = None
-        for t in types:
-            if type(t) is PointerType:
-                if pointer_type is not None:
-                    raise ValueError("Cannot collate the combination of two pointer types")
-                pointer_type = t
-            elif type(t) is BasicType:
-                if not (t.is_int() or t.is_uint()):
-                    raise ValueError("Invalid pointer arithmetic")
-            else:
-                raise ValueError("Invalid pointer arithmetic")
-        return pointer_type
 
-    # peel of vector types, if at least one vector type occurred the result will also be the vector type
-    vector_type = [t for t in types if type(t) is VectorType]
-    if not all_equal(t.width for t in vector_type):
-        raise ValueError("Collation failed because of vector types with different width")
-    types = [peel_off_type(t, VectorType) for t in types]
+    # # Pointer arithmetic case i.e. pointer + integer is allowed
+    # if any(type(t) is PointerType for t in types):
+    #     pointer_type = None
+    #     for t in types:
+    #         if type(t) is PointerType:
+    #             if pointer_type is not None:
+    #                 raise ValueError("Cannot collate the combination of two pointer types")
+    #             pointer_type = t
+    #         elif type(t) is BasicType:
+    #             if not (t.is_int() or t.is_uint()):
+    #                 raise ValueError("Invalid pointer arithmetic")
+    #         else:
+    #             raise ValueError("Invalid pointer arithmetic")
+    #     return pointer_type
+    #
+    # # peel of vector types, if at least one vector type occurred the result will also be the vector type
+    # vector_type = [t for t in types if type(t) is VectorType]
+    # if not all_equal(t.width for t in vector_type):
+    #     raise ValueError("Collation failed because of vector types with different width")
+    # types = [peel_off_type(t, VectorType) for t in types]
 
     # now we should have a list of basic types - struct types are not yet supported
     assert all(type(t) is BasicType for t in types)
@@ -126,8 +128,8 @@ def collate_types(types,
     # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
     result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
     result = BasicType(result_numpy_type)
-    if vector_type:
-        result = VectorType(result, vector_type[0].width)
+    # if vector_type:
+    #     result = VectorType(result, vector_type[0].width)
     return result
 
 
@@ -166,6 +168,7 @@ def get_type_of_expression(expr,
     elif isinstance(expr, TypedSymbol):
         return expr.dtype
     elif isinstance(expr, sp.Symbol):
+        # TODO delete if case
         if symbol_type_dict:
             return symbol_type_dict[expr.name]
         else:
@@ -288,36 +291,7 @@ def add_types(eqs: List[Assignment], type_for_symbol: Dict[sp.Symbol, np.dtype],
     check = KernelConstraintsCheck(type_for_symbol, check_independence_condition,
                                    check_double_write_condition=check_double_write_condition)
 
-    # TODO: check if this adds only types to leave nodes of AST, get type info
-    def visit(obj):
-        if isinstance(obj, (list, tuple)):
-            return [visit(e) for e in obj]
-        if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
-            return check.process_assignment(obj)
-        elif isinstance(obj, ast.Conditional):
-            check.scopes.push()
-            # Disable double write check inside conditionals
-            # would be triggered by e.g. in-kernel boundaries
-            old_double_write = check.check_double_write_condition
-            check.check_double_write_condition = False
-            false_block = None if obj.false_block is None else visit(
-                obj.false_block)
-            result = ast.Conditional(check.process_expression(
-                obj.condition_expr, type_constants=False),
-                true_block=visit(obj.true_block),
-                false_block=false_block)
-            check.check_double_write_condition = old_double_write
-            check.scopes.pop()
-            return result
-        elif isinstance(obj, ast.Block):
-            check.scopes.push()
-            result = ast.Block([visit(e) for e in obj.args])
-            check.scopes.pop()
-            return result
-        elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
-            return obj
-        else:
-            raise ValueError("Invalid object in kernel " + str(type(obj)))
+
 
     typed_equations = visit(eqs)
 
@@ -333,6 +307,8 @@ def insert_casts(node):
     Returns:
         modified AST
     """
+    from pystencils.astnodes import SympyAssignment, ResolvedFieldAccess, LoopOverCoordinate, Block
+
     def cast(zipped_args_types, target_dtype):
         """
         Adds casts to the arguments if their type differs from the target type
@@ -385,7 +361,7 @@ def insert_casts(node):
             return pointer_arithmetic(zipped)
         else:
             return node.func(*cast(zipped, target))
-    elif node.func is ast.SympyAssignment:
+    elif node.func is SympyAssignment:
         lhs = args[0]
         rhs = args[1]
         target = get_type_of_expression(lhs)
@@ -393,13 +369,13 @@ def insert_casts(node):
             return node.func(*args)  # TODO fix, not complete
         else:
             return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
-    elif node.func is ast.ResolvedFieldAccess:
+    elif node.func is ResolvedFieldAccess:
         return node
-    elif node.func is ast.Block:
+    elif node.func is Block:
         for old_arg, new_arg in zip(node.args, args):
             node.replace(old_arg, new_arg)
         return node
-    elif node.func is ast.LoopOverCoordinate:
+    elif node.func is LoopOverCoordinate:
         for old_arg, new_arg in zip(node.args, args):
             node.replace(old_arg, new_arg)
         return node
@@ -464,18 +440,19 @@ def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='i
     Returns:
         dictionary, mapping symbol name to type
     """
+    from pystencils.astnodes import SympyAssignment, Conditional, Node
     result = defaultdict(lambda: default_type)
     if hasattr(default_type, 'numpy_dtype'):
         result['_complex_type'] = (np.zeros((1,), default_type.numpy_dtype) * 1j).dtype
     else:
         result['_complex_type'] = (np.zeros((1,), default_type) * 1j).dtype
     for eq in eqs:
-        if isinstance(eq, ast.Conditional):
+        if isinstance(eq, Conditional):
             result.update(typing_from_sympy_inspection(eq.true_block.args))
             if eq.false_block:
                 result.update(typing_from_sympy_inspection(
                     eq.false_block.args))
-        elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment):
+        elif isinstance(eq, Node) and not isinstance(eq, SympyAssignment):
             continue
         else:
             from pystencils.cpu.vectorization import vec_all, vec_any
diff --git a/pystencils/utils.py b/pystencils/utils.py
index 3afdbc582ef7dece1933dbaf5b00be149f9cbd30..dc8d35ee64dcfdb0ef6f9f687526fe3379ce8fbd 100644
--- a/pystencils/utils.py
+++ b/pystencils/utils.py
@@ -220,3 +220,17 @@ class LinearEquationSystem:
                 break
             result -= 1
         self.next_zero_row = result
+
+
+class ContextVar:
+    def __init__(self, value):
+        self.stack = [value]
+
+    @contextmanager
+    def __call__(self, new_value):
+        self.stack.append(new_value)
+        yield self
+        self.stack.pop()
+
+    def get(self):
+        return self.stack[-1]
diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py
index 5c2b008e4ba4b0bd1fbf28e96fbef8affeef0e4c..774306d8d7c36ab2ff4a026e89f4ed4e78785c4e 100644
--- a/pystencils_tests/test_types.py
+++ b/pystencils_tests/test_types.py
@@ -3,7 +3,40 @@ import numpy as np
 
 import pystencils as ps
 from pystencils.typing import TypedSymbol, get_type_of_expression, VectorType, collate_types, create_type, \
-    typed_symbols, type_all_numbers, matrix_symbols, CastFunc, PointerArithmeticFunc, PointerType
+    typed_symbols, CastFunc, PointerArithmeticFunc, PointerType, result_type
+
+
+def test_result_type():
+    i = np.dtype('int32')
+    l = np.dtype('int64')
+    ui = np.dtype('uint32')
+    ul = np.dtype('uint64')
+    f = np.dtype('float32')
+    d = np.dtype('float64')
+    b = np.dtype('bool')
+
+    assert result_type(i, l) == l
+    assert result_type(l, i) == l
+    assert result_type(ui, i) == i
+    assert result_type(ui, l) == l
+    assert result_type(ul, i) == i
+    assert result_type(ul, l) == l
+    assert result_type(d, f) == d
+    assert result_type(f, d) == d
+    assert result_type(i, f) == f
+    assert result_type(l, f) == f
+    assert result_type(ui, f) == f
+    assert result_type(ul, f) == f
+    assert result_type(i, d) == d
+    assert result_type(l, d) == d
+    assert result_type(ui, d) == d
+    assert result_type(ul, d) == d
+    assert result_type(b, i) == i
+    assert result_type(b, l) == l
+    assert result_type(b, ui) == ui
+    assert result_type(b, ul) == ul
+    assert result_type(b, f) == f
+    assert result_type(b, d) == d
 
 
 def test_collation():