diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 37bc43f979110c4393eabc636793215d713bc1f7..c49077bb27a2af158624e7222d999973487f0951 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -101,6 +101,41 @@ minimal-sympy-master:
   tags:
     - docker
 
+
+pycodegen-integration:
+  image: i10git.cs.fau.de:5005/pycodegen/pycodegen/full
+  stage: test
+  when: manual
+  script:
+    - git clone https://gitlab-ci-token:${CI_JOB_TOKEN}@i10git.cs.fau.de/pycodegen/pycodegen.git
+    - cd pycodegen
+    - git submodule sync --recursive
+    - git submodule update --init --recursive
+    - git submodule foreach git fetch origin   # compare the latest master version!
+    - git submodule foreach git reset --hard origin/master
+    - cd pystencils
+    - git remote add test $CI_REPOSITORY_URL
+    - git fetch test
+    - git reset --hard $CI_COMMIT_SHA
+    - cd ..
+    - export PYTHONPATH=`pwd`/pystencils:`pwd`/lbmpy:`pwd`/pygrandchem:`pwd`/pystencils_walberla:`pwd`/lbmpy_walberla
+    - ./install_walberla.sh
+    - export NUM_CORES=$(nproc --all)
+    - mkdir -p ~/.config/matplotlib
+    - echo "backend:template" > ~/.config/matplotlib/matplotlibrc
+    - cd pystencils
+    - py.test -v -n $NUM_CORES .
+    - cd ../lbmpy
+    - py.test -v -n $NUM_CORES .
+    - cd ../pygrandchem
+    - py.test -v -n $NUM_CORES .
+    - cd ../walberla/build/
+    - make CodegenJacobiCPU CodegenJacobiGPU MicroBenchmarkGpuLbm LbCodeGenerationExample
+  tags:
+    - docker
+    - cuda
+    - AVX
+
 # -------------------- Linter & Documentation --------------------------------------------------------------------------
 
 
diff --git a/pystencils/cache.py b/pystencils/cache.py
index ecf9727f0aba74b2b5cccae8d0ed0d781cac7ccf..5df15ae7c7498e6e849c93f2f071435560a2c415 100644
--- a/pystencils/cache.py
+++ b/pystencils/cache.py
@@ -1,10 +1,14 @@
 import os
+from collections import Hashable
+from functools import partial
+from itertools import chain
 
 try:
     from functools import lru_cache as memorycache
 except ImportError:
     from backports.functools_lru_cache import lru_cache as memorycache
 
+
 try:
     from joblib import Memory
     from appdirs import user_cache_dir
@@ -22,6 +26,20 @@ except ImportError:
         return o
 
 
+def _wrapper(wrapped_func, cached_func, *args, **kwargs):
+    if all(isinstance(a, Hashable) for a in chain(args, kwargs.values())):
+        return cached_func(*args, **kwargs)
+    else:
+        return wrapped_func(*args, **kwargs)
+
+
+def memorycache_if_hashable(maxsize=128, typed=False):
+
+    def wrapper(func):
+        return partial(_wrapper, func, memorycache(maxsize, typed)(func))
+
+    return wrapper
+
 # Disable memory cache:
 # disk_cache = lambda o: o
 # disk_cache_no_fallback = lambda o: o
diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index fc7f8a4f0e916922297a7314d432627503ba8464..45099302b9615164fe329d21f8f85c631d45b1e3 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -1,11 +1,13 @@
 import ctypes
+from collections import defaultdict
+from functools import partial
 
 import numpy as np
 import sympy as sp
 from sympy.core.cache import cacheit
 from sympy.logic.boolalg import Boolean
 
-from pystencils.cache import memorycache
+from pystencils.cache import memorycache, memorycache_if_hashable
 from pystencils.utils import all_equal
 
 try:
@@ -408,11 +410,22 @@ def collate_types(types):
     return result
 
 
-@memorycache(maxsize=2048)
-def get_type_of_expression(expr, default_float_type='double', default_int_type='int'):
+@memorycache_if_hashable(maxsize=2048)
+def get_type_of_expression(expr,
+                           default_float_type='double',
+                           default_int_type='int',
+                           symbol_type_dict=None):
     from pystencils.astnodes import ResolvedFieldAccess
     from pystencils.cpu.vectorization import vec_all, vec_any
 
+    if not symbol_type_dict:
+        symbol_type_dict = defaultdict(lambda: create_type('double'))
+
+    get_type = partial(get_type_of_expression,
+                       default_float_type=default_float_type,
+                       default_int_type=default_int_type,
+                       symbol_type_dict=symbol_type_dict)
+
     expr = sp.sympify(expr)
     if isinstance(expr, sp.Integer):
         return create_type(default_int_type)
@@ -423,14 +436,17 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type='
     elif isinstance(expr, TypedSymbol):
         return expr.dtype
     elif isinstance(expr, sp.Symbol):
-        raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
+        if symbol_type_dict:
+            return symbol_type_dict[expr.name]
+        else:
+            raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
     elif isinstance(expr, cast_func):
         return expr.args[1]
-    elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
+    elif isinstance(expr, (vec_any, vec_all)):
         return create_type("bool")
     elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
-        collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
-        collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
+        collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
+        collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
         if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
             collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
         return collated_result_type
@@ -440,16 +456,16 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type='
     elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
         # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
         result = create_type("bool")
-        vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
+        vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
         if vec_args:
             result = VectorType(result, width=vec_args[0].width)
         return result
-    elif isinstance(expr, sp.Pow):
-        return get_type_of_expression(expr.args[0])
+    elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)):
+        return get_type(expr.args[0])
     elif isinstance(expr, sp.Expr):
         expr: sp.Expr
         if expr.args:
-            types = tuple(get_type_of_expression(a) for a in expr.args)
+            types = tuple(get_type(a) for a in expr.args)
             return collate_types(types)
         else:
             if expr.is_integer:
diff --git a/pystencils/test_type_interference.py b/pystencils/test_type_interference.py
new file mode 100644
index 0000000000000000000000000000000000000000..0daa6f9d2a36948184d537809457b4aaf8001d29
--- /dev/null
+++ b/pystencils/test_type_interference.py
@@ -0,0 +1,26 @@
+from sympy.abc import a, b, c, d, e, f
+
+import pystencils
+from pystencils.data_types import cast_func, create_type
+
+
+def test_type_interference():
+    x = pystencils.fields('x:  float32[3d]')
+    assignments = pystencils.AssignmentCollection({
+        a: cast_func(10, create_type('float64')),
+        b: cast_func(10, create_type('uint16')),
+        e: 11,
+        c: b,
+        f: c + b,
+        d: c + b + x.center + e,
+        x.center: c + b + x.center
+    })
+
+    ast = pystencils.create_kernel(assignments)
+
+    code = str(pystencils.show_code(ast))
+    print(code)
+    assert 'double a' in code
+    assert 'uint16_t b' in code
+    assert 'uint16_t f' in code
+    assert 'int64_t e' in code
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index 39a9abddbad8dab1ae6c4e3361a43ff1885c4ee8..554b9cf9b2257e6b0ef97b24239ea210baa8b63b 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -147,7 +147,10 @@ def get_common_shape(field_set):
     return shape
 
 
-def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
+def make_loop_over_domain(body,
+                          iteration_slice=None,
+                          ghost_layers=None,
+                          loop_order=None):
     """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
 
     Args:
@@ -189,17 +192,21 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
         if iteration_slice is None:
             begin = ghost_layers[loop_coordinate][0]
             end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
-            new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
+            new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate,
+                                              begin, end, 1)
             current_body = ast.Block([new_loop])
         else:
             slice_component = iteration_slice[loop_coordinate]
             if type(slice_component) is slice:
                 sc = slice_component
-                new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
+                new_loop = ast.LoopOverCoordinate(current_body,
+                                                  loop_coordinate, sc.start,
+                                                  sc.stop, sc.step)
                 current_body = ast.Block([new_loop])
             else:
-                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
-                                                 sp.sympify(slice_component))
+                assignment = ast.SympyAssignment(
+                    ast.LoopOverCoordinate.get_loop_counter_symbol(
+                        loop_coordinate), sp.sympify(slice_component))
                 current_body.insert_front(assignment)
 
     return current_body, ghost_layers
@@ -238,9 +245,11 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
         offset += field.strides[coordinate_id] * coordinate_value
 
         if coordinate_id < field.spatial_dimensions:
-            offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
+            offset += field.strides[coordinate_id] * field_access.offsets[
+                coordinate_id]
             if type(field_access.offsets[coordinate_id]) is int:
-                name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
+                name += "_%d%d" % (coordinate_id,
+                                   field_access.offsets[coordinate_id])
             else:
                 list_to_hash.append(field_access.offsets[coordinate_id])
         else:
@@ -257,7 +266,8 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
     return new_ptr, offset
 
 
-def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
+def parse_base_pointer_info(base_pointer_specification, loop_order,
+                            spatial_dimensions, index_dimensions):
     """
     Creates base pointer specification for :func:`resolve_field_accesses` function.
 
@@ -295,11 +305,13 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime
 
         def add_new_element(elem):
             if elem >= spatial_dimensions + index_dimensions:
