diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 0ae5e36400b7d05dd57c11874657096d36181078..ab0aa8a4333ccdc3a19c379d3e8d361514fed93d 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -233,11 +233,17 @@ class CBackend:
                                    self.sympy_printer.doprint(node.rhs))
         else:
             lhs_type = get_type_of_expression(node.lhs)
+            printed_mask = ""
             if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
-                arg, data_type, aligned, nontemporal = node.lhs.args
+                arg, data_type, aligned, nontemporal, mask = node.lhs.args
                 instr = 'storeU'
                 if aligned:
                     instr = 'stream' if nontemporal else 'storeA'
+                if mask != True:
+                    instr = 'maskStore' if aligned else 'maskStoreU'
+                    printed_mask = self.sympy_printer.doprint(mask)
+                    if self._vector_instruction_set['dataTypePrefix']['double'] == '__mm256d':
+                        printed_mask = "_mm256_castpd_si256({})".format(printed_mask)
 
                 rhs_type = get_type_of_expression(node.rhs)
                 if type(rhs_type) is not VectorType:
@@ -246,7 +252,8 @@ class CBackend:
                     rhs = node.rhs
 
                 return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
-                                                                  self.sympy_printer.doprint(rhs)) + ';'
+                                                                  self.sympy_printer.doprint(rhs),
+                                                                  printed_mask) + ';'
             else:
                 return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
 
@@ -450,7 +457,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
 
     def _print_Function(self, expr):
         if isinstance(expr, vector_memory_access):
-            arg, data_type, aligned, _ = expr.args
+            arg, data_type, aligned, _, mask = expr.args
             instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
             return instruction.format("& " + self._print(arg))
         elif isinstance(expr, cast_func):
diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py
index 4415b9ab66e22af1bd671dd581d66ed7e159d855..be8523e90913d7f62eb908ace8cbdf0189c31fbf 100644
--- a/pystencils/backends/simd_instruction_sets.py
+++ b/pystencils/backends/simd_instruction_sets.py
@@ -32,7 +32,24 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
         'storeU': 'storeu[0,1]',
         'storeA': 'store[0,1]',
         'stream': 'stream[0,1]',
+        'maskstore': 'mask_store[0, 2, 1]' if instruction_set == 'avx512' else 'maskstore[0, 2, 1]',
+        'maskload': 'mask_load[0, 2, 1]' if instruction_set == 'avx512' else 'maskload[0, 2, 1]'
     }
+    if instruction_set == 'avx512':
+        base_names.update({
+            'maskStore': 'mask_store[0, 2, 1]',
+            'maskStoreU': 'mask_storeu[0, 2, 1]',
+            'maskLoad': 'mask_load[2, 1, 0]',
+            'maskLoadU': 'mask_loadu[2, 1, 0]'
+        })
+    if instruction_set == 'avx':
+        base_names.update({
+            'maskStore': 'maskstore[0, 2, 1]',
+            'maskStoreU': 'maskstore[0, 2, 1]',
+            'maskLoad': 'maskload[0, 1]',
+            'maskLoadU': 'maskloadu[0, 1]'
+        })
+
     for comparison_op, constant in comparisons.items():
         base_names[comparison_op] = 'cmp[0, 1, %s]' % (constant,)
 
diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index d2f206722972cdc62cdc9f398012cf579b27a96d..9556494576df91c2f614ea241ad44baf2e6f9f41 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -123,7 +123,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
                 nontemporal = False
                 if hasattr(indexed, 'field'):
                     nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
-                substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal)
+                substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True)
         if not successful:
             warnings.warn("Could not vectorize loop because of non-consecutive memory access")
             continue
@@ -136,6 +136,30 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
         fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
                   skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access))
 
+        mask_conditionals(loop_node)
+
+
+def mask_conditionals(loop_body):
+
+    def visit_node(node, mask):
+        if isinstance(node, ast.Conditional):
+            true_mask = sp.And(node.condition_expr, mask)
+            visit_node(node.true_block, true_mask)
+            if node.false_block:
+                false_mask = sp.And(sp.Not(node.condition_expr), mask)
+                visit_node(node, false_mask)
+            node.condition_expr = vec_any(node.condition_expr)
+        elif isinstance(node, ast.SympyAssignment):
+            if mask is not True:
+                s = {ma: vector_memory_access(ma.args[0], ma.args[1], ma.args[2], ma.args[3], sp.And(mask, ma.args[4]))
+                     for ma in node.atoms(vector_memory_access)}
+                node.subs(s)
+        else:
+            for arg in node.args:
+                visit_node(arg, mask)
+
+    visit_node(loop_body, mask=True)
+
 
 def insert_vector_casts(ast_node):
     """Inserts necessary casts from scalar values to vector values."""
@@ -143,8 +167,10 @@ def insert_vector_casts(ast_node):
     handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all)
 
     def visit_expr(expr):
-
-        if isinstance(expr, cast_func) or isinstance(expr, vector_memory_access):
+        if isinstance(expr, vector_memory_access):
+            return vector_memory_access(expr.args[0], expr.args[1], expr.args[2], expr.args[3],
+                                        visit_expr(expr.args[4]))
+        elif isinstance(expr, cast_func):
             return expr
         elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction):
             new_args = [visit_expr(a) for a in expr.args]
@@ -199,10 +225,12 @@ def insert_vector_casts(ast_node):
                         new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type)
                         substitution_dict[assignment.lhs] = new_lhs
                         assignment.lhs = new_lhs
-                elif isinstance(assignment.lhs.func, cast_func):
-                    lhs_type = assignment.lhs.args[1]
-                    if type(lhs_type) is VectorType and type(rhs_type) is not VectorType:
-                        assignment.rhs = cast_func(assignment.rhs, lhs_type)
+                elif isinstance(assignment.lhs, vector_memory_access):
+                    assignment.lhs = visit_expr(assignment.lhs)
+                #elif isinstance(assignment.lhs, cast_func): # TODO check if necessary
+                #    lhs_type = assignment.lhs.args[1]
+                #    if type(lhs_type) is VectorType and type(rhs_type) is not VectorType:
+                #        assignment.rhs = cast_func(assignment.rhs, lhs_type)
             elif isinstance(arg, ast.Conditional):
                 arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict,
                                                skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index 20eb94d6b83474bfab46dc491a05b3b1ed2191d9..a5e876c974e372fdeeb946a9ed2fb1cf1ba08c50 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -190,7 +190,8 @@ class boolean_cast_func(cast_func, Boolean):
 
 # noinspection PyPep8Naming
 class vector_memory_access(cast_func):
-    nargs = (4,)
+    # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none)
+    nargs = (5,)
 
 
 # noinspection PyPep8Naming
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index f43b5bdc6e400fe004e84e2acab188f88a2562ca..1c244be0537b401587b76c348f89ce64c04cdee2 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -891,10 +891,10 @@ class KernelConstraintsCheck:
         if isinstance(lhs, AbstractField.AbstractAccess):
             fai = self.FieldAndIndex(lhs.field, lhs.index)
             self._field_writes[fai].add(lhs.offsets)
-            if len(self._field_writes[fai]) > 1:
-                raise ValueError(
-                    "Field {} is written at two different locations".format(
-                        lhs.field.name))
+            #if len(self._field_writes[fai]) > 1:
+            #    raise ValueError(
+            #        "Field {} is written at two different locations".format(
+            #            lhs.field.name))
         elif isinstance(lhs, sp.Symbol):
             if self.scopes.is_defined_locally(lhs):
                 raise ValueError(