diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index 420e57c709e1697885ea686fe2596ad4c417b5ff..ae4c0e38894043326548efb63d08c6d79103486c 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -257,22 +257,24 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
 
     handled_functions = (sp.Add, sp.Mul, vec_any, vec_all, DivFunc, sp.Abs)
 
-    def visit_expr(expr, default_type='double'):  # TODO Vectorization Revamp: get rid of default_type
+    # TODO Vectorization Revamp: get rid of default_type
+    def visit_expr(expr, default_type='double', force_vectorize=False):
         if isinstance(expr, VectorMemoryAccess):
-            return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:])
+            return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type, force_vectorize),
+                                      *expr.args[5:])
         elif isinstance(expr, CastFunc):
             cast_type = expr.args[1]
-            arg = visit_expr(expr.args[0])
+            arg = visit_expr(expr.args[0], default_type, force_vectorize)
             assert cast_type in [BasicType('float32'), BasicType('float64')],\
                 f'Vectorization cannot vectorize type {cast_type}'
             return expr.func(arg, VectorType(cast_type, instruction_set['width']))
         elif expr.func is sp.Abs and 'abs' not in instruction_set:
-            new_arg = visit_expr(expr.args[0], default_type)
+            new_arg = visit_expr(expr.args[0], default_type, force_vectorize)
             base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is VectorMemoryAccess \
                 else get_type_of_expression(expr.args[0])
             pw = sp.Piecewise((-new_arg, new_arg < CastFunc(0, base_type.numpy_dtype)),
                               (new_arg, True))
-            return visit_expr(pw, default_type)
+            return visit_expr(pw, default_type, force_vectorize)
         elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction):
             if expr.func is sp.Mul and expr.args[0] == -1:
                 # special treatment for the unary minus: make sure that the -1 has the same type as the argument
@@ -287,7 +289,7 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
                     if dtype is np.float32:
                         default_type = 'float'
                     expr = sp.Mul(dtype(expr.args[0]), *expr.args[1:])
-            new_args = [visit_expr(a, default_type) for a in expr.args]
+            new_args = [visit_expr(a, default_type, force_vectorize) for a in expr.args]
             arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args]
             if not any(type(t) is VectorType for t in arg_types):
                 return expr
@@ -306,7 +308,7 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
                 exp = expr.args[0].exp
                 expr = sp.UnevaluatedExpr(sp.Mul(*([base] * +exp), evaluate=False))
 
-            new_args = [visit_expr(a, default_type) for a in expr.args[0].args]
+            new_args = [visit_expr(a, default_type, force_vectorize) for a in expr.args[0].args]
             arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args]
 
             target_type = collate_types(arg_types)
@@ -318,11 +320,11 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
                 for a, t in zip(new_args, arg_types)]
             return expr.func(expr.args[0].func(*casted_args, evaluate=False))
         elif expr.func is sp.Pow:
-            new_arg = visit_expr(expr.args[0], default_type)
+            new_arg = visit_expr(expr.args[0], default_type, force_vectorize)
             return expr.func(new_arg, expr.args[1])
         elif expr.func == sp.Piecewise:
-            new_results = [visit_expr(a[0], default_type) for a in expr.args]
-            new_conditions = [visit_expr(a[1], default_type) for a in expr.args]
+            new_results = [visit_expr(a[0], default_type, force_vectorize) for a in expr.args]
+            new_conditions = [visit_expr(a[1], default_type, force_vectorize) for a in expr.args]
             types_of_results = [get_type_of_expression(a) for a in new_results]
             types_of_conditions = [get_type_of_expression(a) for a in new_conditions]
 
@@ -341,7 +343,14 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
                                  for a, t in zip(new_conditions, types_of_conditions)]
 
             return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)])
-        elif isinstance(expr, (sp.Number, TypedSymbol, BooleanAtom)):
+        elif isinstance(expr, TypedSymbol):
+            if force_vectorize:
+                expr_type = get_type_of_expression(expr)
+                if type(expr_type) is not VectorType:
+                    vector_type = VectorType(expr_type, instruction_set['width'])
+                    return CastFunc(expr, vector_type)
+            return expr
+        elif isinstance(expr, (sp.Number, BooleanAtom)):
             return expr
         else:
             raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
