diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index 9df186470086f8115fe0c832917a8676d04aa7bf..916274314b1651aae29619ef2a6a59e15b1d3cc6 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -12,7 +12,7 @@ from ...sympyextensions.typed_sympy import TypedSymbol
 
 from ..symbols import PsSymbol
 from ..arrays import PsLinearizedArray
-from ...types import PsType, PsIntegerType, PsNumericType, PsScalarType, PsStructType
+from ...types import PsType, PsIntegerType, PsNumericType, PsScalarType, PsStructType, deconstify
 from ..constraints import KernelParamsConstraint
 from ..exceptions import PsInternalCompilerError, KernelConstraintsError
 
@@ -63,8 +63,8 @@ class KernelCreationContext:
         default_dtype: PsNumericType = DEFAULTS.numeric_dtype,
         index_dtype: PsIntegerType = DEFAULTS.index_dtype,
     ):
-        self._default_dtype = default_dtype
-        self._index_dtype = index_dtype
+        self._default_dtype = deconstify(default_dtype)
+        self._index_dtype = deconstify(index_dtype)
 
         self._symbols: dict[str, PsSymbol] = dict()
 
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 8ef6edd24d54aa9a9992c04adf51abe620ab4813..fc085e2be99f61204cde92438811a3b4e41c8bf7 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -71,7 +71,9 @@ class TypeContext:
     """
 
     def __init__(
-        self, target_type: PsType | None = None, require_nonconst: bool = False
+        self,
+        target_type: PsType | None = None,
+        require_nonconst: bool = False,
     ):
         self._require_nonconst = require_nonconst
         self._deferred_exprs: list[PsExpression] = []
@@ -171,8 +173,10 @@ class TypeContext:
                         )
 
                 case PsSymbolExpr(symb):
-                    assert symb.dtype is not None
-                    if not self._compatible(symb.dtype):
+                    if symb.dtype is None:
+                        #   Symbols are not forced to constness
+                        symb.dtype = deconstify(self._target_type)
+                    elif not self._compatible(symb.dtype):
                         raise TypificationError(
                             f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n"
                             f"    Symbol type: {symb.dtype}\n"
@@ -262,7 +266,11 @@ class Typifier:
 
     The following general rules apply:
 
-     - The context's `default_dtype` is applied to all untyped symbols
+     - The context's ``default_dtype`` is applied to all untyped symbols encountered inside a right-hand side expression
+     - If an untyped symbol is encountered on an assignment's left-hand side, it will first be attempted to infer its
+       type from the right-hand side. If that fails, the context's ``default_dtype`` will be applied.
+     - It is an error if an untyped symbol occurs in the same type context as a typed symbol or constant
+       with a non-default data type.
      - By default, all expressions receive a ``const`` type unless they occur on a (non-declaration) assignment's
        left-hand side
 
@@ -280,7 +288,12 @@ class Typifier:
 
     def __call__(self, node: NodeT) -> NodeT:
         if isinstance(node, PsExpression):
-            self.visit_expr(node, TypeContext())
+            tc = TypeContext()
+            self.visit_expr(node, tc)
+
+            if tc.target_type is None:
+                #   no type could be inferred -> take the default
+                tc.apply_dtype(self._ctx.default_dtype)
         else:
             self.visit(node)
         return node
@@ -304,20 +317,44 @@ class Typifier:
                     self.visit(s)
 
             case PsDeclaration(lhs, rhs):
+                #   Only if the LHS is an untyped symbol, infer its type from the RHS
+                infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None
+
                 tc = TypeContext()
-                #   LHS defines target type; type context carries it to RHS
-                self.visit_expr(lhs, tc)
-                assert tc.target_type is not None
+
+                if infer_lhs:
+                    tc.infer_dtype(lhs)
+                else:
+                    self.visit_expr(lhs, tc)
+                    assert tc.target_type is not None
+
                 self.visit_expr(rhs, tc)
 
+                if infer_lhs and tc.target_type is None:
+                    #   no type has been inferred -> use the default dtype
+                    tc.apply_dtype(self._ctx.default_dtype)
+
             case PsAssignment(lhs, rhs):
+                infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None
+
                 tc_lhs = TypeContext(require_nonconst=True)
-                self.visit_expr(lhs, tc_lhs)
-                assert tc_lhs.target_type is not None
 
-                tc_rhs = TypeContext(tc_lhs.target_type, require_nonconst=False)
+                if infer_lhs:
+                    tc_lhs.infer_dtype(lhs)
+                else:
+                    self.visit_expr(lhs, tc_lhs)
+                    assert tc_lhs.target_type is not None
+
+                tc_rhs = TypeContext(target_type=tc_lhs.target_type)
                 self.visit_expr(rhs, tc_rhs)
 
+                if infer_lhs:
+                    if tc_rhs.target_type is None:
+                        tc_rhs.apply_dtype(self._ctx.default_dtype)
+                    
+                    assert tc_rhs.target_type is not None
+                    tc_lhs.apply_dtype(deconstify(tc_rhs.target_type))
+
             case PsConditional(cond, branch_true, branch_false):
                 cond_tc = TypeContext(PsBoolType())
                 self.visit_expr(cond, cond_tc)
@@ -330,6 +367,7 @@ class Typifier:
             case PsLoop(ctr, start, stop, step, body):
                 if ctr.symbol.dtype is None:
                     ctr.symbol.apply_dtype(self._ctx.index_dtype)
+                    ctr.dtype = ctr.symbol.get_dtype()
 
                 tc_index = TypeContext(ctr.symbol.dtype)
                 self.visit_expr(start, tc_index)
@@ -355,11 +393,10 @@ class Typifier:
         either ``apply_dtype`` or ``infer_dtype``.
         """
         match expr:
