From 2cb231b248da7ea05ae9f101a6613647669348eb Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Tue, 2 Apr 2019 15:16:38 +0200 Subject: [PATCH] FLOPs counting now also counts sqrts, invsqrts and their fast approximations --- pystencils/sympyextensions.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index 22a2035..e40c108 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -6,7 +6,6 @@ from collections import defaultdict, Counter import sympy as sp from sympy.functions import Abs from typing import Optional, Union, List, TypeVar, Iterable, Sequence, Callable, Dict, Tuple - from pystencils.data_types import get_type_of_expression, get_base_type, cast_func from pystencils.assignment import Assignment @@ -233,7 +232,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients)) if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match): # find common factor - factors = defaultdict(lambda: 0) + factors = defaultdict(int) skips = 0 for common_symbol in subexpression_coefficient_dict.keys(): if common_symbol not in expr_coefficients: @@ -428,7 +427,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], Returns: dict with 'adds', 'muls' and 'divs' keys """ - result = {'adds': 0, 'muls': 0, 'divs': 0} + from pystencils.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division + + result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0, + 'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0} if isinstance(term, Sequence): for element in term: @@ -480,6 +482,12 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], elif isinstance(t, cast_func): visit_children = False visit(t.args[0]) + elif t.func is fast_sqrt: + result['fast_sqrts'] += 1 + elif t.func is fast_inv_sqrt: + result['fast_inv_sqrts'] += 1 + elif t.func is fast_division: + result['fast_div'] += 1 elif t.func is sp.Pow: if check_type(t.args[0]): visit_children = False @@ -490,6 +498,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], result['muls'] -= 1 result['divs'] += 1 result['muls'] += (-int(t.exp)) - 1 + elif sp.nsimplify(t.exp) == sp.Rational(1, 2): + result['sqrts'] += 1 + else: + warnings.warn("Cannot handle exponent", t.exp, " of sp.Pow node") else: warnings.warn("Counting operations: only integer exponents are supported in Pow, " "counting will be inaccurate") @@ -513,14 +525,13 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], def count_operations_in_ast(ast) -> Dict[str, int]: """Counts number of operations in an abstract syntax tree, see also :func:`count_operations`""" from pystencils.astnodes import SympyAssignment - result = {'adds': 0, 'muls': 0, 'divs': 0} + result = defaultdict(int) def visit(node): if isinstance(node, SympyAssignment): r = count_operations(node.rhs) - result['adds'] += r['adds'] - result['muls'] += r['muls'] - result['divs'] += r['divs'] + for k, v in r.items(): + result[k] += v else: for arg in node.args: visit(arg) -- GitLab