@@ -357,11 +366,18 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
                 # continue
                 subs_expr = fast_subs(assignment.rhs, substitution_dict,
                                       skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
-                assignment.rhs = visit_expr(subs_expr, default_type)
-                rhs_type = get_type_of_expression(assignment.rhs)
+
+                # If either side contains a vectorized subexpression, both sides
+                # must be fully vectorized.
+                lhs_type = get_type_of_expression(assignment.lhs)
+                rhs_type = get_type_of_expression(subs_expr)
+                lhs_vectorized = type(lhs_type) is VectorType
+                rhs_vectorized = type(rhs_type) is VectorType
+
+                assignment.rhs = visit_expr(subs_expr, default_type, force_vectorize=lhs_vectorized or rhs_vectorized)
+
                 if isinstance(assignment.lhs, TypedSymbol):
-                    lhs_type = assignment.lhs.dtype
-                    if type(rhs_type) is VectorType and type(lhs_type) is not VectorType:
+                    if rhs_vectorized and not lhs_vectorized:
                         new_lhs_type = VectorType(lhs_type, rhs_type.width)
                         new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type)
                         substitution_dict[assignment.lhs] = new_lhs
diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py
index 3718770b9384e0f01dd899c3afed298208fc95a9..83320a91f60677e05ab484f7e83b428df47adc64 100644
--- a/pystencils_tests/test_vectorization.py
+++ b/pystencils_tests/test_vectorization.py
@@ -6,6 +6,7 @@ import pystencils.config
 import sympy as sp
 
 import pystencils as ps
+import pystencils.astnodes as ast
 from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set
 from pystencils.cpu.vectorization import vectorize
 from pystencils.enums import Target
@@ -40,6 +41,47 @@ def test_vector_type_propagation(instruction_set=instruction_set):
     np.testing.assert_equal(dst[1:-1, 1:-1], 2 * 10.0 + 3)
 
 
+def test_vectorize_moved_constants1(instruction_set=instruction_set):
+    opt = {'instruction_set': instruction_set, 'assume_inner_stride_one': True}
+
+    f = ps.fields("f: [1D]")
+    x = ast.TypedSymbol("x", np.float64)
+
+    kernel_func = ps.create_kernel(
+        [ast.SympyAssignment(x, 2.0), ast.SympyAssignment(f[0], x)],
+        cpu_prepend_optimizations=[ps.transformations.move_constants_before_loop],  # explicitly move constants
+        cpu_vectorize_info=opt,
+    )
+    ps.show_code(kernel_func)  # fails if `x` on rhs was not correctly vectorized
+    kernel = kernel_func.compile()
+
+    f_arr = np.zeros(9)
+    kernel(f=f_arr)
+
+    assert(np.all(f_arr == 2))
+
+
+def test_vectorize_moved_constants2(instruction_set=instruction_set):
+    opt = {'instruction_set': instruction_set, 'assume_inner_stride_one': True}
+
+    f = ps.fields("f: [1D]")
+    x = ast.TypedSymbol("x", np.float64)
+    y = ast.TypedSymbol("y", np.float64)
+
+    kernel_func = ps.create_kernel(
+        [ast.SympyAssignment(x, 2.0), ast.SympyAssignment(y, 3.0), ast.SympyAssignment(f[0], x + y)],
+        cpu_prepend_optimizations=[ps.transformations.move_constants_before_loop],  # explicitly move constants
+        cpu_vectorize_info=opt,
+    )
+    ps.show_code(kernel_func)  # fails if `x` on rhs was not correctly vectorized
+    kernel = kernel_func.compile()
+
+    f_arr = np.zeros(9)
+    kernel(f=f_arr)
+
+    assert(np.all(f_arr == 5))
+
+
 @pytest.mark.parametrize('openmp', [True, False])
 def test_aligned_and_nt_stores(openmp, instruction_set=instruction_set):
     domain_size = (24, 24)