-                raise ValueError("Coordinate %d does not exist" % (elem,))
+                raise ValueError("Coordinate %d does not exist" % (elem, ))
             new_group.append(elem)
             if elem in specified_coordinates:
-                raise ValueError("Coordinate %d specified two times" % (elem,))
+                raise ValueError("Coordinate %d specified two times" %
+                                 (elem, ))
             specified_coordinates.add(elem)
+
         for element in spec_group:
             if type(element) is int:
                 add_new_element(element)
@@ -320,7 +332,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime
                 index = int(element[len("index"):])
                 add_new_element(spatial_dimensions + index)
             else:
-                raise ValueError("Unknown specification %s" % (element,))
+                raise ValueError("Unknown specification %s" % (element, ))
 
         result.append(new_group)
 
@@ -345,30 +357,42 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
         base buffer index - required by 'resolve_buffer_accesses' function
     """
     if loop_counters is None or loop_iterations is None:
-        loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
+        loops = [
+            l for l in filtered_tree_iteration(
+                ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)
+        ]
         loops.reverse()
-        parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
+        parents_of_innermost_loop = list(
+            parents_of_type(loops[0],
+                            ast.LoopOverCoordinate,
+                            include_current=True))
         assert len(loops) == len(parents_of_innermost_loop)
-        assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
+        assert all(l1 is l2
+                   for l1, l2 in zip(loops, parents_of_innermost_loop))
 
         loop_iterations = [(l.stop - l.start) / l.step for l in loops]
         loop_counters = [l.loop_counter_symbol for l in loops]
 
     field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
-    buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
+    buffer_accesses = {
+        fa
+        for fa in field_accesses if FieldType.is_buffer(fa.field)
+    }
     loop_counters = [v * len(buffer_accesses) for v in loop_counters]
 
     base_buffer_index = loop_counters[0]
     stride = 1
     for idx, var in enumerate(loop_counters[1:]):
         cur_stride = loop_iterations[idx]
-        stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
+        stride *= int(cur_stride) if isinstance(cur_stride,
+                                                float) else cur_stride
         base_buffer_index += var * stride
     return base_buffer_index
 
 
-def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
-
+def resolve_buffer_accesses(ast_node,
+                            base_buffer_index,
+                            read_only_field_names=set()):
     def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
         if isinstance(expr, AbstractField.AbstractAccess):
             field_access = expr
@@ -378,17 +402,24 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s
                 return expr
 
             buffer = field_access.field
-            field_ptr = FieldPointerSymbol(buffer.name, buffer.dtype, const=buffer.name in read_only_field_names)
+            field_ptr = FieldPointerSymbol(
+                buffer.name,
+                buffer.dtype,
+                const=buffer.name in read_only_field_names)
 
             buffer_index = base_buffer_index
             if len(field_access.index) > 1:
-                raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')
+                raise RuntimeError(
+                    'Only indexing dimensions up to 1 are currently supported in buffers!'
+                )
 
             if len(field_access.index) > 0:
                 cell_index = field_access.index[0]
                 buffer_index += cell_index
 
-            result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
+            result = ast.ResolvedFieldAccess(field_ptr, buffer_index,
+                                             field_access.field,
+                                             field_access.offsets,
                                              field_access.index)
 
             return visit_sympy_expr(result, enclosing_block, sympy_assignment)
@@ -396,16 +427,23 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s
             if isinstance(expr, ast.ResolvedFieldAccess):
                 return expr
 
-            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
-            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
+            new_args = [
+                visit_sympy_expr(e, enclosing_block, sympy_assignment)
+                for e in expr.args
+            ]
+            kwargs = {
+                'evaluate': False
+            } if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
             return expr.func(*new_args, **kwargs) if new_args else expr
 
     def visit_node(sub_ast):
         if isinstance(sub_ast, ast.SympyAssignment):
             enclosing_block = sub_ast.parent
             assert type(enclosing_block) is ast.Block
-            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
-            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
+            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block,
+                                           sub_ast)
+            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block,
+                                           sub_ast)
         else:
             for i, a in enumerate(sub_ast.args):
                 visit_node(a)
@@ -413,7 +451,8 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s
     return visit_node(ast_node)
 
 
-def resolve_field_accesses(ast_node, read_only_field_names=set(),
+def resolve_field_accesses(ast_node,
+                           read_only_field_names=set(),
                            field_to_base_pointer_info=MappingProxyType({}),
                            field_to_fixed_coordinates=MappingProxyType({})):
     """
