diff --git a/sympyextensions.py b/sympyextensions.py index ec27fd238aba33c133d040325c4c5171d0e6a79d..8320a752d1eb69db6c2ff46eff9e14478b1d060e 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -7,7 +7,7 @@ import sympy as sp from sympy.functions import Abs from typing import Optional, Union, List, TypeVar, Iterable, Sequence, Callable, Dict, Tuple -from pystencils.data_types import get_type_of_expression, get_base_type +from pystencils.data_types import get_type_of_expression, get_base_type, cast_func from pystencils.assignment import Assignment T = TypeVar('T') @@ -428,7 +428,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], for a in t.args: if a == 1 or a == -1: result['muls'] -= 1 - elif t.func is sp.Float: + elif isinstance(t, sp.Float) or isinstance(t, sp.Rational): pass elif isinstance(t, sp.Symbol): visit_children = False @@ -436,6 +436,9 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], visit_children = False elif t.is_integer: pass + elif t.func is cast_func: + visit_children = False + visit(t.args[0]) elif t.func is sp.Pow: if check_type(t.args[0]): visit_children = False