diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index b9a452742aab763b5cb631525b303f4c106e2d64..40be43eaaa2b6b014bf93447afcb8a68ab407eff 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -639,8 +639,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], for child_term, condition in t.args: visit(child_term) visit_children = False - elif isinstance(t, sp.Rel): + elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)): pass + elif isinstance(t, DivFunc): + result["divs"] += 1 else: warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate") diff --git a/pystencils_tests/test_sympyextensions.py b/pystencils_tests/test_sympyextensions.py index 38a138d2b0d4a52d56214a6f2c5c4f0a7dedfd9c..1929cc0666a302de3e10423c40794f219bfa38ae 100644 --- a/pystencils_tests/test_sympyextensions.py +++ b/pystencils_tests/test_sympyextensions.py @@ -15,6 +15,7 @@ from pystencils.sympyextensions import scalar_product from pystencils.sympyextensions import kronecker_delta from pystencils import Assignment +from pystencils.functions import DivFunc from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt, insert_fast_divisions, insert_fast_sqrts) @@ -163,6 +164,30 @@ def test_count_operations(): assert ops['divs'] == 1 assert ops['sqrts'] == 1 + expr = DivFunc(x, y) + ops = count_operations(expr, only_type=None) + assert ops['divs'] == 1 + + expr = DivFunc(x + z, y + z) + ops = count_operations(expr, only_type=None) + assert ops['adds'] == 2 + assert ops['divs'] == 1 + + expr = sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)) + ops = count_operations(expr, only_type=None) + assert ops['muls'] == 99 + + expr = DivFunc(1, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))) + ops = count_operations(expr, only_type=None) + assert ops['divs'] == 1 + assert ops['muls'] == 99 + + expr = DivFunc(y + z, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))) + ops = count_operations(expr, only_type=None) + assert ops['adds'] == 1 + assert ops['divs'] == 1 + assert ops['muls'] == 99 + def test_common_denominator(): x = sympy.symbols('x')