diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index ab35ff10819b1056a8495cc86d4afd879e4551f9..b5dc66e10c40af4849e7904f8dcfff267166b7f7 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -491,10 +491,7 @@ class CustomSympyPrinter(CCodePrinter):
             return f"&({self._print(expr.args[0])})"
         elif isinstance(expr, CastFunc):
             arg, data_type = expr.args
-            if isinstance(arg, sp.Number) and arg.is_finite:
-                return self._typed_number(arg, data_type)
-            else:
-                return f"(({data_type})({self._print(arg)}))"
+            return f"(({data_type})({self._print(arg)}))"
         elif isinstance(expr, fast_division):
             return f"({self._print(expr.args[0] / expr.args[1])})"
         elif isinstance(expr, fast_sqrt):
diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py
index b3817c8f5c11ac07a8084f215ac04768deb15c8c..4e0573c89e9b66c92f81a97a97256460e75f711e 100644
--- a/pystencils/cpu/kernelcreation.py
+++ b/pystencils/cpu/kernelcreation.py
@@ -4,7 +4,7 @@ import sympy as sp
 import numpy as np
 
 import pystencils.astnodes as ast
-from pystencils.assignment import Assignment
+from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.config import CreateKernelConfig
 from pystencils.enums import Target, Backend
 from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
@@ -17,12 +17,8 @@ from pystencils.transformations import (
     move_constants_before_loop, parse_base_pointer_info, resolve_buffer_accesses,
     resolve_field_accesses, split_inner_loop)
 
-from pystencils.kernel_contrains_check import KernelConstraintsCheck
 
-AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]]
-
-
-def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConfig, split_groups) -> KernelFunction:
+def create_kernel(assignments: AssignmentCollection, config: CreateKernelConfig) -> KernelFunction:
     """Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
 
     Loops are created according to the field accesses in the equations.
@@ -31,8 +27,6 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf
         assignments: list of sympy equations, containing accesses to :class:`pystencils.field.Field`.
         Defining the update rules of the kernel
         config: create kernel config
-        split_groups: Specification on how to split up inner loop into multiple loops. For details see
-                      transformation :func:`pystencils.transformation.split_inner_loop`
 
     Returns:
         AST node representing a function, that can be printed as C or CUDA code
@@ -41,8 +35,13 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf
     type_info = config.data_type
     iteration_slice = config.iteration_slice
     ghost_layers = config.ghost_layers
-    skip_independence_check = config.skip_independence_check
-    allow_double_writes = config.allow_double_writes
+    fields_written = assignments.bound_fields
+    fields_read = assignments.free_fields
+
+    split_groups = ()
+    if 'split_groups' in assignments.simplification_hints:
+        split_groups = assignments.simplification_hints['split_groups']
+    assignments = assignments.all_assignments
 
     # TODO: try to delete
     def type_symbol(term):
@@ -56,12 +55,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf
         else:
             raise ValueError("Term has to be field access or symbol")
 
-    check = KernelConstraintsCheck(check_independence_condition=skip_independence_check,
-                                   check_double_write_condition=allow_double_writes)
-    check.visit(assignments)
-
-    fields_read = check.fields_read
-    fields_written = check.fields_written
+    # TODO move add_types to create_domain_kernel or create_kernel
 
     assignments = add_types(assignments, config)
 
@@ -78,7 +72,6 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf
     ast_node = KernelFunction(loop_node, Target.CPU, Backend.C, compile_function=make_python_function,
                               ghost_layers=ghost_layer_info, function_name=function_name, assignments=assignments)
 
-    # TODO move split groups here
     if split_groups:
         typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups]
         split_inner_loop(ast_node, typed_split_groups)
@@ -100,7 +93,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, config: CreateKernelConf
     return ast_node
 
 
-def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, function_name="kernel",
+def create_indexed_kernel(assignments: AssignmentCollection, index_fields, function_name="kernel",
                           type_info=None, coordinate_names=('x', 'y', 'z')) -> KernelFunction:
     """
     Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py
index 74f56666651563d5d49df141819cc44a64835889..42204822a29321d611189ae9f9afd021344a5305 100644
--- a/pystencils/kernel_contrains_check.py
+++ b/pystencils/kernel_contrains_check.py
@@ -4,6 +4,7 @@ from typing import Union
 import sympy as sp
 from sympy.codegen import Assignment
 
+from pystencils.simp import AssignmentCollection
 from pystencils import astnodes as ast, TypedSymbol
 from pystencils.field import Field
 from pystencils.transformations import NestedScopes
@@ -41,7 +42,9 @@ class KernelConstraintsCheck:
         self.check_double_write_condition = check_double_write_condition
 
     def visit(self, obj):
-        if isinstance(obj, list) or isinstance(obj, tuple):
+        if isinstance(obj, AssignmentCollection):
+            [self.visit(e) for e in obj.all_assignments]
+        elif isinstance(obj, list) or isinstance(obj, tuple):
             [self.visit(e) for e in obj]
         elif isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
             self.process_assignment(obj)
diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index 5ee7817f170beb24688aca6ce3940e63ae3a0b89..2635eb035e7b3ff7ef84fec071bdbd8ef96019d8 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -12,7 +12,7 @@ from pystencils.enums import Target, Backend
 from pystencils.field import Field, FieldType
 from pystencils.gpucuda.indexing import indexing_creator_from_params
 from pystencils.simp.assignment_collection import AssignmentCollection
-from pystencils.simp.simplifications import apply_sympy_optimisations
+from pystencils.kernel_contrains_check import KernelConstraintsCheck
 from pystencils.simplificationfactory import create_simplification_strategy
 from pystencils.stencil import direction_string_to_offset, inverse_direction_string
 from pystencils.transformations import (
@@ -62,6 +62,8 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol
     if isinstance(assignments, Assignment):
         assignments = [assignments]
     assert assignments, "Assignments must not be empty!"
+    if isinstance(assignments, list):
+        assignments = AssignmentCollection(assignments)
 
     if config.index_fields:
         return create_indexed_kernel(assignments, config=config)
@@ -69,7 +71,7 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol
         return create_domain_kernel(assignments, config=config)
 
 
-def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelConfig):
+def create_domain_kernel(assignments: AssignmentCollection, *, config: CreateKernelConfig):
     """
     Creates abstract syntax tree (AST) of kernel, using a list of update equations.
 
@@ -82,6 +84,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
         can be compiled with through its 'compile()' member
 
     Example:
+        # TODO change to assignment collection
         >>> import pystencils as ps
         >>> import numpy as np
         >>> s, d = ps.fields('s, d: [2D]')
@@ -98,6 +101,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
                [0., 4., 4., 4., 0.],
                [0., 0., 0., 0., 0.]])
     """
+
     # --- applying first default simplifications
     try:
         if config.default_assignment_simplifications and isinstance(assignments, AssignmentCollection):
@@ -107,20 +111,18 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
         warnings.warn(f"It was not possible to apply the default pystencils optimisations to the "
                       f"AssignmentCollection due to the following problem :{e}")
 
-    # TODO: shift to CPU
-    # ----  Normalizing parameters
-    split_groups = ()
-    if isinstance(assignments, AssignmentCollection):
-        if 'split_groups' in assignments.simplification_hints:
-            split_groups = assignments.simplification_hints['split_groups']
-        assignments = assignments.all_assignments
+    assignments.evaluate_terms()
 
-    try:
-        if config.default_assignment_simplifications:
-            assignments = apply_sympy_optimisations(assignments)
-    except Exception as e:
-        warnings.warn(f"It was not possible to apply the default SymPy optimisations to the "
-                      f"Assignments due to the following problem :{e}")
+    # --- eval
+    # TODO split apply_sympy_optimisations and do the eval here
+
+    # FUTURE WORK from here we shouldn't NEED sympy
+    # --- check constrains
+    check = KernelConstraintsCheck(check_independence_condition=config.skip_independence_check,
+                                   check_double_write_condition=config.allow_double_writes)
+    check.visit(assignments)
+    assert assignments.bound_fields == check.fields_written, f'WTF'
+    assert assignments.rhs_fields == check.fields_read, f'WTF'
 
     # ----  Creating ast
     ast = None
@@ -128,7 +130,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
         if config.backend == Backend.C:
             from pystencils.cpu import add_openmp, create_kernel
             # TODO: data type keyword should be unified to data_type
-            ast = create_kernel(assignments, config=config, split_groups=split_groups)
+            ast = create_kernel(assignments, config=config)
             for optimization in config.cpu_prepend_optimizations:
                 optimization(ast)
             omp_collapse = None
@@ -170,7 +172,7 @@ def create_domain_kernel(assignments: List[Assignment], *, config: CreateKernelC
     return ast
 
 
-def create_indexed_kernel(assignments: List[Assignment], *, config: CreateKernelConfig):
+def create_indexed_kernel(assignments: AssignmentCollection, *, config: CreateKernelConfig):
     """
     Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
     coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.
@@ -212,6 +214,8 @@ import pystencils.kernel_creation_config        >>> import pystencils as ps
                [0. , 0. , 0. , 4.3, 0. ],
                [0. , 0. , 0. , 0. , 0. ]])
     """
+    # TODO do this in backends
+    assignments = assignments.all_assignments
     ast = None
     if config.target == Target.CPU and config.backend == Backend.C:
         from pystencils.cpu import add_openmp, create_indexed_kernel
diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py
index f493b08e931bf69f0b9713255c0eee87e9bace66..7309e7d87bb3402ece30f5008baed0514a99b770 100644
--- a/pystencils/simp/assignment_collection.py
+++ b/pystencils/simp/assignment_collection.py
@@ -3,6 +3,7 @@ from copy import copy
 from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
 
 import sympy as sp
+from sympy.codegen.rewriting import ReplaceOptim, optimize
 
 import pystencils
 from pystencils.assignment import Assignment
@@ -107,16 +108,21 @@ class AssignmentCollection:
         return self.subexpressions + self.main_assignments
 
     @property
-    def free_symbols(self) -> Set[sp.Symbol]:
-        """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
-        free_symbols = set()
+    def rhs_symbols(self) -> Set[sp.Symbol]:
+        """All symbols used in the assignment collection, which occur on the rhs of any assignment."""
+        rhs_symbols = set()
         for eq in self.all_assignments:
             if isinstance(eq, Assignment):
-                free_symbols.update(eq.rhs.atoms(sp.Symbol))
+                rhs_symbols.update(eq.rhs.atoms(sp.Symbol))
             elif isinstance(eq, pystencils.astnodes.Node):
-                free_symbols.update(eq.undefined_symbols)
+                rhs_symbols.update(eq.undefined_symbols)
 
-        return free_symbols - self.bound_symbols
+        return rhs_symbols
+
+    @property
+    def free_symbols(self) -> Set[sp.Symbol]:
+        """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
+        return self.rhs_symbols - self.bound_symbols
 
     @property
     def bound_symbols(self) -> Set[sp.Symbol]:
@@ -132,10 +138,15 @@ class AssignmentCollection:
             assignment.symbols_defined for assignment in self.all_assignments
             if isinstance(assignment, pystencils.astnodes.Node)
         ]
-        )
+                                                    )
 
         return bound_symbols_set
 
+    @property
+    def rhs_fields(self):
+        """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
+        return {s.field for s in self.rhs_symbols if hasattr(s, 'field')}
+
     @property
     def free_fields(self):
         """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
@@ -152,7 +163,7 @@ class AssignmentCollection:
         return (set(
             [assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)]
         ).union(*[assignment.symbols_defined for assignment in self.main_assignments if isinstance(
-                assignment, pystencils.astnodes.Node)]
+            assignment, pystencils.astnodes.Node)]
                 ))
 
     @property
@@ -214,6 +225,7 @@ class AssignmentCollection:
             return {s: func(*args, **kwargs) for s, func in lambdas.items()}
 
         return f
+
     # ---------------------------- Creating new modified collections ---------------------------------------------------
 
     def copy(self,
@@ -353,10 +365,26 @@ class AssignmentCollection:
         new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments]
         return self.copy(new_assignment, kept_subexpressions)
 
