Skip to content
Snippets Groups Projects
Commit 5836b20e authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'count_ops' into 'master'

count_operations: fix to not count integer expressions for addresses/constants as real operations

See merge request !171
parents 99257435 dab3371d
Branches
Tags
1 merge request!171count_operations: fix to not count integer expressions for addresses/constants as real operations
Pipeline #27029 passed with warnings with stages
in 30 minutes and 42 seconds
......@@ -10,7 +10,8 @@ from sympy.functions import Abs
from sympy.core.numbers import Zero
from pystencils.assignment import Assignment
from pystencils.data_types import cast_func, get_base_type, get_type_of_expression
from pystencils.data_types import cast_func, get_type_of_expression, PointerType
from pystencils.kernelparameters import FieldPointerSymbol
T = TypeVar('T')
......@@ -445,7 +446,6 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
if isinstance(term, Sequence):
for element in term:
r = count_operations(element, only_type)
......@@ -455,16 +455,18 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif isinstance(term, Assignment):
term = term.rhs
if hasattr(term, 'evalf'):
term = term.evalf()
def check_type(e):
if only_type is None:
return True
if isinstance(e, FieldPointerSymbol) and only_type == "real":
return only_type == "int"
try:
base_type = get_base_type(get_type_of_expression(e))
base_type = get_type_of_expression(e)
except ValueError:
return False
if isinstance(base_type, PointerType):
return only_type == 'int'
if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
return True
if only_type == 'real' and (base_type.is_float()):
......@@ -515,6 +517,9 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
result['muls'] += (-int(t.exp)) - 1
elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
result['sqrts'] += 1
elif sp.nsimplify(t.exp) == -sp.Rational(1, 2):
result["sqrts"] += 1
result["divs"] += 1
else:
warnings.warn(f"Cannot handle exponent {t.exp} of sp.Pow node")
else:
......
......@@ -30,7 +30,7 @@ def test_simplification_strategy():
result = strategy(ac)
assert result.operation_count['adds'] == 7
assert result.operation_count['muls'] == 5
assert result.operation_count['muls'] == 4
assert result.operation_count['divs'] == 0
# Trigger display routines, such that they are at least executed
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment