Skip to content
Snippets Groups Projects
Commit 2cb231b2 authored by Martin Bauer's avatar Martin Bauer
Browse files

FLOPs counting now also counts sqrts, invsqrts and their fast approximations

parent 0f298c63
No related merge requests found
......@@ -6,7 +6,6 @@ from collections import defaultdict, Counter
import sympy as sp
from sympy.functions import Abs
from typing import Optional, Union, List, TypeVar, Iterable, Sequence, Callable, Dict, Tuple
from pystencils.data_types import get_type_of_expression, get_base_type, cast_func
from pystencils.assignment import Assignment
......@@ -233,7 +232,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients))
if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match):
# find common factor
factors = defaultdict(lambda: 0)
factors = defaultdict(int)
skips = 0
for common_symbol in subexpression_coefficient_dict.keys():
if common_symbol not in expr_coefficients:
......@@ -428,7 +427,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
Returns:
dict with 'adds', 'muls' and 'divs' keys
"""
result = {'adds': 0, 'muls': 0, 'divs': 0}
from pystencils.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division
result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
if isinstance(term, Sequence):
for element in term:
......@@ -480,6 +482,12 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif isinstance(t, cast_func):
visit_children = False
visit(t.args[0])
elif t.func is fast_sqrt:
result['fast_sqrts'] += 1
elif t.func is fast_inv_sqrt:
result['fast_inv_sqrts'] += 1
elif t.func is fast_division:
result['fast_div'] += 1
elif t.func is sp.Pow:
if check_type(t.args[0]):
visit_children = False
......@@ -490,6 +498,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
result['muls'] -= 1
result['divs'] += 1
result['muls'] += (-int(t.exp)) - 1
elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
result['sqrts'] += 1
else:
warnings.warn("Cannot handle exponent", t.exp, " of sp.Pow node")
else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate")
......@@ -513,14 +525,13 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
def count_operations_in_ast(ast) -> Dict[str, int]:
"""Counts number of operations in an abstract syntax tree, see also :func:`count_operations`"""
from pystencils.astnodes import SympyAssignment
result = {'adds': 0, 'muls': 0, 'divs': 0}
result = defaultdict(int)
def visit(node):
if isinstance(node, SympyAssignment):
r = count_operations(node.rhs)
result['adds'] += r['adds']
result['muls'] += r['muls']
result['divs'] += r['divs']
for k, v in r.items():
result[k] += v
else:
for arg in node.args:
visit(arg)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment