Skip to content
Snippets Groups Projects

[Fix] Printing of subtraction

Merged Markus Holzer requested to merge holzer/pystencils:FixSubtraction into master
Viewing commit 992cdc48
Prev
Show latest version
3 files
+ 18
31
Preferences
Compare changes
Files
3
@@ -6,6 +6,7 @@ import numpy as np
import sympy as sp
from sympy import Piecewise
from sympy.core.numbers import NegativeOne
from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
@@ -228,6 +229,15 @@ class TypeAdder:
new_func = expr.func(*new_args) if new_args else expr
return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)):
# Subtraction is realised a multiplication with -1 in SymPy. Thus we exclude the coefficient in this case
# and resolve the typing entirely with the expression itself
if isinstance(expr, sp.Mul):
c, e = expr.as_coeff_Mul()
if c == NegativeOne():
args_types = self.figure_out_type(e)
new_args = [NegativeOne(), args_types[0]]
return expr.func(*new_args, evaluate=False), args_types[1]
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
if isinstance(collated_type, PointerType):