diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 7c743a3997071a2f1515dd5802ee6f69cb741375..0666d96873d4bdd3d722a7912b6e704b4aee1cf8 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -149,7 +149,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): return self._constant == other._constant def __repr__(self) -> str: - return f"Constant({repr(self._constant)})" + return f"PsConstantExpr({repr(self._constant)})" class PsSubscript(PsLvalue, PsExpression): @@ -385,6 +385,18 @@ class PsCall(PsExpression): return super().structurally_equal(other) and self._function == other._function +class PsNumericOpTrait: + """Trait for operations valid only on numerical types""" + + +class PsIntOpTrait: + """Trait for operations valid only on integer types""" + + +class PsBoolOpTrait: + """Trait for boolean operations""" + + class PsUnOp(PsExpression): __match_args__ = ("operand",) @@ -414,8 +426,12 @@ class PsUnOp(PsExpression): def python_operator(self) -> None | Callable[[Any], Any]: return None + def __repr__(self) -> str: + opname = self.__class__.__name__ + return f"{opname}({repr(self._operand)})" + -class PsNeg(PsUnOp): +class PsNeg(PsUnOp, PsNumericOpTrait): @property def python_operator(self): return operator.neg @@ -503,31 +519,31 @@ class PsBinOp(PsExpression): return None -class PsAdd(PsBinOp): +class PsAdd(PsBinOp, PsNumericOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.add -class PsSub(PsBinOp): +class PsSub(PsBinOp, PsNumericOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.sub -class PsMul(PsBinOp): +class PsMul(PsBinOp, PsNumericOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.mul -class PsDiv(PsBinOp): +class PsDiv(PsBinOp, PsNumericOpTrait): # python_operator not implemented because can't unambigously decide # between intdiv and truediv pass -class PsIntDiv(PsBinOp): +class PsIntDiv(PsBinOp, PsIntOpTrait): """C-like integer division (round to zero).""" # python_operator not implemented because both floordiv and truediv have @@ -535,36 +551,94 @@ class PsIntDiv(PsBinOp): pass -class PsLeftShift(PsBinOp): +class PsLeftShift(PsBinOp, PsIntOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.lshift -class PsRightShift(PsBinOp): +class PsRightShift(PsBinOp, PsIntOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.rshift -class PsBitwiseAnd(PsBinOp): +class PsBitwiseAnd(PsBinOp, PsIntOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.and_ -class PsBitwiseXor(PsBinOp): +class PsBitwiseXor(PsBinOp, PsIntOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.xor -class PsBitwiseOr(PsBinOp): +class PsBitwiseOr(PsBinOp, PsIntOpTrait): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.or_ + + +class PsAnd(PsBinOp, PsBoolOpTrait): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.and_ + + +class PsOr(PsBinOp, PsBoolOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.or_ +class PsNot(PsUnOp, PsBoolOpTrait): + @property + def python_operator(self) -> Callable[[Any], Any] | None: + return operator.not_ + + +class PsRel(PsBinOp): + """Base class for binary relational operators""" + + +class PsEq(PsRel): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.eq + + +class PsNe(PsRel): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.ne + + +class PsGe(PsRel): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.ge + + +class PsLe(PsRel): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.le + + +class PsGt(PsRel): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.gt + + +class PsLt(PsRel): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.lt + + class PsArrayInitList(PsExpression): __match_args__ = ("items",) diff --git a/src/pystencils/backend/ast/logical_expressions.py b/src/pystencils/backend/ast/logical_expressions.py deleted file mode 100644 index 2d739e020c2261a7d0b6fe917172223f1495c0e3..0000000000000000000000000000000000000000 --- a/src/pystencils/backend/ast/logical_expressions.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import Callable, Any -import operator - -from .expressions import PsExpression -from .astnode import PsAstNode -from .util import failing_cast - - -class PsLogicalExpression(PsExpression): - __match_args__ = ("operand1", "operand2") - - def __init__(self, op1: PsExpression, op2: PsExpression): - super().__init__() - self._op1 = op1 - self._op2 = op2 - - @property - def operand1(self) -> PsExpression: - return self._op1 - - @operand1.setter - def operand1(self, expr: PsExpression): - self._op1 = expr - - @property - def operand2(self) -> PsExpression: - return self._op2 - - @operand2.setter - def operand2(self, expr: PsExpression): - self._op2 = expr - - def clone(self): - return type(self)(self._op1.clone(), self._op2.clone()) - - def get_children(self) -> tuple[PsAstNode, ...]: - return self._op1, self._op2 - - def set_child(self, idx: int, c: PsAstNode): - idx = [0, 1][idx] - match idx: - case 0: - self._op1 = failing_cast(PsExpression, c) - case 1: - self._op2 = failing_cast(PsExpression, c) - - def __repr__(self) -> str: - opname = self.__class__.__name__ - return f"{opname}({repr(self._op1)}, {repr(self._op2)})" - - @property - def python_operator(self) -> None | Callable[[Any, Any], Any]: - return None - - -class PsAnd(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.and_ - - -class PsEq(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.eq - - -class PsGe(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.ge - - -class PsGt(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.gt - - -class PsLe(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.le - - -class PsLt(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.lt - - -class PsNe(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.ne diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index aa5f853a73301f118e8fea9e5654237392867fc9..588ac410a6118b668eacd08114fdea3c7853ba6f 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -35,6 +35,15 @@ from .ast.expressions import ( PsSubscript, PsSymbolExpr, PsVectorArrayAccess, + PsAnd, + PsOr, + PsNot, + PsEq, + PsNe, + PsGt, + PsLt, + PsGe, + PsLe, ) from .symbols import PsSymbol @@ -67,32 +76,41 @@ class Ops(Enum): See also https://en.cppreference.com/w/cpp/language/operator_precedence """ - Weakest = (17 - 17, LR.Middle) + Call = (2, LR.Left) + Subscript = (2, LR.Left) + Lookup = (2, LR.Left) - BitwiseOr = (17 - 13, LR.Left) + Neg = (3, LR.Right) + Not = (3, LR.Right) + AddressOf = (3, LR.Right) + Deref = (3, LR.Right) + Cast = (3, LR.Right) - BitwiseXor = (17 - 12, LR.Left) + Mul = (5, LR.Left) + Div = (5, LR.Left) + Rem = (5, LR.Left) - BitwiseAnd = (17 - 11, LR.Left) + Add = (6, LR.Left) + Sub = (6, LR.Left) - LeftShift = (17 - 7, LR.Left) - RightShift = (17 - 7, LR.Left) + LeftShift = (7, LR.Left) + RightShift = (7, LR.Left) - Add = (17 - 6, LR.Left) - Sub = (17 - 6, LR.Left) + RelOp = (9, LR.Left) # >=, >, <, <= - Mul = (17 - 5, LR.Left) - Div = (17 - 5, LR.Left) - Rem = (17 - 5, LR.Left) + EqOp = (10, LR.Left) # == and != - Neg = (17 - 3, LR.Right) - AddressOf = (17 - 3, LR.Right) - Deref = (17 - 3, LR.Right) - Cast = (17 - 3, LR.Right) + BitwiseAnd = (11, LR.Left) - Call = (17 - 2, LR.Left) - Subscript = (17 - 2, LR.Left) - Lookup = (17 - 2, LR.Left) + BitwiseXor = (12, LR.Left) + + BitwiseOr = (13, LR.Left) + + LogicAnd = (14, LR.Left) + + LogicOr = (15, LR.Left) + + Weakest = (17, LR.Middle) def __init__(self, pred: int, assoc: LR) -> None: self.precedence = pred @@ -125,7 +143,7 @@ class PrinterCtx: return self.branch_stack[-1] def parenthesize(self, expr: str, next_operator: Ops) -> str: - if next_operator.precedence < self.current_op.precedence: + if next_operator.precedence > self.current_op.precedence: return f"({expr})" elif ( next_operator.precedence == self.current_op.precedence @@ -274,6 +292,13 @@ class CAstPrinter: return pc.parenthesize(f"-{operand_code}", Ops.Neg) + case PsNot(operand): + pc.push_op(Ops.Not, LR.Right) + operand_code = self.visit(operand, pc) + pc.pop_op() + + return pc.parenthesize(f"!{operand_code}", Ops.Not) + case PsDeref(operand): pc.push_op(Ops.Deref, LR.Right) operand_code = self.visit(operand, pc) @@ -339,5 +364,21 @@ class CAstPrinter: return ("^", Ops.BitwiseXor) case PsBitwiseOr(): return ("|", Ops.BitwiseOr) + case PsAnd(): + return ("&&", Ops.LogicAnd) + case PsOr(): + return ("||", Ops.LogicOr) + case PsEq(): + return ("==", Ops.EqOp) + case PsNe(): + return ("!=", Ops.EqOp) + case PsGt(): + return (">", Ops.RelOp) + case PsGe(): + return (">=", Ops.RelOp) + case PsLt(): + return ("<", Ops.RelOp) + case PsLe(): + return ("<=", Ops.RelOp) case _: assert False diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index b9bbe8cce84dca53a02e2297e0e8cb5199a25b26..c2334f54c34d476207eddc5466b2b13bff0d39d8 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -119,17 +119,19 @@ class AstFactory: body, ) - def loop_nest(self, counters: Sequence[str], slices: Sequence[slice], body: PsBlock) -> PsLoop: + def loop_nest( + self, counters: Sequence[str], slices: Sequence[slice], body: PsBlock + ) -> PsLoop: """Create a loop nest from a sequence of slices. **Example:** This snippet creates a 3D loop nest with ten iterations in each dimension:: - + >>> from pystencils import make_slice >>> ctx = KernelCreationContext() >>> factory = AstFactory(ctx) >>> loop = factory.loop_nest(("i", "j", "k"), make_slice[:10,:10,:10], PsBlock([])) - + Args: counters: Sequence of names for the loop counters slices: Sequence of iteration slices; see also `parse_slice` diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index a9f760e9718742bfc16a2bee7c60f14f3f272be3..0f2485fe903af25e3c93777e055dae80b2b4209d 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -3,6 +3,8 @@ from functools import reduce from operator import add, mul, sub, truediv import sympy as sp +import sympy.core.relational +import sympy.logic.boolalg from sympy.codegen.ast import AssignmentBase, AugmentedAssignment from ...sympyextensions import Assignment, AssignmentCollection, integer_functions @@ -34,6 +36,16 @@ from ..ast.expressions import ( PsRightShift, PsSubscript, PsVectorArrayAccess, + PsRel, + PsEq, + PsNe, + PsLt, + PsGt, + PsLe, + PsGe, + PsAnd, + PsOr, + PsNot, ) from ..constants import PsConstant @@ -46,6 +58,20 @@ class FreezeError(Exception): """Signifies an error during expression freezing.""" +ExprLike = ( + sp.Expr + | sp.Tuple + | sympy.core.relational.Relational + | sympy.logic.boolalg.BooleanFunction +) +_ExprLike = ( + sp.Expr, + sp.Tuple, + sympy.core.relational.Relational, + sympy.logic.boolalg.BooleanFunction, +) + + class FreezeExpressions: """Convert expressions and kernels expressed in the SymPy language to the code generator's internal representation. @@ -65,7 +91,7 @@ class FreezeExpressions: pass @overload - def __call__(self, obj: sp.Expr) -> PsExpression: + def __call__(self, obj: ExprLike) -> PsExpression: pass @overload @@ -77,7 +103,7 @@ class FreezeExpressions: return PsBlock([self.visit(asm) for asm in obj.all_assignments]) elif isinstance(obj, AssignmentBase): return cast(PsAssignment, self.visit(obj)) - elif isinstance(obj, sp.Expr): + elif isinstance(obj, _ExprLike): return cast(PsExpression, self.visit(obj)) else: raise PsInputError(f"Don't know how to freeze {obj}") @@ -97,8 +123,8 @@ class FreezeExpressions: raise FreezeError(f"Don't know how to freeze expression {node}") - def visit_expr_like(self, obj: Any) -> PsExpression: - if isinstance(obj, sp.Basic): + def visit_expr_or_builtin(self, obj: Any) -> PsExpression: + if isinstance(obj, _ExprLike): return self.visit_expr(obj) elif isinstance(obj, (int, float, bool)): return PsExpression.make(PsConstant(obj)) @@ -106,7 +132,7 @@ class FreezeExpressions: raise FreezeError(f"Don't know how to freeze {obj}") def visit_expr(self, expr: sp.Basic): - if not isinstance(expr, (sp.Expr, sp.Tuple)): + if not isinstance(expr, _ExprLike): raise FreezeError(f"Cannot freeze {expr} to an expression") return cast(PsExpression, self.visit(expr)) @@ -257,7 +283,9 @@ class FreezeExpressions: array = self._ctx.get_array(field) ptr = array.base_pointer - offsets: list[PsExpression] = [self.visit_expr_like(o) for o in access.offsets] + offsets: list[PsExpression] = [ + self.visit_expr_or_builtin(o) for o in access.offsets + ] indices: list[PsExpression] if not access.is_absolute_access: @@ -303,7 +331,7 @@ class FreezeExpressions: ) else: struct_member_name = None - indices = [self.visit_expr_like(i) for i in access.index] + indices = [self.visit_expr_or_builtin(i) for i in access.index] if not indices: # For canonical representation, there must always be at least one index dimension indices = [PsExpression.make(PsConstant(0))] @@ -371,5 +399,35 @@ class FreezeExpressions: args = tuple(self.visit_expr(arg) for arg in expr.args) return PsCall(PsMathFunction(MathFunctions.Max), args) - def map_CastFunc(self, cast_expr: CastFunc): + def map_CastFunc(self, cast_expr: CastFunc) -> PsCast: return PsCast(cast_expr.dtype, self.visit_expr(cast_expr.expr)) + + def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel: + arg1, arg2 = [self.visit_expr(arg) for arg in rel.args] + match rel.rel_op: # type: ignore + case "==": + return PsEq(arg1, arg2) + case "!=": + return PsNe(arg1, arg2) + case ">=": + return PsGe(arg1, arg2) + case "<=": + return PsLe(arg1, arg2) + case ">": + return PsGt(arg1, arg2) + case "<": + return PsLt(arg1, arg2) + case other: + raise FreezeError(f"Unsupported relation: {other}") + + def map_And(self, conj: sympy.logic.And) -> PsAnd: + arg1, arg2 = [self.visit_expr(arg) for arg in conj.args] + return PsAnd(arg1, arg2) + + def map_Or(self, disj: sympy.logic.Or) -> PsOr: + arg1, arg2 = [self.visit_expr(arg) for arg in disj.args] + return PsOr(arg1, arg2) + + def map_Not(self, neg: sympy.logic.Not) -> PsNot: + arg = self.visit_expr(neg.args[0]) + return PsNot(arg) diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 5a093031cb4a498a4a00ff2020df0d69747a2b70..ba215f822ea7372211bf764425d44e44487cc46b 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -125,7 +125,7 @@ class FullIterationSpace(IterationSpace): archetype_field: Field | None = None, ): """Create an iteration space from a sequence of slices, optionally over an archetype field. - + Args: ctx: The kernel creation context iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice` @@ -157,6 +157,7 @@ class FullIterationSpace(IterationSpace): ] from .ast_factory import AstFactory + factory = AstFactory(ctx) def to_dim(slic: slice, size: PsSymbol | PsConstant | None, ctr: PsSymbol): diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index bfecec5be0bfccadebe671e2272e771f02aa71fc..1bf3c49807ff52a34bc9ab319f5da67e4fa59ebc 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -22,25 +22,26 @@ from ..ast.structural import ( PsExpression, PsAssignment, PsDeclaration, + PsComment, ) from ..ast.expressions import ( PsArrayAccess, PsArrayInitList, PsBinOp, - PsBitwiseAnd, - PsBitwiseOr, - PsBitwiseXor, + PsIntOpTrait, + PsNumericOpTrait, + PsBoolOpTrait, PsCall, PsCast, PsDeref, PsAddressOf, PsConstantExpr, - PsIntDiv, - PsLeftShift, PsLookup, - PsRightShift, PsSubscript, PsSymbolExpr, + PsRel, + PsNeg, + PsNot, ) from ..functions import PsMathFunction @@ -167,19 +168,29 @@ class TypeContext: f" Target type: {self._target_type}" ) - case ( - PsIntDiv() - | PsLeftShift() - | PsRightShift() - | PsBitwiseAnd() - | PsBitwiseXor() - | PsBitwiseOr() - ) if not isinstance(self._target_type, PsIntegerType): + case PsNumericOpTrait() if not isinstance( + self._target_type, PsNumericType + ) or isinstance(self._target_type, PsBoolType): + # FIXME: PsBoolType derives from PsNumericType, but is not numeric + raise TypificationError( + f"Numerical operation encountered in non-numerical type context:\n" + f" Expression: {expr}" + f" Type Context: {self._target_type}" + ) + + case PsIntOpTrait() if not isinstance(self._target_type, PsIntegerType): raise TypificationError( f"Integer operation encountered in non-integer type context:\n" f" Expression: {expr}" f" Type Context: {self._target_type}" ) + + case PsBoolOpTrait() if not isinstance(self._target_type, PsBoolType): + raise TypificationError( + f"Boolean operation encountered in non-boolean type context:\n" + f" Expression: {expr}" + f" Type Context: {self._target_type}" + ) # endif expr.dtype = self._target_type @@ -297,7 +308,7 @@ class Typifier: self.visit_expr(rhs, tc_rhs) case PsConditional(cond, branch_true, branch_false): - cond_tc = TypeContext(PsBoolType(const=True)) + cond_tc = TypeContext(PsBoolType()) self.visit_expr(cond, cond_tc) self.visit(branch_true) @@ -316,6 +327,9 @@ class Typifier: self.visit(body) + case PsComment(): + pass + case _: raise NotImplementedError(f"Can't typify {node}") @@ -420,11 +434,33 @@ class Typifier: tc.apply_dtype(member_type, expr) + case PsRel(op1, op2): + args_tc = TypeContext() + self.visit_expr(op1, args_tc) + self.visit_expr(op2, args_tc) + + if args_tc.target_type is None: + raise TypificationError( + f"Unable to determine type of arguments to relation: {expr}" + ) + if not isinstance(args_tc.target_type, PsNumericType): + raise TypificationError( + f"Invalid type in arguments to relation\n" + f" Expression: {expr}\n" + f" Arguments Type: {args_tc.target_type}" + ) + + tc.apply_dtype(PsBoolType(), expr) + case PsBinOp(op1, op2): self.visit_expr(op1, tc) self.visit_expr(op2, tc) tc.infer_dtype(expr) + case PsNeg(op) | PsNot(op): + self.visit_expr(op, tc) + tc.infer_dtype(expr) + case PsCall(function, args): match function: case PsMathFunction(): diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 8b1d8f783d7596ca341c01ffbaec64e038dd3d80..839cd34f4c9fada060a0ac6253b635f7c7812948 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -14,7 +14,7 @@ from ..ast.expressions import ( PsSymbolExpr, PsAdd, ) -from ..ast.logical_expressions import PsLt, PsAnd +from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType from ..symbols import PsSymbol diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index afb1e4fcd52d2f0fd85a008ca24987f085fd7dc6..01b69509991eaa762a093f50f427f6e4050dc34a 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -1,4 +1,5 @@ from .eliminate_constants import EliminateConstants +from .eliminate_branches import EliminateBranches from .canonicalize_symbols import CanonicalizeSymbols from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .erase_anonymous_structs import EraseAnonymousStructTypes @@ -7,6 +8,7 @@ from .select_intrinsics import MaterializeVectorIntrinsics __all__ = [ "EliminateConstants", + "EliminateBranches", "CanonicalizeSymbols", "HoistLoopInvariantDeclarations", "EraseAnonymousStructTypes", diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py new file mode 100644 index 0000000000000000000000000000000000000000..eab3d3722c30756ab39af072e75e9d6d89874447 --- /dev/null +++ b/src/pystencils/backend/transformations/eliminate_branches.py @@ -0,0 +1,73 @@ +from ..kernelcreation import KernelCreationContext +from ..ast import PsAstNode +from ..ast.structural import PsLoop, PsBlock, PsConditional +from ..ast.expressions import PsConstantExpr + +from .eliminate_constants import EliminateConstants + +__all__ = ["EliminateBranches"] + + +class BranchElimContext: + def __init__(self) -> None: + self.enclosing_loops: list[PsLoop] = [] + + +class EliminateBranches: + """Replace conditional branches by their then- or else-branch if their condition can be unequivocally + evaluated. + + This pass will attempt to evaluate branch conditions within their context in the AST, and replace + conditionals by either their then- or their else-block if the branch is unequivocal. + + TODO: If islpy is installed, this pass will incorporate information about the iteration regions + of enclosing loops into its analysis. + """ + + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False) + + def __call__(self, node: PsAstNode) -> PsAstNode: + return self.visit(node, BranchElimContext()) + + def visit(self, node: PsAstNode, ec: BranchElimContext) -> PsAstNode: + match node: + case PsLoop(_, _, _, _, body): + ec.enclosing_loops.append(node) + self.visit(body, ec) + ec.enclosing_loops.pop() + + case PsBlock(statements): + statements_new: list[PsAstNode] = [] + for stmt in statements: + if isinstance(stmt, PsConditional): + result = self.handle_conditional(stmt, ec) + if result is not None: + statements_new.append(result) + else: + statements_new.append(self.visit(stmt, ec)) + node.statements = statements_new + + case PsConditional(): + result = self.handle_conditional(node, ec) + if result is None: + return PsBlock([]) + else: + return result + + return node + + def handle_conditional( + self, conditional: PsConditional, ec: BranchElimContext + ) -> PsConditional | PsBlock | None: + condition_simplified = self._elim_constants(conditional.condition) + match condition_simplified: + case PsConstantExpr(c) if c.value: + return conditional.branch_true + case PsConstantExpr(c) if not c.value: + return conditional.branch_false + + # TODO: Analyze condition against counters of enclosing loops using ISL + + return conditional diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index 22ad740faca992dc9f520f160a406912af4385af..7678dbd8c6ce783585fb7095b201e9f92e65e485 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -14,12 +14,31 @@ from ..ast.expressions import ( PsSub, PsMul, PsDiv, + PsAnd, + PsOr, + PsRel, + PsNeg, + PsNot, + PsCall, + PsEq, + PsGe, + PsLe, + PsLt, + PsGt, + PsNe, ) from ..ast.util import AstEqWrapper from ..constants import PsConstant from ..symbols import PsSymbol -from ...types import PsIntegerType, PsIeeeFloatType, PsTypeError +from ..functions import PsMathFunction +from ...types import ( + PsIntegerType, + PsIeeeFloatType, + PsNumericType, + PsBoolType, + PsTypeError, +) __all__ = ["EliminateConstants"] @@ -30,8 +49,6 @@ class ECContext: self._ctx = ctx self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict() - self._typifier = Typifier(ctx) - from ..emission import CAstPrinter self._printer = CAstPrinter(0) @@ -59,7 +76,7 @@ class ECContext: return f"__c_{code}" def extract_expression(self, expr: PsExpression) -> PsSymbolExpr: - expr, dtype = self._typifier.typify_expression(expr) + dtype = expr.get_dtype() expr_wrapped = AstEqWrapper(expr) if expr_wrapped not in self._extracted_constants: @@ -92,8 +109,10 @@ class EliminateConstants: self, ctx: KernelCreationContext, extract_constant_exprs: bool = False ): self._ctx = ctx + self._typify = Typifier(ctx) self._fold_integers = True + self._fold_relations = True self._fold_floats = False self._extract_constant_exprs = extract_constant_exprs @@ -146,7 +165,7 @@ class EliminateConstants: expr.children = [r[0] for r in subtree_results] subtree_constness = [r[1] for r in subtree_results] - # Eliminate idempotence and dominance + # Eliminate idempotence, dominance, and trivial relations match expr: # Additive idempotence: Addition and subtraction of zero case PsAdd(PsConstantExpr(c), other_op) if c.value == 0: @@ -180,52 +199,110 @@ class EliminateConstants: case PsMul(other_op, PsConstantExpr(c)) if c.value == 0: return PsConstantExpr(c), True + # Logical idempotence + case PsAnd(PsConstantExpr(c), other_op) if c.value: + return other_op, all(subtree_constness) + + case PsAnd(other_op, PsConstantExpr(c)) if c.value: + return other_op, all(subtree_constness) + + case PsOr(PsConstantExpr(c), other_op) if not c.value: + return other_op, all(subtree_constness) + + case PsOr(other_op, PsConstantExpr(c)) if not c.value: + return other_op, all(subtree_constness) + + # Logical dominance + case PsAnd(PsConstantExpr(c), other_op) if not c.value: + return PsConstantExpr(c), True + + case PsAnd(other_op, PsConstantExpr(c)) if not c.value: + return PsConstantExpr(c), True + + case PsOr(PsConstantExpr(c), other_op) if c.value: + return PsConstantExpr(c), True + + case PsOr(other_op, PsConstantExpr(c)) if c.value: + return PsConstantExpr(c), True + + # Trivial comparisons + case ( + PsEq(op1, op2) | PsGe(op1, op2) | PsLe(op1, op2) + ) if op1.structurally_equal(op2): + true = self._typify(PsConstantExpr(PsConstant(True, PsBoolType()))) + return true, True + + case ( + PsNe(op1, op2) | PsGt(op1, op2) | PsLt(op1, op2) + ) if op1.structurally_equal(op2): + false = self._typify(PsConstantExpr(PsConstant(False, PsBoolType()))) + return false, True + # end match: no idempotence or dominance encountered # Detect constant expressions if all(subtree_constness): - # Fold binary expressions where possible - if isinstance(expr, PsBinOp): - op1_transformed = expr.operand1 - op2_transformed = expr.operand2 - - if isinstance(op1_transformed, PsConstantExpr) and isinstance( - op2_transformed, PsConstantExpr - ): - v1 = op1_transformed.constant.value - v2 = op2_transformed.constant.value + dtype = expr.get_dtype() + assert isinstance(dtype, PsNumericType) + + is_int = isinstance(dtype, PsIntegerType) + is_float = isinstance(dtype, PsIeeeFloatType) + is_bool = isinstance(dtype, PsBoolType) + is_rel = isinstance(expr, PsRel) + + do_fold = ( + is_bool + or (self._fold_integers and is_int) + or (self._fold_floats and is_float) + or (self._fold_relations and is_rel) + ) + + folded: PsConstant | None + + match expr: + case PsNeg(operand) | PsNot(operand): + if isinstance(operand, PsConstantExpr): + val = operand.constant.value + py_operator = expr.python_operator - # assume they are of equal type - dtype = op1_transformed.constant.dtype + if do_fold and py_operator is not None: + folded = PsConstant(py_operator(val), dtype) + return self._typify(PsConstantExpr(folded)), True - is_int = isinstance(dtype, PsIntegerType) - is_float = isinstance(dtype, PsIeeeFloatType) + return expr, True - if (self._fold_integers and is_int) or ( - self._fold_floats and is_float + case PsBinOp(op1, op2): + if isinstance(op1, PsConstantExpr) and isinstance( + op2, PsConstantExpr ): - py_operator = expr.python_operator - - folded = None - if py_operator is not None: - folded = PsConstant( - py_operator(v1, v2), - dtype, - ) - elif isinstance(expr, PsDiv): - if isinstance(dtype, PsIntegerType): - pass - # TODO: C integer division! - # folded = PsConstant(v1 // v2, dtype) - elif isinstance(dtype, PsIeeeFloatType): - folded = PsConstant(v1 / v2, dtype) - - if folded is not None: - return PsConstantExpr(folded), True - - expr.operand1 = op1_transformed - expr.operand2 = op2_transformed - return expr, True + v1 = op1.constant.value + v2 = op2.constant.value + + if do_fold: + py_operator = expr.python_operator + + folded = None + if py_operator is not None: + folded = PsConstant( + py_operator(v1, v2), + dtype, + ) + elif isinstance(expr, PsDiv): + if is_int: + pass + # TODO: C integer division! + # folded = PsConstant(v1 // v2, dtype) + elif isinstance(dtype, PsIeeeFloatType): + folded = PsConstant(v1 / v2, dtype) + + if folded is not None: + return self._typify(PsConstantExpr(folded)), True + + return expr, True + + case PsCall(PsMathFunction(), _): + # TODO: Some math functions (min/max) might be safely folded + return expr, True # end if: this expression is not constant # If required, extract constant subexpressions diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 269435257bcd8dbb486290fe1d3f35aee8e21319..99aac49df90ba694ce9501eec389bd22f00a4070 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -1,4 +1,5 @@ import sympy as sp +import pytest from pystencils import Assignment, fields @@ -15,8 +16,16 @@ from pystencils.backend.ast.expressions import ( PsExpression, PsIntDiv, PsLeftShift, - PsMul, PsRightShift, + PsAnd, + PsOr, + PsNot, + PsEq, + PsNe, + PsLt, + PsLe, + PsGt, + PsGe ) from pystencils.backend.constants import PsConstant from pystencils.backend.kernelcreation import ( @@ -33,7 +42,6 @@ from pystencils.sympyextensions.integer_functions import ( bitwise_xor, int_div, int_power_of_2, - modulo_floor, ) @@ -145,3 +153,44 @@ def test_freeze_integer_functions(): for fasm, correct in zip(fasms, should): assert fasm.structurally_equal(correct) + + +def test_freeze_booleans(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + + x, y, z = sp.symbols("x, y, z") + + expr1 = freeze(sp.Not(sp.And(x, y))) + assert expr1.structurally_equal(PsNot(PsAnd(x2, y2))) + + expr2 = freeze(sp.Or(sp.Not(z), sp.And(y, sp.Not(x)))) + assert expr2.structurally_equal(PsOr(PsNot(z2), PsAnd(y2, PsNot(x2)))) + + +@pytest.mark.parametrize("rel_pair", [ + (sp.Eq, PsEq), + (sp.Ne, PsNe), + (sp.Lt, PsLt), + (sp.Gt, PsGt), + (sp.Le, PsLe), + (sp.Ge, PsGe) +]) +def test_freeze_relations(rel_pair): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + sp_op, ps_op = rel_pair + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + + x, y, z = sp.symbols("x, y, z") + + expr1 = freeze(sp_op(x, y + z)) + assert expr1.structurally_equal(ps_op(x2, y2 + z2)) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index adca2245b02feeac286a4507c27b7fb570620af8..60d0d6e7424bdfea730cafe18995afdb7dc253df 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -6,11 +6,30 @@ from typing import cast from pystencils import Assignment, TypedSymbol, Field, FieldType -from pystencils.backend.ast.structural import PsDeclaration, PsAssignment, PsExpression -from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp +from pystencils.backend.ast.structural import ( + PsDeclaration, + PsAssignment, + PsExpression, + PsConditional, + PsBlock, +) +from pystencils.backend.ast.expressions import ( + PsConstantExpr, + PsSymbolExpr, + PsBinOp, + PsAnd, + PsOr, + PsNot, + PsEq, + PsNe, + PsGe, + PsLe, + PsGt, + PsLt, +) from pystencils.backend.constants import PsConstant from pystencils.types import constify, create_type, create_numeric_type -from pystencils.types.quick import Fp +from pystencils.types.quick import Fp, Bool from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError @@ -291,4 +310,55 @@ def test_typify_constant_clones(): assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None -test_lhs_constness() \ No newline at end of file +def test_typify_bools_and_relations(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + true = PsConstantExpr(PsConstant(True, Bool())) + p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] + x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] + + expr = PsAnd(PsEq(x, y), PsAnd(true, PsNot(PsOr(p, q)))) + expr = typify(expr) + + assert expr.dtype == Bool(const=True) + + +def test_bool_in_numerical_context(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + true = PsConstantExpr(PsConstant(True, Bool())) + p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] + + expr = true + (p - q) + with pytest.raises(TypificationError): + typify(expr) + + +@pytest.mark.parametrize("rel", [PsEq, PsNe, PsLt, PsGt, PsLe, PsGe]) +def test_typify_conditionals(rel): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] + + cond = PsConditional(rel(x, y), PsBlock([])) + cond = typify(cond) + assert cond.condition.dtype == Bool(const=True) + + +def test_invalid_conditions(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] + p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] + + cond = PsConditional(x + y, PsBlock([])) + with pytest.raises(TypificationError): + typify(cond) + + cond = PsConditional(PsAnd(p, PsOr(x, q)), PsBlock([])) + with pytest.raises(TypificationError): + typify(cond) diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py index c8294c6dd4ac1a8683a5f779cb52d3e628cbc708..1fc6821d7b530a8b8e10b0298b641219bd31a53a 100644 --- a/tests/nbackend/test_code_printing.py +++ b/tests/nbackend/test_code_printing.py @@ -6,7 +6,7 @@ from pystencils.backend.kernelfunction import KernelFunction from pystencils.backend.symbols import PsSymbol from pystencils.backend.constants import PsConstant from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer -from pystencils.types.quick import Fp, SInt, UInt +from pystencils.types.quick import Fp, SInt, UInt, Bool from pystencils.backend.emission import CAstPrinter @@ -99,3 +99,61 @@ def test_printing_integer_functions(): ) code = cprint(expr) assert code == "(j ^ k ^ (i << (j >> k) | i / k)) & i" + + +def test_logical_precedence(): + from pystencils.backend.ast.expressions import PsNot, PsAnd, PsOr + + p, q, r = [PsExpression.make(PsSymbol(x, Bool())) for x in "pqr"] + true = PsExpression.make(PsConstant(True, Bool())) + false = PsExpression.make(PsConstant(False, Bool())) + cprint = CAstPrinter() + + expr = PsNot(PsAnd(p, PsOr(q, r))) + code = cprint(expr) + assert code == "!(p && (q || r))" + + expr = PsAnd(PsAnd(p, q), PsAnd(q, r)) + code = cprint(expr) + assert code == "p && q && (q && r)" + + expr = PsOr(PsAnd(true, p), PsOr(PsAnd(false, PsNot(q)), PsAnd(r, p))) + code = cprint(expr) + assert code == "true && p || (false && !q || r && p)" + + expr = PsAnd(PsOr(PsNot(p), PsNot(q)), PsNot(PsOr(true, false))) + code = cprint(expr) + assert code == "(!p || !q) && !(true || false)" + + +def test_relations_precedence(): + from pystencils.backend.ast.expressions import ( + PsNot, + PsAnd, + PsOr, + PsEq, + PsNe, + PsLt, + PsGt, + PsLe, + PsGe, + ) + + x, y, z = [PsExpression.make(PsSymbol(x, Fp(32))) for x in "xyz"] + cprint = CAstPrinter() + + expr = PsAnd(PsEq(x, y), PsLe(y, z)) + code = cprint(expr) + assert code == "x == y && y <= z" + + expr = PsOr(PsLt(x, y), PsLt(y, z)) + code = cprint(expr) + assert code == "x < y || y < z" + + expr = PsAnd(PsNot(PsGe(x, y)), PsNot(PsLe(y, z))) + code = cprint(expr) + assert code == "!(x >= y) && !(y <= z)" + + expr = PsOr(PsNe(x, y), PsNot(PsGt(y, z))) + code = cprint(expr) + assert code == "x != y || !(y > z)" diff --git a/tests/nbackend/transformations/test_branch_elimination.py b/tests/nbackend/transformations/test_branch_elimination.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb3526d0b53fd40972c4dfeb06cf3a614bc6c10 --- /dev/null +++ b/tests/nbackend/transformations/test_branch_elimination.py @@ -0,0 +1,55 @@ +from pystencils import make_slice +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + Typifier, + AstFactory, +) +from pystencils.backend.ast.expressions import PsExpression +from pystencils.backend.ast.structural import PsConditional, PsBlock, PsComment +from pystencils.backend.constants import PsConstant +from pystencils.backend.transformations import EliminateBranches +from pystencils.types.quick import Int +from pystencils.backend.ast.expressions import PsGt + + +i0 = PsExpression.make(PsConstant(0, Int(32))) +i1 = PsExpression.make(PsConstant(1, Int(32))) + + +def test_eliminate_conditional(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + elim = EliminateBranches(ctx) + + b1 = PsBlock([PsComment("Branch One")]) + + b2 = PsBlock([PsComment("Branch Two")]) + + cond = typify(PsConditional(PsGt(i1, i0), b1, b2)) + result = elim(cond) + assert result == b1 + + cond = typify(PsConditional(PsGt(-i1, i0), b1, b2)) + result = elim(cond) + assert result == b2 + + cond = typify(PsConditional(PsGt(-i1, i0), b1)) + result = elim(cond) + assert result.structurally_equal(PsBlock([])) + + +def test_eliminate_nested_conditional(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + typify = Typifier(ctx) + elim = EliminateBranches(ctx) + + b1 = PsBlock([PsComment("Branch One")]) + + b2 = PsBlock([PsComment("Branch Two")]) + + cond = typify(PsConditional(PsGt(i1, i0), b1, b2)) + ast = factory.loop_nest(("i", "j"), make_slice[:10, :10], PsBlock([cond])) + + result = elim(ast) + assert result.body.statements[0].body.statements[0] == b1 diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py index b2ac6fc5ab1d9b8344279a201352886660ac8bfb..48df23ee193e24ff6736d74c22317f04dddc056c 100644 --- a/tests/nbackend/transformations/test_constant_elimination.py +++ b/tests/nbackend/transformations/test_constant_elimination.py @@ -1,12 +1,22 @@ -from pystencils.backend.kernelcreation import KernelCreationContext +from pystencils.backend.kernelcreation import KernelCreationContext, Typifier from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr from pystencils.backend.symbols import PsSymbol from pystencils.backend.constants import PsConstant from pystencils.backend.transformations import EliminateConstants -from pystencils.types.quick import Int, Fp +from pystencils.backend.ast.expressions import ( + PsAnd, + PsOr, + PsNot, + PsEq, + PsGt, +) -x, y, z = [PsExpression.make(PsSymbol(name)) for name in "xyz"] +from pystencils.types.quick import Int, Fp, Bool + +x, y, z = [PsExpression.make(PsSymbol(name, Fp(32))) for name in "xyz"] +p, q, r = [PsExpression.make(PsSymbol(name, Int(32))) for name in "pqr"] +a, b, c = [PsExpression.make(PsSymbol(name, Bool())) for name in "abc"] f3p5 = PsExpression.make(PsConstant(3.5, Fp(32))) f42 = PsExpression.make(PsConstant(42, Fp(32))) @@ -18,52 +28,103 @@ i0 = PsExpression.make(PsConstant(0, Int(32))) i1 = PsExpression.make(PsConstant(1, Int(32))) i3 = PsExpression.make(PsConstant(3, Int(32))) +im3 = PsExpression.make(PsConstant(-3, Int(32))) i12 = PsExpression.make(PsConstant(12, Int(32))) +true = PsExpression.make(PsConstant(True, Bool())) +false = PsExpression.make(PsConstant(False, Bool())) + def test_idempotence(): ctx = KernelCreationContext() + typify = Typifier(ctx) elim = EliminateConstants(ctx) - expr = f42 * (f1 + f0) - f0 + expr = typify(f42 * (f1 + f0) - f0) result = elim(expr) assert isinstance(result, PsConstantExpr) and result.structurally_equal(f42) - expr = (x + f0) * f3p5 + (f1 * y + f0) * f42 + expr = typify((x + f0) * f3p5 + (f1 * y + f0) * f42) result = elim(expr) assert result.structurally_equal(x * f3p5 + y * f42) - expr = (f3p5 * f1) + (f42 * f1) + expr = typify((f3p5 * f1) + (f42 * f1)) result = elim(expr) # do not fold floats by default assert expr.structurally_equal(f3p5 + f42) - expr = f1 * x + f0 + (f0 + f0 + f1 + f0) * y + expr = typify(f1 * x + f0 + (f0 + f0 + f1 + f0) * y) result = elim(expr) assert result.structurally_equal(x + y) def test_int_folding(): ctx = KernelCreationContext() + typify = Typifier(ctx) elim = EliminateConstants(ctx) - expr = (i1 * x + i1 * i3) + i1 * i12 + expr = typify((i1 * p + i1 * -i3) + i1 * i12) result = elim(expr) - assert result.structurally_equal((x + i3) + i12) + assert result.structurally_equal((p + im3) + i12) - expr = (i1 + i1 + i1 + i0 + i0 + i1) * (i1 + i1 + i1) + expr = typify((i1 + i1 + i1 + i0 + i0 + i1) * (i1 + i1 + i1)) result = elim(expr) assert result.structurally_equal(i12) def test_zero_dominance(): ctx = KernelCreationContext() + typify = Typifier(ctx) elim = EliminateConstants(ctx) - expr = (f0 * x) + (y * f0) + f1 + expr = typify((f0 * x) + (y * f0) + f1) result = elim(expr) assert result.structurally_equal(f1) - expr = (i3 + i12 * (x + y) + x / (i3 * y)) * i0 + expr = typify((i3 + i12 * (p + q) + p / (i3 * q)) * i0) result = elim(expr) assert result.structurally_equal(i0) + + +def test_boolean_folding(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + elim = EliminateConstants(ctx) + + expr = typify(PsNot(PsAnd(false, PsOr(true, a)))) + result = elim(expr) + assert result.structurally_equal(true) + + expr = typify(PsOr(PsAnd(a, b), PsNot(false))) + result = elim(expr) + assert result.structurally_equal(true) + + expr = typify(PsAnd(c, PsAnd(true, PsAnd(a, PsOr(false, b))))) + result = elim(expr) + assert result.structurally_equal(PsAnd(c, PsAnd(a, b))) + + +def test_relations_folding(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + elim = EliminateConstants(ctx) + + expr = typify(PsGt(p * i0, - i1)) + result = elim(expr) + assert result.structurally_equal(true) + + expr = typify(PsEq(i1 + i1 + i1, i3)) + result = elim(expr) + assert result.structurally_equal(true) + + expr = typify(PsEq(- i1, - i3)) + result = elim(expr) + assert result.structurally_equal(false) + + expr = typify(PsEq(x + y, f1 * (x + y))) + result = elim(expr) + assert result.structurally_equal(true) + + expr = typify(PsGt(x + y, f1 * (x + y))) + result = elim(expr) + assert result.structurally_equal(false)