+    def evaluate_terms(self):
+
+        evaluate_constant_terms = ReplaceOptim(
+            lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
+            lambda p: p.evalf())
+
+        sympy_optimisations = [evaluate_constant_terms]
+
+        self.subexpressions = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
+                       if hasattr(a, 'lhs')
+                       else a for a in self.subexpressions]
+
+        self.main_assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
+                       if hasattr(a, 'lhs')
+                       else a for a in self.main_assignments]
     # ----------------------------------------- Display and Printing   -------------------------------------------------
 
     def _repr_html_(self):
         """Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
+
         def make_html_equation_table(equations):
             no_border = 'style="border:none"'
             html_table = '<table style="border:none; width: 100%; ">'
diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py
index c36ba558e0a788a9d1751f2057d4647104d7e5a1..955f6b73d37421c8d51cb04b871945571ffce7bd 100644
--- a/pystencils/simp/simplifications.py
+++ b/pystencils/simp/simplifications.py
@@ -3,12 +3,10 @@ from typing import Callable, List, Sequence, Union
 from collections import defaultdict
 
 import sympy as sp
-from sympy.codegen.rewriting import optims_c99, optimize
-from sympy.codegen.rewriting import ReplaceOptim
 
 from pystencils.assignment import Assignment
-from pystencils.astnodes import Node, SympyAssignment
-from pystencils.field import Field, Field
+from pystencils.astnodes import Node
+from pystencils.field import Field
 from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
 
 
@@ -227,22 +225,29 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
     return f
 
 
-def apply_sympy_optimisations(assignments):
-    """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation)
-        and applies the default sympy optimisations. See sympy.codegen.rewriting
-    """
-
-    # Evaluates all constant terms
-    evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
-                                           lambda p: p.evalf())
-
-    sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
-
-    assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
-                   if hasattr(a, 'lhs')
-                   else a for a in assignments]
-    assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
-    for a in chain.from_iterable(assignments_nodes):
-        a.optimize(sympy_optimisations)
-
-    return assignments
+# TODO Markus
+# TODO: make this really work for Assignmentcollections
+# TODO: this function should ONLY evaluate
+# TODO: do the optims_c99 elsewhere optionally
+# def apply_sympy_optimisations(ac: AssignmentCollection):
+#     """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation)
+#         and applies the default sympy optimisations. See sympy.codegen.rewriting
+#     """
+#
+#     # Evaluates all constant terms
+#
+#     assignments = ac.all_assignments
+#
+#     evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
+#                                            lambda p: p.evalf())
+#
+#     sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
+#
+#     assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
+#                    if hasattr(a, 'lhs')
+#                    else a for a in assignments]
+#     assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
+#     for a in chain.from_iterable(assignments_nodes):
+#         a.optimize(sympy_optimisations)
+#
+#     return AssignmentCollection(assignments)
diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py
index 6c51691ff8a56a0a27829ccd66e4c409cf461545..df36c0d9175b4b01572e94570c9a4ce0cd41e5c9 100644
--- a/pystencils/typing/leaf_typing.py
+++ b/pystencils/typing/leaf_typing.py
@@ -5,6 +5,11 @@ import logging
 import numpy as np
 
 import sympy as sp
+from sympy import Piecewise
+from sympy.functions.elementary.piecewise import ExprCondPair
+from sympy.codegen import Assignment
+from sympy.logic.boolalg import BooleanFunction
+from sympy.logic.boolalg import BooleanAtom
 
 from pystencils import astnodes as ast
 from pystencils.field import Field
@@ -13,8 +18,6 @@ from pystencils.typing.utilities import get_type_of_expression, collate_types
 from pystencils.typing.cast_functions import CastFunc
 from pystencils.typing.typed_sympy import TypedSymbol
 from pystencils.utils import ContextVar
-from sympy.codegen import Assignment
-from sympy.logic.boolalg import BooleanFunction
 
 
 class TypeAdder:
@@ -93,6 +96,8 @@ class TypeAdder:
     def figure_out_type(self, expr) -> Tuple[Any, BasicType]:  # TODO or abstract type? vector type?
         # Trivial cases
         from pystencils.field import Field
+        import pystencils.integer_functions
+        from pystencils.bit_masks import flag_cond
 
         if isinstance(expr, Field.Access):
             return expr, expr.dtype
@@ -104,24 +109,56 @@ class TypeAdder:
         elif isinstance(expr, np.generic):
             assert False, f'Why do we have a np.generic in rhs???? {expr}'
         elif isinstance(expr, sp.Number):