-            case PsSymbolExpr(_):
-                if expr.symbol.dtype is None:
-                    expr.symbol.dtype = self._ctx.default_dtype
-
-                tc.apply_dtype(expr.symbol.dtype, expr)
+            case PsSymbolExpr(symb):
+                if symb.dtype is None:
+                    symb.dtype = self._ctx.default_dtype
+                tc.apply_dtype(symb.dtype, expr)
 
             case PsConstantExpr(c):
                 if c.dtype is not None:
diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py
index 5c2631e1eba1eaa84eb8f7ba6442ec020a191245..d3da7e8881d631266d58ee6cc0d4d3612a2900a1 100644
--- a/tests/nbackend/kernelcreation/test_typification.py
+++ b/tests/nbackend/kernelcreation/test_typification.py
@@ -4,7 +4,7 @@ import numpy as np
 
 from typing import cast
 
-from pystencils import Assignment, TypedSymbol, Field, FieldType
+from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment
 
 from pystencils.backend.ast.structural import (
     PsDeclaration,
@@ -16,6 +16,7 @@ from pystencils.backend.ast.structural import (
 from pystencils.backend.ast.expressions import (
     PsConstantExpr,
     PsSymbolExpr,
+    PsSubscript,
     PsBinOp,
     PsAnd,
     PsOr,
@@ -27,12 +28,12 @@ from pystencils.backend.ast.expressions import (
     PsGt,
     PsLt,
     PsCall,
-    PsTernary
+    PsTernary,
 )
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.functions import CFunction
 from pystencils.types import constify, create_type, create_numeric_type
-from pystencils.types.quick import Fp, Int, Bool
+from pystencils.types.quick import Fp, Int, Bool, Arr
 from pystencils.backend.kernelcreation.context import KernelCreationContext
 from pystencils.backend.kernelcreation.freeze import FreezeExpressions
 from pystencils.backend.kernelcreation.typification import Typifier, TypificationError
@@ -186,7 +187,7 @@ def test_typify_structs():
         fasm = typify(fasm)
 
 
-def test_contextual_typing():
+def test_default_typing():
     ctx = KernelCreationContext()
     freeze = FreezeExpressions(ctx)
     typify = Typifier(ctx)
@@ -213,6 +214,45 @@ def test_contextual_typing():
     check(expr)
 
 
+def test_lhs_inference():
+    ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    x, y, z = sp.symbols("x, y, z")
+    q = TypedSymbol("q", np.float32)
+    w = TypedSymbol("w", np.float16)
+
+    #   Type of the LHS is propagated to untyped RHS symbols
+
+    asm = Assignment(x, 3 - q)
+    fasm = typify(freeze(asm))
+
+    assert ctx.get_symbol("x").dtype == Fp(32)
+    assert fasm.lhs.dtype == constify(Fp(32))
+
+    asm = Assignment(y, 3 - w)
+    fasm = typify(freeze(asm))
+
+    assert ctx.get_symbol("y").dtype == Fp(16)
+    assert fasm.lhs.dtype == constify(Fp(16))
+
+    fasm = PsAssignment(PsExpression.make(ctx.get_symbol("z")), freeze(3 - w))
+    fasm = typify(fasm)
+
+    assert ctx.get_symbol("z").dtype == Fp(16)
+    assert fasm.lhs.dtype == Fp(16)
+
+    fasm = PsDeclaration(
+        PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q))
+    )
+    fasm = typify(fasm)
+
+    assert ctx.get_symbol("r").dtype == Bool()
+    assert fasm.lhs.dtype == constify(Bool())
+    assert fasm.rhs.dtype == constify(Bool())
+
+
 def test_erronous_typing():
     ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
     freeze = FreezeExpressions(ctx)
@@ -227,16 +267,43 @@ def test_erronous_typing():
     with pytest.raises(TypificationError):
         typify(expr)
 
+    #   Conflict between LHS and RHS symbols
     asm = Assignment(q, 3 - w)
     fasm = freeze(asm)
     with pytest.raises(TypificationError):
         typify(fasm)
 
+    #   Do not propagate types back from LHS symbols to RHS symbols
     asm = Assignment(q, 3 - x)
     fasm = freeze(asm)
     with pytest.raises(TypificationError):
         typify(fasm)
 
+    asm = AddAugmentedAssignment(z, 3 - q)
+    fasm = freeze(asm)
+    with pytest.raises(TypificationError):
+        typify(fasm)
+
+
+def test_invalid_indices():
+    ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
+    typify = Typifier(ctx)
+
+    arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64))))
+    x, y, z = [PsExpression.make(ctx.get_symbol(x)) for x in "xyz"]
+
+    #   Using default-typed symbols as array indices is illegal when the default type is a float
+
+    fasm = PsAssignment(PsSubscript(arr, x + y), z)
+
+    with pytest.raises(TypificationError):
+        typify(fasm)
+
+    fasm = PsAssignment(z, PsSubscript(arr, x + y))
+
+    with pytest.raises(TypificationError):
+        typify(fasm)
+
 
 def test_typify_integer_binops():
     ctx = KernelCreationContext()
@@ -366,7 +433,7 @@ def test_invalid_conditions():
     with pytest.raises(TypificationError):
         typify(cond)
 
-    
+
 def test_typify_ternary():
     ctx = KernelCreationContext()
     typify = Typifier(ctx)
diff --git a/tests/symbolics/test_conditional_field_access.py b/tests/symbolics/test_conditional_field_access.py
index e18ffc56a4b0a95c30ff2e9e2d4affa5567654ac..1dbc88cf45c1d5bd9735cb018bb80b06b5be5f37 100644
--- a/tests/symbolics/test_conditional_field_access.py
+++ b/tests/symbolics/test_conditional_field_access.py
@@ -51,9 +51,6 @@ def add_fixed_constant_boundary_handling(assignments, with_cse):
 @pytest.mark.parametrize('dtype', ('float64', 'float32'))
 @pytest.mark.parametrize('with_cse', (False, 'with_cse'))
 def test_boundary_check(dtype, with_cse):
-    if with_cse:
-        pytest.xfail("Doesn't typify correctly yet.")
-
     f, g = ps.fields(f"f, g : {dtype}[2D]")
     stencil = ps.Assignment(g[0, 0], (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4)