@@ -430,8 +469,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
     Returns
         transformed AST
     """
-    field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
-    field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
+    field_to_base_pointer_info = OrderedDict(
+        sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
+    field_to_fixed_coordinates = OrderedDict(
+        sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
 
     def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
         if isinstance(expr, AbstractField.AbstractAccess):
@@ -439,20 +480,29 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
             field = field_access.field
 
             if field_access.indirect_addressing_fields:
-                new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
-                                    for off in field_access.offsets)
-                new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
-                                    if isinstance(ind, sp.Basic) else ind
-                                    for ind in field_access.index)
+                new_offsets = tuple(
+                    visit_sympy_expr(off, enclosing_block, sympy_assignment)
+                    for off in field_access.offsets)
+                new_indices = tuple(
+                    visit_sympy_expr(ind, enclosing_block, sympy_assignment
+                                     ) if isinstance(ind, sp.Basic) else ind
+                    for ind in field_access.index)
                 field_access = Field.Access(field_access.field, new_offsets,
-                                            new_indices, field_access.is_absolute_access)
+                                            new_indices,
+                                            field_access.is_absolute_access)
 
             if field.name in field_to_base_pointer_info:
                 base_pointer_info = field_to_base_pointer_info[field.name]
             else:
-                base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
+                base_pointer_info = [
+                    list(
+                        range(field.index_dimensions + field.spatial_dimensions))
+                ]
 
-            field_ptr = FieldPointerSymbol(field.name, field.dtype, const=field.name in read_only_field_names)
+            field_ptr = FieldPointerSymbol(
+                field.name,
+                field.dtype,
+                const=field.name in read_only_field_names)
 
             def create_coordinate_dict(group_param):
                 coordinates = {}
@@ -460,12 +510,15 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
                     if e < field.spatial_dimensions:
                         if field.name in field_to_fixed_coordinates:
                             if not field_access.is_absolute_access:
-                                coordinates[e] = field_to_fixed_coordinates[field.name][e]
+                                coordinates[e] = field_to_fixed_coordinates[
+                                    field.name][e]
                             else:
                                 coordinates[e] = 0
                         else:
                             if not field_access.is_absolute_access:
-                                coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
+                                coordinates[
+                                    e] = ast.LoopOverCoordinate.get_loop_counter_symbol(
+                                        e)
                             else:
                                 coordinates[e] = 0
                         coordinates[e] *= field.dtype.item_size
@@ -474,9 +527,11 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
                             assert field.index_dimensions == 1
                             accessed_field_name = field_access.index[0]
                             assert isinstance(accessed_field_name, str)
-                            coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
+                            coordinates[e] = field.dtype.get_element_offset(
+                                accessed_field_name)
                         else:
-                            coordinates[e] = field_access.index[e - field.spatial_dimensions]
+                            coordinates[e] = field_access.index[
+                                e - field.spatial_dimensions]
 
                 return coordinates
 
@@ -484,19 +539,27 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
 
             for group in reversed(base_pointer_info[1:]):
                 coord_dict = create_coordinate_dict(group)
-                new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
+                new_ptr, offset = create_intermediate_base_pointer(
+                    field_access, coord_dict, last_pointer)
                 if new_ptr not in enclosing_block.symbols_defined:
-                    new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
-                    enclosing_block.insert_before(new_assignment, sympy_assignment)
+                    new_assignment = ast.SympyAssignment(new_ptr,
+                                                         last_pointer + offset,
+                                                         is_const=False)
+                    enclosing_block.insert_before(new_assignment,
+                                                  sympy_assignment)
                 last_pointer = new_ptr
 
             coord_dict = create_coordinate_dict(base_pointer_info[0])
-            _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
-            result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
-                                             field_access.offsets, field_access.index)
+            _, offset = create_intermediate_base_pointer(
+                field_access, coord_dict, last_pointer)
+            result = ast.ResolvedFieldAccess(last_pointer, offset,
+                                             field_access.field,
+                                             field_access.offsets,
+                                             field_access.index)
 
             if isinstance(get_base_type(field_access.field.dtype), StructType):
-                new_type = field_access.field.dtype.get_element_type(field_access.index[0])
+                new_type = field_access.field.dtype.get_element_type(
+                    field_access.index[0])
                 result = reinterpret_cast_func(result, new_type)
 
             return visit_sympy_expr(result, enclosing_block, sympy_assignment)
@@ -504,20 +567,28 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
             if isinstance(expr, ast.ResolvedFieldAccess):
                 return expr
 
-            new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
-            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
+            new_args = [
+                visit_sympy_expr(e, enclosing_block, sympy_assignment)
+                for e in expr.args
+            ]
+            kwargs = {
+                'evaluate': False
+            } if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
             return expr.func(*new_args, **kwargs) if new_args else expr
 
     def visit_node(sub_ast):
         if isinstance(sub_ast, ast.SympyAssignment):
             enclosing_block = sub_ast.parent
             assert type(enclosing_block) is ast.Block
-            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
-            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
+            sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block,
+                                           sub_ast)
+            sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block,
+                                           sub_ast)
         elif isinstance(sub_ast, ast.Conditional):
             enclosing_block = sub_ast.parent
             assert type(enclosing_block) is ast.Block
-            sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, enclosing_block, sub_ast)
+            sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr,
+                                                      enclosing_block, sub_ast)
             visit_node(sub_ast.true_block)
             if sub_ast.false_block:
                 visit_node(sub_ast.false_block)
@@ -561,11 +632,14 @@ def move_constants_before_loop(ast_node):
             element = element.parent
         return last_block, last_block_child
 
-    def check_if_assignment_already_in_block(assignment, target_block, rhs_or_lhs=True):
+    def check_if_assignment_already_in_block(assignment,
+                                             target_block,
+                                             rhs_or_lhs=True):
         for arg in target_block.args:
             if type(arg) is not ast.SympyAssignment:
                 continue
-            if (rhs_or_lhs and arg.rhs == assignment.rhs) or (not rhs_or_lhs and arg.lhs == assignment.lhs):
+            if (rhs_or_lhs and arg.rhs == assignment.rhs) or (
+                    not rhs_or_lhs and arg.lhs == assignment.lhs):
                 return arg
         return None
 
@@ -588,21 +662,24 @@ def move_constants_before_loop(ast_node):
             # Before traversing the next child, all symbols are substituted first.
             child.subs(substitute_variables)
 
-            if not isinstance(child, ast.SympyAssignment):  # only move SympyAssignments
+            if not isinstance(
+                    child, ast.SympyAssignment):  # only move SympyAssignments
                 block.append(child)
                 continue
 
             target, child_to_insert_before = find_block_to_move_to(child)
-            if target == block:     # movement not possible
+            if target == block:  # movement not possible
                 target.append(child)
             else:
                 if isinstance(child, ast.SympyAssignment):
-                    exists_already = check_if_assignment_already_in_block(child, target, False)
+                    exists_already = check_if_assignment_already_in_block(
+                        child, target, False)
                 else:
                     exists_already = False
 
                 if not exists_already:
-                    rhs_identical = check_if_assignment_already_in_block(child, target, True)
+                    rhs_identical = check_if_assignment_already_in_block(
+                        child, target, True)
                     if rhs_identical:
                         # there is already an assignment out there with the same rhs
                         # -> replace all lhs symbols in this block with the lhs of the outer assignment
@@ -617,7 +694,9 @@ def move_constants_before_loop(ast_node):
                     # -> symbol has to be renamed
                     assert isinstance(child.lhs, TypedSymbol)
                     new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
-                    target.insert_before(ast.SympyAssignment(new_symbol, child.rhs), child_to_insert_before)
+                    target.insert_before(
+                        ast.SympyAssignment(new_symbol, child.rhs),
+                        child_to_insert_before)
                     substitute_variables[child.lhs] = new_symbol
 
 
@@ -633,7 +712,9 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
     """
     all_loops = ast_node.atoms(ast.LoopOverCoordinate)
     inner_loop = [l for l in all_loops if l.is_innermost_loop]
-    assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
+    assert len(
+        inner_loop
+    ) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
     inner_loop = inner_loop[0]
     assert type(inner_loop.body) is ast.Block
     outer_loop = [l for l in all_loops if l.is_outermost_loop]
@@ -664,28 +745,38 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
             if not isinstance(symbol, AbstractField.AbstractAccess):
                 assert type(symbol) is TypedSymbol
                 new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
-                symbols_with_temporary_array[symbol] = sp.IndexedBase(new_ts,
-                                                                      shape=(1,))[inner_loop.loop_counter_symbol]
+                symbols_with_temporary_array[symbol] = sp.IndexedBase(
+                    new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
 
         assignment_group = []
         for assignment in inner_loop.body.args:
             if assignment.lhs in symbols_resolved:
-                new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items())
-                if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group:
+                new_rhs = assignment.rhs.subs(
+                    symbols_with_temporary_array.items())
+                if not isinstance(assignment.lhs, AbstractField.AbstractAccess
+                                  ) and assignment.lhs in symbol_group:
                     assert type(assignment.lhs) is TypedSymbol
-                    new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
-                    new_lhs = sp.IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
+                    new_ts = TypedSymbol(assignment.lhs.name,
+                                         PointerType(assignment.lhs.dtype))
+                    new_lhs = sp.IndexedBase(
+                        new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
                 else:
                     new_lhs = assignment.lhs
                 assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
         assignment_groups.append(assignment_group)
 
-    new_loops = [inner_loop.new_loop_with_different_body(ast.Block(group)) for group in assignment_groups]
+    new_loops = [
+        inner_loop.new_loop_with_different_body(ast.Block(group))
+        for group in assignment_groups
+    ]
     inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
 
     for tmp_array in symbols_with_temporary_array:
-        tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
-        alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start)
+        tmp_array_pointer = TypedSymbol(tmp_array.name,
+                                        PointerType(tmp_array.dtype))
+        alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer,
+                                                   inner_loop.stop,
+                                                   inner_loop.start)
         free_node = ast.TemporaryMemoryFree(alloc_node)
         outer_loop.parent.insert_front(alloc_node)
         outer_loop.parent.append(free_node)
@@ -715,15 +806,17 @@ def cut_loop(loop_node, cutting_points):
         elif new_end - new_start == 0:
             pass
         else:
-            new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
-                                              new_start, new_end, loop_node.step)
+            new_loop = ast.LoopOverCoordinate(
+                deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
+                new_start, new_end, loop_node.step)
             new_loops.append(new_loop)
         new_start = new_end
     loop_node.parent.replace(loop_node, new_loops)
     return new_loops
 
 
-def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = False) -> None:
+def simplify_conditionals(node: ast.Node,
+                          loop_counter_simplification: bool = False) -> None:
     """Removes conditionals that are always true/false.
 
     Args:
@@ -739,14 +832,18 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa
         if conditional.condition_expr == sp.true:
             conditional.parent.replace(conditional, [conditional.true_block])
         elif conditional.condition_expr == sp.false:
-            conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
+            conditional.parent.replace(
+                conditional,
+                [conditional.false_block] if conditional.false_block else [])
         elif loop_counter_simplification:
             try:
                 # noinspection PyUnresolvedReferences
                 from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
                 simplify_loop_counter_dependent_conditional(conditional)
             except ImportError:
-                warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")
+                warnings.warn(
+                    "Integer simplifications in conditionals skipped, because ISLpy package not installed"
+                )
 
 
 def cleanup_blocks(node: ast.Node) -> None:
@@ -808,18 +905,28 @@ class KernelConstraintsCheck:
         elif type_constants and isinstance(rhs, sp.Number):
             return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
         elif isinstance(rhs, sp.Mul):
-            new_args = [self.process_expression(arg, type_constants) if arg not in (-1, 1) else arg for arg in rhs.args]
+            new_args = [
+                self.process_expression(arg, type_constants)
+                if arg not in (-1, 1) else arg for arg in rhs.args
+            ]
             return rhs.func(*new_args) if new_args else rhs
         elif isinstance(rhs, sp.Indexed):
             return rhs
         elif isinstance(rhs, cast_func):
-            return cast_func(self.process_expression(rhs.args[0], type_constants=False), rhs.dtype)
+            return cast_func(
+                self.process_expression(rhs.args[0], type_constants=False),
+                rhs.dtype)
         else:
             if isinstance(rhs, sp.Pow):
                 # don't process exponents -> they should remain integers
-                return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1])
+                return sp.Pow(
+                    self.process_expression(rhs.args[0], type_constants),
+                    rhs.args[1])
             else:
-                new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
+                new_args = [
+                    self.process_expression(arg, type_constants)
+                    for arg in rhs.args
+                ]
                 return rhs.func(*new_args) if new_args else rhs
 
     @property
@@ -829,7 +936,7 @@ class KernelConstraintsCheck:
     def _process_lhs(self, lhs):
         assert isinstance(lhs, sp.Symbol)
         self._update_accesses_lhs(lhs)
-        if not isinstance(lhs, AbstractField.AbstractAccess) and not isinstance(lhs, TypedSymbol):
+        if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
             return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
         else:
             return lhs
@@ -839,22 +946,32 @@ class KernelConstraintsCheck:
             fai = self.FieldAndIndex(lhs.field, lhs.index)
             self._field_writes[fai].add(lhs.offsets)
             if len(self._field_writes[fai]) > 1:
-                raise ValueError("Field {} is written at two different locations".format(lhs.field.name))
+                raise ValueError(
+                    "Field {} is written at two different locations".format(
+                        lhs.field.name))
         elif isinstance(lhs, sp.Symbol):
             if self.scopes.is_defined_locally(lhs):
-                raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
+                raise ValueError(
+                    "Assignments not in SSA form, multiple assignments to {}".
+                    format(lhs.name))
             if lhs in self.scopes.free_parameters:
-                raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
+                raise ValueError(
+                    "Symbol {} is written, after it has been read".format(
+                        lhs.name))
             self.scopes.define_symbol(lhs)
 
     def _update_accesses_rhs(self, rhs):
-        if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition:
-            writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
+        if isinstance(rhs, AbstractField.AbstractAccess
+                      ) and self.check_independence_condition:
+            writes = self._field_writes[self.FieldAndIndex(
+                rhs.field, rhs.index)]
             for write_offset in writes:
                 assert len(writes) == 1
                 if write_offset != rhs.offsets:
-                    raise ValueError("Violation of loop independence condition. Field "
-                                     "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
+                    raise ValueError(
+                        "Violation of loop independence condition. Field "
+                        "{} is read at {} and written at {}".format(
+                            rhs.field, rhs.offsets, write_offset))
             self.fields_read.add(rhs.field)
         elif isinstance(rhs, sp.Symbol):
             self.scopes.access_symbol(rhs)
@@ -875,21 +992,29 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
         ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
          list of equations where symbols have been replaced by typed symbols
     """