-            if expr.is_Float:
-                data_type = self.default_number_float.get()
-            elif expr.is_Integer:
+            if expr.is_Integer:
                 data_type = self.default_number_int.get()
+            elif expr.is_Float or expr.is_Rational:
+                data_type = self.default_number_float.get()
             else:
                 assert False, f'{sp.Number} is neither Float nor Integer'
             return expr, data_type
-        # TODO add everything in between
+        elif isinstance(expr, BooleanAtom):
+            return expr, BasicType('bool')
+        elif isinstance(expr, sp.Equality):
+            new_args = [self.figure_out_type(arg)[0] for arg in expr.args]
+            new_eq = sp.Equality(*new_args)
+            return new_eq, BasicType('bool')
+        elif isinstance(expr, CastFunc):
+            raise NotImplementedError('CastFunc')
+        elif isinstance(expr, BooleanFunction) or \
+                type(expr, ) in pystencils.integer_functions.__dict__.values():
+            raise NotImplementedError('BooleanFunction')
+        elif isinstance(expr, flag_cond):
+            #   do not process the arguments to the bit shift - they must remain integers
+            raise NotImplementedError('flag_cond')
         elif isinstance(expr, sp.Mul):
+            raise NotImplementedError('sp.Mul')
             # TODO can we ignore this and move it to general expr handling, i.e. removing Mul?
-            args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)]
-            return None  # TODO collate types
+            # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)]
         elif isinstance(expr, sp.Indexed):
-            self.apply_type(expr, BasicType('uintp'))  # TODO double check
-            return None
+            raise NotImplementedError('sp.Indexed')
         elif isinstance(expr, sp.Pow):
-            # TODO sp.Pow should know a type
-            return None  # TODO
+            args_types = [self.figure_out_type(arg) for arg in expr.args]
+            collated_type = collate_types([t for _, t in args_types])
+            return expr, collated_type
+        elif isinstance(expr, ExprCondPair):
+            expr_expr, expr_type = self.figure_out_type(expr.expr)
+            condition, condition_type = self.figure_out_type(expr.cond)
+            if condition_type != BasicType('bool'):
+                logging.warning(f'Condition "{condition}" is of type "{condition_type}" and not "bool"')
+            return expr.func(expr_expr, condition), expr_type
+        elif isinstance(expr, Piecewise):
+            args_types = [self.figure_out_type(arg) for arg in expr.args]
+            collated_type = collate_types([t for _, t in args_types])
+            new_args = []
+            for a, t in args_types:
+                if t != collated_type:
+                    if isinstance(a, ExprCondPair):
+                        new_args.append(a.func(CastFunc(a.expr, collated_type), a.cond))
+                    else:
+                        new_args.append(CastFunc(a, collated_type))
+                else:
+                    new_args.append(a)
+            return expr.func(*new_args) if new_args else expr, collated_type
         else:
             args_types = [self.figure_out_type(arg) for arg in expr.args]
             collated_type = collate_types([t for _, t in args_types])
@@ -190,6 +227,6 @@ class TypeAdder:
 
     def process_lhs(self, lhs: Union[Field.Access, TypedSymbol, sp.Symbol]):
         if not isinstance(lhs, (Field.Access, TypedSymbol)):
-            return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
+            return TypedSymbol(lhs.name, self.type_for_symbol[lhs.name])
         else:
             return lhs
diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py
index 795e5c26c91f7a6ac2294650c4050c5822b226f5..d1b0d33cb84910d56679acf7269ffa55ae86f1e3 100644
--- a/pystencils_tests/test_types.py
+++ b/pystencils_tests/test_types.py
@@ -84,7 +84,7 @@ def test_mixed_add(dtype1, dtype2):
     assert test_f[0] == constant+constant
 
 
-# TODO redo following tests
+# TODO vector
 def test_collation():
     double_type = BasicType('float64')
     float_type = BasicType('float32')
@@ -95,8 +95,9 @@ def test_collation():
     assert collate_types([double4_type, float4_type]) == double4_type
 
 
+# TODO this
 def test_vector_type():
-    double_type = BasicType("double")
+    double_type = BasicType('float64')
     float_type = BasicType('float32')
     double4_type = VectorType(double_type, 4)
     float4_type = VectorType(float_type, 4)
