From ad9cba7c25476c3db3fe5bc7d9efa76afd6c97d9 Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Mon, 7 Nov 2022 10:05:49 +0100
Subject: [PATCH] Improve FLOP counting function

---
 pystencils/sympyextensions.py            |  4 +++-
 pystencils_tests/test_sympyextensions.py | 25 ++++++++++++++++++++++++
 2 files changed, 28 insertions(+), 1 deletion(-)

diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index b9a452742..40be43eaa 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -639,8 +639,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
             for child_term, condition in t.args:
                 visit(child_term)
             visit_children = False
-        elif isinstance(t, sp.Rel):
+        elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)):
             pass
+        elif isinstance(t, DivFunc):
+            result["divs"] += 1
         else:
             warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate")
 
diff --git a/pystencils_tests/test_sympyextensions.py b/pystencils_tests/test_sympyextensions.py
index 38a138d2b..1929cc066 100644
--- a/pystencils_tests/test_sympyextensions.py
+++ b/pystencils_tests/test_sympyextensions.py
@@ -15,6 +15,7 @@ from pystencils.sympyextensions import scalar_product
 from pystencils.sympyextensions import kronecker_delta
 
 from pystencils import Assignment
+from pystencils.functions import DivFunc
 from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt,
                                            insert_fast_divisions, insert_fast_sqrts)
 
@@ -163,6 +164,30 @@ def test_count_operations():
     assert ops['divs'] == 1
     assert ops['sqrts'] == 1
 
+    expr = DivFunc(x, y)
+    ops = count_operations(expr, only_type=None)
+    assert ops['divs'] == 1
+
+    expr = DivFunc(x + z, y + z)
+    ops = count_operations(expr, only_type=None)
+    assert ops['adds'] == 2
+    assert ops['divs'] == 1
+
+    expr = sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))
+    ops = count_operations(expr, only_type=None)
+    assert ops['muls'] == 99
+
+    expr = DivFunc(1, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)))
+    ops = count_operations(expr, only_type=None)
+    assert ops['divs'] == 1
+    assert ops['muls'] == 99
+
+    expr = DivFunc(y + z, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)))
+    ops = count_operations(expr, only_type=None)
+    assert ops['adds'] == 1
+    assert ops['divs'] == 1
+    assert ops['muls'] == 99
+
 
 def test_common_denominator():
     x = sympy.symbols('x')
-- 
GitLab