-    if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
+    if isinstance(type_for_symbol,
+                  str) or not hasattr(type_for_symbol, '__getitem__'):
         type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
 
-    check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
+    # assignments = ast.Block(eqs).atoms(ast.Assignment)
+    # type_for_symbol.update( {a.lhs: get_type_of_expression(a.rhs) for a in assignments})
+    # print(type_for_symbol)
+    check = KernelConstraintsCheck(type_for_symbol,
+                                   check_independence_condition)
 
     def visit(obj):
-        if isinstance(obj, list) or isinstance(obj, tuple):
+        if isinstance(obj, (list, tuple)):
             return [visit(e) for e in obj]
-        if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
+        if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
             return check.process_assignment(obj)
         elif isinstance(obj, ast.Conditional):
             check.scopes.push()
-            false_block = None if obj.false_block is None else visit(obj.false_block)
-            result = ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False),
-                                     true_block=visit(obj.true_block), false_block=false_block)
+            false_block = None if obj.false_block is None else visit(
+                obj.false_block)
+            result = ast.Conditional(check.process_expression(
+                obj.condition_expr, type_constants=False),
+                true_block=visit(obj.true_block),
+                false_block=false_block)
             check.scopes.pop()
             return result
         elif isinstance(obj, ast.Block):
@@ -897,7 +1022,8 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
             result = ast.Block([visit(e) for e in obj.args])
             check.scopes.pop()
             return result
-        elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
+        elif isinstance(
+                obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
             return obj
         else:
             raise ValueError("Invalid object in kernel " + str(type(obj)))
@@ -956,7 +1082,8 @@ def insert_casts(node):
     for arg in node.args:
         args.append(insert_casts(arg))
     # TODO indexed, LoopOverCoordinate
-    if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
+    if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne,
+                     sp.Lt, sp.Le, sp.Gt, sp.Ge):
         # TODO optimize pow, don't cast integer on double
         types = [get_type_of_expression(arg) for arg in args]
         assert len(types) > 0
@@ -974,7 +1101,8 @@ def insert_casts(node):
         if target.func is PointerType:
             return node.func(*args)  # TODO fix, not complete
         else:
-            return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
+            return node.func(
+                lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
     elif node.func is ast.ResolvedFieldAccess:
         return node
     elif node.func is ast.Block:
@@ -991,19 +1119,30 @@ def insert_casts(node):
         target = collate_types(types)
         zipped = list(zip(expressions, types))
         casted_expressions = cast(zipped, target)
-        args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_expressions)]
+        args = [
+            arg.func(*[expr, arg.cond])
+            for (arg, expr) in zip(args, casted_expressions)
+        ]
 
     return node.func(*args)
 
 
-def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
+def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction
+                                            ) -> None:
     """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""
 
-    all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
-    assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
+    all_inner_loops = [
+        l for l in function_node.atoms(ast.LoopOverCoordinate)
+        if l.is_innermost_loop
+    ]
+    assert len(
+        all_inner_loops
+    ) == 1, "Transformation works only on kernels with exactly one inner loop"
     inner_loop = all_inner_loops.pop()
 
-    for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
+    for loop in parents_of_type(inner_loop,
+                                ast.LoopOverCoordinate,
+                                include_current=True):
         cut_loop(loop, [loop.stop - 1])
 
     simplify_conditionals(function_node.body, loop_counter_simplification=True)
@@ -1016,7 +1155,7 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -
 # --------------------------------------- Helper Functions -------------------------------------------------------------
 
 
-def typing_from_sympy_inspection(eqs, default_type="double"):
+def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='int64'):
     """
     Creates a default symbol name to type mapping.
     If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
@@ -1032,17 +1171,25 @@ def typing_from_sympy_inspection(eqs, default_type="double"):
         if isinstance(eq, ast.Conditional):
             result.update(typing_from_sympy_inspection(eq.true_block.args))
             if eq.false_block:
-                result.update(typing_from_sympy_inspection(eq.false_block.args))
+                result.update(typing_from_sympy_inspection(
+                    eq.false_block.args))
         elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment):
             continue
         else:
             from pystencils.cpu.vectorization import vec_all, vec_any
-            if isinstance(eq.rhs, vec_all) or isinstance(eq.rhs, vec_any):
+            if isinstance(eq.rhs, (vec_all, vec_any)):
                 result[eq.lhs.name] = "bool"
             # problematic case here is when rhs is a symbol: then it is impossible to decide here without
             # further information what type the left hand side is - default fallback is the dict value then
             if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
                 result[eq.lhs.name] = "bool"
+            try:
+                result[eq.lhs.name] = get_type_of_expression(eq.rhs,
+                                                             default_float_type=default_type,
+                                                             default_int_type=default_int_type,
+                                                             symbol_type_dict=result)
+            except Exception:
+                pass  # gracefully fail in case get_type_of_expression cannot determine type
     return result
 
 
@@ -1084,13 +1231,17 @@ def get_optimal_loop_ordering(fields):
     ref_field = next(iter(fields))
     for field in fields:
         if field.spatial_dimensions != ref_field.spatial_dimensions:
-            raise ValueError("All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
-                             + str({f.name: f.spatial_shape for f in fields}))
+            raise ValueError(
+                "All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
+                + str({f.name: f.spatial_shape
+                       for f in fields}))
 
     layouts = set([field.layout for field in fields])
     if len(layouts) > 1:
-        raise ValueError("Due to different layout of the fields no optimal loop ordering exists "
-                         + str({f.name: f.layout for f in fields}))
+        raise ValueError(
+            "Due to different layout of the fields no optimal loop ordering exists "
+            + str({f.name: f.layout
+                   for f in fields}))
     layout = list(layouts)[0]
     return list(layout)
 
@@ -1135,7 +1286,9 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
     """
     inner_loops = []
     inner_loop_counters = set()
-    for loop in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment):
+    for loop in filtered_tree_iteration(ast_node,
+                                        ast.LoopOverCoordinate,
+                                        stop_type=ast.SympyAssignment):
         if loop.is_innermost_loop:
             inner_loops.append(loop)
             inner_loop_counters.add(loop.coordinate_to_loop_over)
@@ -1146,8 +1299,10 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
     inner_loop_counter = inner_loop_counters.pop()
 
     parameters = ast_node.get_parameters()
-    stride_params = [p.symbol for p in parameters
-                     if p.is_field_stride and p.symbol.coordinate == inner_loop_counter]
+    stride_params = [
+        p.symbol for p in parameters
+        if p.is_field_stride and p.symbol.coordinate == inner_loop_counter
+    ]
     subs_dict = {stride_param: 1 for stride_param in stride_params}
     if subs_dict:
         ast_node.subs(subs_dict)
@@ -1163,7 +1318,10 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
     Returns:
         number of dimensions blocked
     """
-    loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)]
+    loops = [
+        l for l in filtered_tree_iteration(
+            ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
+    ]
     body = ast_node.body
 
     coordinates = []
@@ -1183,8 +1341,12 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
     outer_loop = None
     for coord in reversed(coordinates):
         body = ast.Block([outer_loop]) if outer_loop else body
-        outer_loop = ast.LoopOverCoordinate(body, coord, loop_starts[coord], loop_stops[coord],
-                                            step=block_size[coord], is_block_loop=True)
+        outer_loop = ast.LoopOverCoordinate(body,
+                                            coord,
+                                            loop_starts[coord],
+                                            loop_stops[coord],
+                                            step=block_size[coord],
+                                            is_block_loop=True)
 
     ast_node.body = ast.Block([outer_loop])
 
@@ -1193,7 +1355,8 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
         coord = inner_loop.coordinate_to_loop_over
         block_ctr = ast.LoopOverCoordinate.get_block_loop_counter_symbol(coord)
         loop_range = inner_loop.stop - inner_loop.start
-        if sp.sympify(loop_range).is_number and loop_range % block_size[coord] == 0:
+        if sp.sympify(
+                loop_range).is_number and loop_range % block_size[coord] == 0:
             stop = block_ctr + block_size[coord]
         else:
             stop = sp.Min(inner_loop.stop, block_ctr + block_size[coord])