From 8bebd100116fcc192c22031d26d2eeb016e13be5 Mon Sep 17 00:00:00 2001 From: Michael Kuron <m.kuron@gmx.de> Date: Sun, 8 Sep 2024 14:33:21 +0200 Subject: [PATCH] Fix FMA insertion failures in remainder loops and with RNG --- src/pystencils/backends/cbackend.py | 8 ++++++ src/pystencils/fast_approximation.py | 10 +++++++- tests/test_vec_fma.py | 38 ++++++++++++++-------------- 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/src/pystencils/backends/cbackend.py b/src/pystencils/backends/cbackend.py index c00cf520a..f9e660f9d 100644 --- a/src/pystencils/backends/cbackend.py +++ b/src/pystencils/backends/cbackend.py @@ -571,6 +571,11 @@ class CustomSympyPrinter(CCodePrinter): return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))" elif expr.func == DivFunc: return f'(({self._print(expr.divisor)}) / ({self._print(expr.dividend)}))' + elif isinstance(expr, Fma): + a = expr.args[0] * (-1 if expr.instruction[0] == '-' else 1) + b = expr.args[1] + c = expr.args[2] * (-1 if expr.instruction[-1] == '-' else 1) + return f"fma({self._print(a)}, {self._print(b)}, {self._print(c)})" else: name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__ arg_str = ', '.join(self._print(a) for a in expr.args) @@ -729,6 +734,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): elif isinstance(expr, fast_inv_sqrt): raise ValueError("fast_inv_sqrt is only supported for Taget.GPU") elif isinstance(expr, Fma): + result = self._scalarFallback('_print_Function', expr) + if result: + return result return self.instruction_set[expr.instruction].format(self._print(expr.args[0]), self._print(expr.args[1]), self._print(expr.args[2]), **self._kwargs) elif isinstance(expr, vec_any) or isinstance(expr, vec_all): diff --git a/src/pystencils/fast_approximation.py b/src/pystencils/fast_approximation.py index de3dc62c3..7e678da41 100644 --- a/src/pystencils/fast_approximation.py +++ b/src/pystencils/fast_approximation.py @@ -121,6 +121,8 @@ def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollecti def insert_fma(term, operators): + from pystencils.rng import RNGBase # late import to avoid cyclic dependency + if '*+' not in operators: return term @@ -144,8 +146,11 @@ def insert_fma(term, operators): return expr def visit(expr): + # Special treatments for various types that cannot be reconstructed from their args if isinstance(expr, ResolvedFieldAccess): return expr + elif isinstance(expr, RNGBase): + return expr elif hasattr(expr, 'body'): old_parent = expr.body.parent if hasattr(expr.body, 'parent') else None expr.body = visit(expr.body) @@ -154,7 +159,9 @@ def insert_fma(term, operators): return expr elif isinstance(expr, Block): return Block([visit(a) for a in expr.args]) - elif expr.func == sp.Add: + + # Find patterns of Add and Mul nodes that can be fused + if expr.func == sp.Add: expr = flatten(expr) summands = list(expr.args) if '-*+' in operators: @@ -210,6 +217,7 @@ def insert_fma(term, operators): summands = [visit(s) for s in summands] return sp.Add(fmadd(factors[0], sp.Mul(*factors[1:]), summands[0]), *summands[1:]) return expr + # Find Mul with three factors, one of them -1, which can be fused elif expr.func == sp.Mul and -1 in expr.args: expr = flatten(expr) factors = list(expr.args) diff --git a/tests/test_vec_fma.py b/tests/test_vec_fma.py index 357c37de5..25389d640 100644 --- a/tests/test_vec_fma.py +++ b/tests/test_vec_fma.py @@ -11,10 +11,10 @@ supported_instruction_sets = get_supported_instruction_sets() if get_supported_i @pytest.mark.parametrize('dtype', ('float32', 'float64')) @pytest.mark.parametrize('instruction_set', supported_instruction_sets) def test_fmadd(instruction_set, dtype): - da = 2 * np.ones((128, 128), dtype=dtype) - db = 3 * np.ones((128, 128), dtype=dtype) - dc = 5 * np.ones((128, 128), dtype=dtype) - dd = np.empty((128, 128), dtype=dtype) + da = 2 * np.ones((129, 129), dtype=dtype) + db = 3 * np.ones((129, 129), dtype=dtype) + dc = 5 * np.ones((129, 129), dtype=dtype) + dd = np.empty((129, 129), dtype=dtype) a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd) update_rule = [ps.Assignment(d.center(), a.center() * b.center() + c.center())] @@ -33,10 +33,10 @@ def test_fmadd(instruction_set, dtype): @pytest.mark.parametrize('dtype', ('float32', 'float64')) @pytest.mark.parametrize('instruction_set', supported_instruction_sets) def test_fmsub(instruction_set, dtype): - da = 2 * np.ones((128, 128), dtype=dtype) - db = 3 * np.ones((128, 128), dtype=dtype) - dc = 5 * np.ones((128, 128), dtype=dtype) - dd = np.empty((128, 128), dtype=dtype) + da = 2 * np.ones((129, 129), dtype=dtype) + db = 3 * np.ones((129, 129), dtype=dtype) + dc = 5 * np.ones((129, 129), dtype=dtype) + dd = np.empty((129, 129), dtype=dtype) a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd) update_rule = [ps.Assignment(d.center(), a.center() * b.center() - c.center())] @@ -57,10 +57,10 @@ def test_fmsub(instruction_set, dtype): @pytest.mark.parametrize('dtype', ('float32', 'float64')) @pytest.mark.parametrize('instruction_set', supported_instruction_sets) def test_fnmadd(instruction_set, dtype): - da = 2 * np.ones((128, 128), dtype=dtype) - db = 3 * np.ones((128, 128), dtype=dtype) - dc = 5 * np.ones((128, 128), dtype=dtype) - dd = np.empty((128, 128), dtype=dtype) + da = 2 * np.ones((129, 129), dtype=dtype) + db = 3 * np.ones((129, 129), dtype=dtype) + dc = 5 * np.ones((129, 129), dtype=dtype) + dd = np.empty((129, 129), dtype=dtype) a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd) update_rule = [ps.Assignment(d.center(), -a.center() * b.center() + c.center())] @@ -81,10 +81,10 @@ def test_fnmadd(instruction_set, dtype): @pytest.mark.parametrize('dtype', ('float32', 'float64')) @pytest.mark.parametrize('instruction_set', supported_instruction_sets) def test_fnmsub(instruction_set, dtype): - da = 2 * np.ones((128, 128), dtype=dtype) - db = 3 * np.ones((128, 128), dtype=dtype) - dc = 5 * np.ones((128, 128), dtype=dtype) - dd = np.empty((128, 128), dtype=dtype) + da = 2 * np.ones((129, 129), dtype=dtype) + db = 3 * np.ones((129, 129), dtype=dtype) + dc = 5 * np.ones((129, 129), dtype=dtype) + dd = np.empty((129, 129), dtype=dtype) a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd) update_rule = [ps.Assignment(d.center(), -a.center() * b.center() - c.center())] @@ -105,9 +105,9 @@ def test_fnmsub(instruction_set, dtype): @pytest.mark.parametrize('dtype', ('float32', 'float64')) @pytest.mark.parametrize('instruction_set', supported_instruction_sets) def test_fnm(instruction_set, dtype): - da = 2 * np.ones((128, 128), dtype=dtype) - db = 3 * np.ones((128, 128), dtype=dtype) - dd = np.empty((128, 128), dtype=dtype) + da = 2 * np.ones((129, 129), dtype=dtype) + db = 3 * np.ones((129, 129), dtype=dtype) + dd = np.empty((129, 129), dtype=dtype) a, b, d = ps.fields(a=da, b=db, d=dd) update_rule = [ps.Assignment(d.center(), -a.center() * b.center())] -- GitLab