@@ -147,36 +148,33 @@ def test_assumptions():
     assert (x.shape[0] + 1).is_real
 
 
-def test_sqrt_of_integer():
+@pytest.mark.parametrize('dtype', ('float64', 'float32'))
+def test_sqrt_of_integer(dtype):
     """Regression test for bug where sqrt(3) was classified as integer"""
-    f = ps.fields("f: [1D]")
-    tmp = sp.symbols("tmp")
-
-    assignments = [ps.Assignment(tmp, sp.sqrt(3)),
-                   ps.Assignment(f[0], tmp)]
-    arr_double = np.array([1], dtype=np.float64)
-    kernel = ps.create_kernel(assignments).compile()
-    kernel(f=arr_double)
-    assert 1.7 < arr_double[0] < 1.8
-
-    f = ps.fields("f: float32[1D]")
-    tmp = sp.symbols("tmp")
+    f = ps.fields(f'f: {dtype}[1D]')
+    tmp = sp.symbols('tmp')
 
     assignments = [ps.Assignment(tmp, sp.sqrt(3)),
                    ps.Assignment(f[0], tmp)]
-    arr_single = np.array([1], dtype=np.float32)
-    config = pystencils.config.CreateKernelConfig(data_type="float32")
-    kernel = ps.create_kernel(assignments, config=config).compile()
-    kernel(f=arr_single)
+    arr = np.array([1], dtype=dtype)
+    config = pystencils.config.CreateKernelConfig(data_type=dtype)
 
-    code = ps.get_code_str(kernel.ast)
+    ast = ps.create_kernel(assignments, config=config)
+    kernel = ast.compile()
+    kernel(f=arr)
+    assert 1.7 < arr[0] < 1.8
 
-    assert "1.7320508075688772f" in code
-    assert 1.7 < arr_single[0] < 1.8
+    code = ps.get_code_str(ast)
+    constant = '1.7320508075688772f'
+    if dtype == 'float32':
+        assert constant in code
+    else:
+        assert constant not in code
 
 
-def test_integer_comparision():
-    f = ps.fields("f [2D]")
+@pytest.mark.parametrize('dtype', ('float64', 'float32'))
+def test_integer_comparision(dtype):
+    f = ps.fields(f"f: {dtype}[2D]")
     d = sp.Symbol("dir")
 
     ur = ps.Assignment(f[0, 0], sp.Piecewise((0, sp.Equality(d, 1)), (f[0, 0], True)))
@@ -184,9 +182,17 @@ def test_integer_comparision():
     ast = ps.create_kernel(ur)
     code = ps.get_code_str(ast)
 
-    assert "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0): (_data_f_00[_stride_f_1*ctr_1]));" in code
+    print(code)
+    # There should be an explicit cast for the integer zero to the type of the field on the rhs
+    if dtype == 'float64':
+        t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (((double)(0))): (_data_f_00[_stride_f_1*ctr_1]));"
+    else:
+        t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (((float)(0))): (_data_f_00[_stride_f_1*ctr_1]));"
+
+    assert t in code
 
 
+# TODO this
 def test_Basic_data_type():
     assert typed_symbols(("s", "f"), np.uint) == typed_symbols("s, f", np.uint)
     t_symbols = typed_symbols(("s", "f"), np.uint)
@@ -223,6 +229,7 @@ def test_Basic_data_type():
     assert TypedSymbol("s", np.uint).reversed == TypedSymbol("s", np.uint)
 
 
+# TODO this
 def test_cast_func():
     assert CastFunc(TypedSymbol("s", np.uint), np.int64).canonical == TypedSymbol("s", np.uint).canonical
 
@@ -235,6 +242,7 @@ def test_pointer_arithmetic_func():
     assert PointerArithmeticFunc(TypedSymbol("s", np.uint), 1).canonical == TypedSymbol("s", np.uint).canonical
 
 
+# TODO this
 def test_division():
     f = ps.fields('f(10): float32[2D]')
     m, tau = sp.symbols("m, tau")
@@ -248,6 +256,7 @@ def test_division():
     assert "1.0f" in code
 
 
+# TODO this
 def test_pow():
     f = ps.fields('f(10): float32[2D]')
     m, tau = sp.symbols("m, tau")