diff --git a/pystencils/bit_masks.py b/pystencils/bit_masks.py
index 73c18688cc7d34bc23caed66528bae3f148dba65..ad4e967f0131dde5347c20de467afbab4892e149 100644
--- a/pystencils/bit_masks.py
+++ b/pystencils/bit_masks.py
@@ -1,5 +1,5 @@
 import sympy as sp
-from pystencils.typing import get_type_of_expression
+# from pystencils.typing import get_type_of_expression
 
 
 # noinspection PyPep8Naming
@@ -22,13 +22,14 @@ class flag_cond(sp.Function):
 
     def __new__(cls, flag_bit, mask_expression, *expressions):
 
-        flag_dtype = get_type_of_expression(flag_bit)
-        if not flag_dtype.is_int():
-            raise ValueError('Argument flag_bit must be of integer type.')
-
-        mask_dtype = get_type_of_expression(mask_expression)
-        if not mask_dtype.is_int():
-            raise ValueError('Argument mask_expression must be of integer type.')
+        # TODO reintroduce checking
+        # flag_dtype = get_type_of_expression(flag_bit)
+        # if not flag_dtype.is_int():
+        #     raise ValueError('Argument flag_bit must be of integer type.')
+        #
+        # mask_dtype = get_type_of_expression(mask_expression)
+        # if not mask_dtype.is_int():
+        #     raise ValueError('Argument mask_expression must be of integer type.')
 
         return super().__new__(cls, flag_bit, mask_expression, *expressions)
 
diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py
index b620c9c7e1257881522db57b8a1f5aad138a8485..ef81b529d25d68613f47afd8913f6f1eb3c59dc8 100644
--- a/pystencils/typing/leaf_typing.py
+++ b/pystencils/typing/leaf_typing.py
@@ -175,7 +175,10 @@ class TypeAdder:
             raise NotImplementedError('integer_functions')
         elif isinstance(expr, flag_cond):
             #   do not process the arguments to the bit shift - they must remain integers
-            raise NotImplementedError('flag_cond')
+            args_types = [self.figure_out_type(a) for a in (expr.args[i] for i in range(2, len(expr.args)))]
+            collated_type = collate_types([t for _, t in args_types])
+            new_expressions = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
+            return expr.func(expr.args[0], expr.args[1], *new_expressions), collated_type
         #elif isinstance(expr, sp.Mul):
         #    raise NotImplementedError('sp.Mul')
         #    # TODO can we ignore this and move it to general expr handling, i.e. removing Mul?
diff --git a/pystencils_tests/test_bit_masks.py b/pystencils_tests/test_bit_masks.py
index 57371976f416abdf52274852666860c3c92dcdf2..423fc13cc63569d3b6277983ca1e9210a3bbe9c9 100644
--- a/pystencils_tests/test_bit_masks.py
+++ b/pystencils_tests/test_bit_masks.py
@@ -1,11 +1,15 @@
+import pytest
 import numpy as np
+
+import pystencils as ps
 from pystencils import Field, Assignment, create_kernel
 from pystencils.bit_masks import flag_cond
 
 
-def test_flag_condition():
+@pytest.mark.parametrize('mask_type', [np.uint8, np.uint16, np.uint32, np.uint64])
+def test_flag_condition(mask_type):
     f_arr = np.zeros((2, 2, 2), dtype=np.float64)
-    mask_arr = np.zeros((2, 2), dtype=np.uint64)
+    mask_arr = np.zeros((2, 2), dtype=mask_type)
 
     mask_arr[0, 1] = (1 << 3)
     mask_arr[1, 0] = (1 << 5)
@@ -16,7 +20,7 @@ def test_flag_condition():
 
     v1 = 42.3
     v2 = 39.7
-    v3 = 119.87
+    v3 = 119
 
     assignments = [
         Assignment(f(0), flag_cond(3, mask(0), v1)),
@@ -25,6 +29,8 @@ def test_flag_condition():
 
     kernel = create_kernel(assignments).compile()
     kernel(f=f_arr, mask=mask_arr)
+    code = ps.get_code_str(kernel)
+    assert '119.0' in code
 
     reference = np.zeros((2, 2, 2), dtype=np.float64)
     reference[0, 1, 0] = v1