diff --git a/src/pystencils/backends/cbackend.py b/src/pystencils/backends/cbackend.py
index c00cf520a8329ae869bdbdec6cb4d48763c9ae0b..f9e660f9dfe210303783312767fe83970f65691b 100644
--- a/src/pystencils/backends/cbackend.py
+++ b/src/pystencils/backends/cbackend.py
@@ -571,6 +571,11 @@ class CustomSympyPrinter(CCodePrinter):
             return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))"
         elif expr.func == DivFunc:
             return f'(({self._print(expr.divisor)}) / ({self._print(expr.dividend)}))'
+        elif isinstance(expr, Fma):
+            a = expr.args[0] * (-1 if expr.instruction[0] == '-' else 1)
+            b = expr.args[1]
+            c = expr.args[2] * (-1 if expr.instruction[-1] == '-' else 1)
+            return f"fma({self._print(a)}, {self._print(b)}, {self._print(c)})"
         else:
             name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__
             arg_str = ', '.join(self._print(a) for a in expr.args)
@@ -729,6 +734,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         elif isinstance(expr, fast_inv_sqrt):
             raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
         elif isinstance(expr, Fma):
+            result = self._scalarFallback('_print_Function', expr)
+            if result:
+                return result
             return self.instruction_set[expr.instruction].format(self._print(expr.args[0]), self._print(expr.args[1]),
                                                                  self._print(expr.args[2]), **self._kwargs)
         elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
diff --git a/src/pystencils/fast_approximation.py b/src/pystencils/fast_approximation.py
index de3dc62c3c4cbf9a8a8c5234292dfdd45c38fa21..7e678da412f1e99a00c55384f95fc13e491948a2 100644
--- a/src/pystencils/fast_approximation.py
+++ b/src/pystencils/fast_approximation.py
@@ -121,6 +121,8 @@ def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollecti
 
 
 def insert_fma(term, operators):
+    from pystencils.rng import RNGBase  # late import to avoid cyclic dependency
+
     if '*+' not in operators:
         return term
 
@@ -144,8 +146,11 @@ def insert_fma(term, operators):
         return expr
 
     def visit(expr):
+        # Special treatments for various types that cannot be reconstructed from their args
         if isinstance(expr, ResolvedFieldAccess):
             return expr
+        elif isinstance(expr, RNGBase):
+            return expr
         elif hasattr(expr, 'body'):
             old_parent = expr.body.parent if hasattr(expr.body, 'parent') else None
             expr.body = visit(expr.body)
@@ -154,7 +159,9 @@ def insert_fma(term, operators):
             return expr
         elif isinstance(expr, Block):
             return Block([visit(a) for a in expr.args])
-        elif expr.func == sp.Add:
+
+        # Find patterns of Add and Mul nodes that can be fused
+        if expr.func == sp.Add:
             expr = flatten(expr)
             summands = list(expr.args)
             if '-*+' in operators:
@@ -210,6 +217,7 @@ def insert_fma(term, operators):
                     summands = [visit(s) for s in summands]
                     return sp.Add(fmadd(factors[0], sp.Mul(*factors[1:]), summands[0]), *summands[1:])
             return expr
+        # Find Mul with three factors, one of them -1, which can be fused
         elif expr.func == sp.Mul and -1 in expr.args:
             expr = flatten(expr)
             factors = list(expr.args)
diff --git a/tests/test_vec_fma.py b/tests/test_vec_fma.py
index 357c37de56188f944a64fcff4b7657e278f845bd..25389d640663d86a675b4efa47d92c3112e284e3 100644
--- a/tests/test_vec_fma.py
+++ b/tests/test_vec_fma.py
@@ -11,10 +11,10 @@ supported_instruction_sets = get_supported_instruction_sets() if get_supported_i
 @pytest.mark.parametrize('dtype', ('float32', 'float64'))
 @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
 def test_fmadd(instruction_set, dtype):
-    da = 2 * np.ones((128, 128), dtype=dtype)
-    db = 3 * np.ones((128, 128), dtype=dtype)
-    dc = 5 * np.ones((128, 128), dtype=dtype)
-    dd = np.empty((128, 128), dtype=dtype)
+    da = 2 * np.ones((129, 129), dtype=dtype)
+    db = 3 * np.ones((129, 129), dtype=dtype)
+    dc = 5 * np.ones((129, 129), dtype=dtype)
+    dd = np.empty((129, 129), dtype=dtype)
 
     a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd)
     update_rule = [ps.Assignment(d.center(), a.center() * b.center() + c.center())]
@@ -33,10 +33,10 @@ def test_fmadd(instruction_set, dtype):
 @pytest.mark.parametrize('dtype', ('float32', 'float64'))
 @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
 def test_fmsub(instruction_set, dtype):
-    da = 2 * np.ones((128, 128), dtype=dtype)
-    db = 3 * np.ones((128, 128), dtype=dtype)
-    dc = 5 * np.ones((128, 128), dtype=dtype)
-    dd = np.empty((128, 128), dtype=dtype)
+    da = 2 * np.ones((129, 129), dtype=dtype)
+    db = 3 * np.ones((129, 129), dtype=dtype)
+    dc = 5 * np.ones((129, 129), dtype=dtype)
+    dd = np.empty((129, 129), dtype=dtype)
 
     a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd)
     update_rule = [ps.Assignment(d.center(), a.center() * b.center() - c.center())]
@@ -57,10 +57,10 @@ def test_fmsub(instruction_set, dtype):
 @pytest.mark.parametrize('dtype', ('float32', 'float64'))
 @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
 def test_fnmadd(instruction_set, dtype):
-    da = 2 * np.ones((128, 128), dtype=dtype)
-    db = 3 * np.ones((128, 128), dtype=dtype)
-    dc = 5 * np.ones((128, 128), dtype=dtype)
-    dd = np.empty((128, 128), dtype=dtype)
+    da = 2 * np.ones((129, 129), dtype=dtype)
+    db = 3 * np.ones((129, 129), dtype=dtype)
+    dc = 5 * np.ones((129, 129), dtype=dtype)
+    dd = np.empty((129, 129), dtype=dtype)
 
     a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd)
     update_rule = [ps.Assignment(d.center(), -a.center() * b.center() + c.center())]
@@ -81,10 +81,10 @@ def test_fnmadd(instruction_set, dtype):
 @pytest.mark.parametrize('dtype', ('float32', 'float64'))
 @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
 def test_fnmsub(instruction_set, dtype):
-    da = 2 * np.ones((128, 128), dtype=dtype)
-    db = 3 * np.ones((128, 128), dtype=dtype)
-    dc = 5 * np.ones((128, 128), dtype=dtype)
-    dd = np.empty((128, 128), dtype=dtype)
+    da = 2 * np.ones((129, 129), dtype=dtype)
+    db = 3 * np.ones((129, 129), dtype=dtype)
+    dc = 5 * np.ones((129, 129), dtype=dtype)
+    dd = np.empty((129, 129), dtype=dtype)
 
     a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd)
     update_rule = [ps.Assignment(d.center(), -a.center() * b.center() - c.center())]
@@ -105,9 +105,9 @@ def test_fnmsub(instruction_set, dtype):
 @pytest.mark.parametrize('dtype', ('float32', 'float64'))
 @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
 def test_fnm(instruction_set, dtype):
-    da = 2 * np.ones((128, 128), dtype=dtype)
-    db = 3 * np.ones((128, 128), dtype=dtype)
-    dd = np.empty((128, 128), dtype=dtype)
+    da = 2 * np.ones((129, 129), dtype=dtype)
+    db = 3 * np.ones((129, 129), dtype=dtype)
+    dd = np.empty((129, 129), dtype=dtype)
 
     a, b, d = ps.fields(a=da, b=db, d=dd)
     update_rule = [ps.Assignment(d.center(), -a.center() * b.center())]