Commit 2cb231b2 authored by Martin Bauer's avatar Martin Bauer
Browse files

FLOPs counting now also counts sqrts, invsqrts and their fast approximations

parent 0f298c63
......@@ -6,7 +6,6 @@ from collections import defaultdict, Counter
import sympy as sp
from sympy.functions import Abs
from typing import Optional, Union, List, TypeVar, Iterable, Sequence, Callable, Dict, Tuple
from pystencils.data_types import get_type_of_expression, get_base_type, cast_func
from pystencils.assignment import Assignment
......@@ -233,7 +232,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients))
if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match):
# find common factor
factors = defaultdict(lambda: 0)
factors = defaultdict(int)
skips = 0
for common_symbol in subexpression_coefficient_dict.keys():
if common_symbol not in expr_coefficients:
......@@ -428,7 +427,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
Returns:
dict with 'adds', 'muls' and 'divs' keys
"""
result = {'adds': 0, 'muls': 0, 'divs': 0}
from pystencils.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division
result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
if isinstance(term, Sequence):
for element in term:
......@@ -480,6 +482,12 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif isinstance(t, cast_func):
visit_children = False
visit(t.args[0])
elif t.func is fast_sqrt:
result['fast_sqrts'] += 1
elif t.func is fast_inv_sqrt:
result['fast_inv_sqrts'] += 1
elif t.func is fast_division:
result['fast_div'] += 1
elif t.func is sp.Pow:
if check_type(t.args[0]):
visit_children = False
......@@ -490,6 +498,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
result['muls'] -= 1
result['divs'] += 1
result['muls'] += (-int(t.exp)) - 1
elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
result['sqrts'] += 1
else:
warnings.warn("Cannot handle exponent", t.exp, " of sp.Pow node")
else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate")
......@@ -513,14 +525,13 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
def count_operations_in_ast(ast) -> Dict[str, int]:
"""Counts number of operations in an abstract syntax tree, see also :func:`count_operations`"""
from pystencils.astnodes import SympyAssignment
result = {'adds': 0, 'muls': 0, 'divs': 0}
result = defaultdict(int)
def visit(node):
if isinstance(node, SympyAssignment):
r = count_operations(node.rhs)
result['adds'] += r['adds']
result['muls'] += r['muls']
result['divs'] += r['divs']
for k, v in r.items():
result[k] += v
else:
for arg in node.args:
visit(arg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment