diff --git a/sympyextensions.py b/sympyextensions.py index 8320a752d1eb69db6c2ff46eff9e14478b1d060e..71fc975a8a99ec22411620bd46eb0647198e3e94 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -18,6 +18,20 @@ def prod(seq: Iterable[T]) -> T: return reduce(operator.mul, seq, 1) +def remove_small_floats(expr, threshold): + """Removes all sp.Float objects whose absolute value is smaller than threshold + + >>> expr = sp.sympify("x + 1e-15 * y") + >>> remove_small_floats(expr, 1e-14) + x + """ + if isinstance(expr, sp.Float) and sp.Abs(expr) < threshold: + return 0 + else: + new_args = [remove_small_floats(c, threshold) for c in expr.args] + return expr.func(*new_args) if new_args else expr + + def is_integer_sequence(sequence: Iterable) -> bool: """Checks if all elements of the passed sequence can be cast to integers""" try: