From 2cb231b248da7ea05ae9f101a6613647669348eb Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Tue, 2 Apr 2019 15:16:38 +0200
Subject: [PATCH] FLOPs counting now also counts sqrts, invsqrts and their fast
 approximations

---
 pystencils/sympyextensions.py | 25 ++++++++++++++++++-------
 1 file changed, 18 insertions(+), 7 deletions(-)

diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index 22a2035e9..e40c10895 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -6,7 +6,6 @@ from collections import defaultdict, Counter
 import sympy as sp
 from sympy.functions import Abs
 from typing import Optional, Union, List, TypeVar, Iterable, Sequence, Callable, Dict, Tuple
-
 from pystencils.data_types import get_type_of_expression, get_base_type, cast_func
 from pystencils.assignment import Assignment
 
@@ -233,7 +232,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
             intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients))
             if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match):
                 # find common factor
-                factors = defaultdict(lambda: 0)
+                factors = defaultdict(int)
                 skips = 0
                 for common_symbol in subexpression_coefficient_dict.keys():
                     if common_symbol not in expr_coefficients:
@@ -428,7 +427,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
     Returns:
         dict with 'adds', 'muls' and 'divs' keys
     """
-    result = {'adds': 0, 'muls': 0, 'divs': 0}
+    from pystencils.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division
+
+    result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
+              'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
 
     if isinstance(term, Sequence):
         for element in term:
@@ -480,6 +482,12 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
         elif isinstance(t, cast_func):
             visit_children = False
             visit(t.args[0])
+        elif t.func is fast_sqrt:
+            result['fast_sqrts'] += 1
+        elif t.func is fast_inv_sqrt:
+            result['fast_inv_sqrts'] += 1
+        elif t.func is fast_division:
+            result['fast_div'] += 1
         elif t.func is sp.Pow:
             if check_type(t.args[0]):
                 visit_children = False
@@ -490,6 +498,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
                         result['muls'] -= 1
                         result['divs'] += 1
                         result['muls'] += (-int(t.exp)) - 1
+                elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
+                    result['sqrts'] += 1
+                else:
+                    warnings.warn("Cannot handle exponent", t.exp, " of sp.Pow node")
             else:
                 warnings.warn("Counting operations: only integer exponents are supported in Pow, "
                               "counting will be inaccurate")
@@ -513,14 +525,13 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
 def count_operations_in_ast(ast) -> Dict[str, int]:
     """Counts number of operations in an abstract syntax tree, see also :func:`count_operations`"""
     from pystencils.astnodes import SympyAssignment
-    result = {'adds': 0, 'muls': 0, 'divs': 0}
+    result = defaultdict(int)
 
     def visit(node):
         if isinstance(node, SympyAssignment):
             r = count_operations(node.rhs)
-            result['adds'] += r['adds']
-            result['muls'] += r['muls']
-            result['divs'] += r['divs']
+            for k, v in r.items():
+                result[k] += v
         else:
             for arg in node.args:
                 visit(arg)
-- 
GitLab