diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 4063f7b539ab387d1f950a75e735f3c6201b5ef2..5e6adfa4ff73d0a813f8668d634a27a698180aea 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -421,6 +421,54 @@ class PsCall(PsExpression): return super().structurally_equal(other) and self._function == other._function +class PsTernary(PsExpression): + """Ternary operator.""" + + __match_args__ = ("condition", "case_then", "case_else") + + def __init__( + self, cond: PsExpression, then: PsExpression, els: PsExpression + ) -> None: + super().__init__() + self._cond = cond + self._then = then + self._else = els + + @property + def condition(self) -> PsExpression: + return self._cond + + @property + def case_then(self) -> PsExpression: + return self._then + + @property + def case_else(self) -> PsExpression: + return self._else + + def clone(self) -> PsExpression: + return PsTernary(self._cond.clone(), self._then.clone(), self._else.clone()) + + def get_children(self) -> tuple[PsExpression, ...]: + return (self._cond, self._then, self._else) + + def set_child(self, idx: int, c: PsAstNode): + idx = range(3)[idx] + match idx: + case 0: + self._cond = failing_cast(PsExpression, c) + case 1: + self._then = failing_cast(PsExpression, c) + case 2: + self._else = failing_cast(PsExpression, c) + + def __str__(self) -> str: + return f"PsTernary({self._cond}, {self._then}, {self._else})" + + def __repr__(self) -> str: + return f"PsTernary({repr(self._cond)}, {repr(self._then)}, {repr(self._else)})" + + class PsNumericOpTrait: """Trait for operations valid only on numerical types""" @@ -582,9 +630,21 @@ class PsDiv(PsBinOp, PsNumericOpTrait): class PsIntDiv(PsBinOp, PsIntOpTrait): """C-like integer division (round to zero).""" - # python_operator not implemented because both floordiv and truediv have - # different semantics. - pass + @property + def python_operator(self) -> Callable[[Any, Any], Any]: + from .util import c_intdiv + + return c_intdiv + + +class PsRem(PsBinOp, PsIntOpTrait): + """C-style integer division remainder""" + + @property + def python_operator(self) -> Callable[[Any, Any], Any]: + from .util import c_rem + + return c_rem class PsLeftShift(PsBinOp, PsIntOpTrait): diff --git a/src/pystencils/backend/ast/util.py b/src/pystencils/backend/ast/util.py index 0d3b78629fa9ee41d753893b1b6b4198cc75ae51..2fdf6078d0cf285062656a739d2bdcea7736e1c2 100644 --- a/src/pystencils/backend/ast/util.py +++ b/src/pystencils/backend/ast/util.py @@ -36,3 +36,14 @@ class AstEqWrapper: # TODO: consider replacing this with smth. more performant # TODO: Check that repr is implemented by all AST nodes return hash(repr(self._node)) + + +def c_intdiv(num, denom): + """C-style integer division""" + return int(num / denom) + + +def c_rem(num, denom): + """C-style integer remainder""" + div = c_intdiv(num, denom) + return num - div * denom diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index e8fc2a662b2f49c43fe51d021915b9e44cc59ad3..3b957f3b321e212158eb2eb5d6229037f286b307 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -26,6 +26,7 @@ from .ast.expressions import ( PsConstantExpr, PsDeref, PsDiv, + PsRem, PsIntDiv, PsLeftShift, PsLookup, @@ -37,6 +38,7 @@ from .ast.expressions import ( PsSymbolExpr, PsLiteralExpr, PsVectorArrayAccess, + PsTernary, PsAnd, PsOr, PsNot, @@ -112,6 +114,8 @@ class Ops(Enum): LogicOr = (15, LR.Left) + Ternary = (16, LR.Right) + Weakest = (17, LR.Middle) def __init__(self, pred: int, assoc: LR) -> None: @@ -329,6 +333,19 @@ class CAstPrinter: type_str = target_type.c_string() return pc.parenthesize(f"({type_str}) {operand_code}", Ops.Cast) + case PsTernary(cond, then, els): + pc.push_op(Ops.Ternary, LR.Left) + cond_code = self.visit(cond, pc) + pc.switch_branch(LR.Middle) + then_code = self.visit(then, pc) + pc.switch_branch(LR.Right) + else_code = self.visit(els, pc) + pc.pop_op() + + return pc.parenthesize( + f"{cond_code} ? {then_code} : {else_code}", Ops.Ternary + ) + case PsArrayInitList(items): pc.push_op(Ops.Weakest, LR.Middle) items_str = ", ".join(self.visit(item, pc) for item in items) @@ -362,6 +379,8 @@ class CAstPrinter: return ("*", Ops.Mul) case PsDiv() | PsIntDiv(): return ("/", Ops.Div) + case PsRem(): + return ("%", Ops.Rem) case PsLeftShift(): return ("<<", Ops.LeftShift) case PsRightShift(): diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 0f2485fe903af25e3c93777e055dae80b2b4209d..ab6261b95361c3832afce8c9203f6c523b12621a 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -36,6 +36,7 @@ from ..ast.expressions import ( PsRightShift, PsSubscript, PsVectorArrayAccess, + PsTernary, PsRel, PsEq, PsNe, @@ -391,6 +392,27 @@ class FreezeExpressions: case _: raise FreezeError(f"Unsupported function: {func}") + def map_Piecewise(self, expr: sp.Piecewise) -> PsTernary: + from sympy.functions.elementary.piecewise import ExprCondPair + + cases: list[ExprCondPair] = cast(list[ExprCondPair], expr.args) + + if cases[-1].cond != sp.true: + raise FreezeError( + "The last case of a `Piecewise` must be the fallback case, its condition must always be `True`." + ) + + conditions = [self.visit_expr(c.cond) for c in cases[:-1]] + subexprs = [self.visit_expr(c.expr) for c in cases] + + last_expr = subexprs.pop() + ternary = PsTernary(conditions.pop(), subexprs.pop(), last_expr) + + while conditions: + ternary = PsTernary(conditions.pop(), subexprs.pop(), ternary) + + return ternary + def map_Min(self, expr: sp.Min) -> PsCall: args = tuple(self.visit_expr(arg) for arg in expr.args) return PsCall(PsMathFunction(MathFunctions.Min), args) diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 6adac2a519ffc04505c0e0adac3484d78f30d013..2a3d2774e03160fe2012f68ecb3ded4803353304 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -11,7 +11,7 @@ from ...field import Field, FieldType from ..symbols import PsSymbol from ..constants import PsConstant -from ..ast.expressions import PsExpression, PsConstantExpr +from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem from ..arrays import PsLinearizedArray from ..ast.util import failing_cast from ...types import PsStructType, constify @@ -210,14 +210,37 @@ class FullIterationSpace(IterationSpace): return self._archetype_field def actual_iterations(self, dimension: int | None = None) -> PsExpression: + from .typification import Typifier + from ..transformations import EliminateConstants + + typify = Typifier(self._ctx) + fold = EliminateConstants(self._ctx) + if dimension is None: - return reduce( - mul, (self.actual_iterations(d) for d in range(len(self.dimensions))) + return fold( + typify( + reduce( + mul, + ( + self.actual_iterations(d) + for d in range(len(self.dimensions)) + ), + ) + ) ) else: dim = self.dimensions[dimension] one = PsConstantExpr(PsConstant(1, self._ctx.index_dtype)) - return one + (dim.stop - dim.start - one) / dim.step + zero = PsConstantExpr(PsConstant(0, self._ctx.index_dtype)) + return fold( + typify( + PsTernary( + PsEq(PsRem((dim.stop - dim.start), dim.step), zero), + (dim.stop - dim.start) / dim.step, + (dim.stop - dim.start) / dim.step + one, + ) + ) + ) def compressed_counter(self) -> PsExpression: """Expression counting the actual number of items processed at the iteration defined by the counter tuple. diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 06e34d4e36219366a25713d28ae758b61fb3d0d6..42885bedc9272469b8bfac151358cdaedf4e880f 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -32,6 +32,7 @@ from ..ast.expressions import ( PsNumericOpTrait, PsBoolOpTrait, PsCall, + PsTernary, PsCast, PsDeref, PsAddressOf, @@ -446,6 +447,14 @@ class Typifier: tc.apply_dtype(member_type, expr) + case PsTernary(cond, then, els): + cond_tc = TypeContext(target_type=PsBoolType()) + self.visit_expr(cond, cond_tc) + + self.visit_expr(then, tc) + self.visit_expr(els, tc) + tc.infer_dtype(expr) + case PsRel(op1, op2): args_tc = TypeContext() self.visit_expr(op1, args_tc) diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index ddfa33f08272d59d032e1e657e66baa96fb41d04..15a6d5c5c48f5e27f7bc65880f0ad1eb851b4825 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -15,6 +15,8 @@ from ..ast.expressions import ( PsSub, PsMul, PsDiv, + PsIntDiv, + PsRem, PsAnd, PsOr, PsRel, @@ -27,6 +29,7 @@ from ..ast.expressions import ( PsLt, PsGt, PsNe, + PsTernary, ) from ..ast.util import AstEqWrapper @@ -198,9 +201,16 @@ class EliminateConstants: case PsMul(other_op, PsConstantExpr(c)) if c.value == 1: return other_op, all(subtree_constness) - case PsDiv(other_op, PsConstantExpr(c)) if c.value == 1: + case PsDiv(other_op, PsConstantExpr(c)) | PsIntDiv( + other_op, PsConstantExpr(c) + ) if c.value == 1: return other_op, all(subtree_constness) + # Trivial remainder at division by one + case PsRem(other_op, PsConstantExpr(c)) if c.value == 1: + zero = self._typify(PsConstantExpr(PsConstant(0, c.get_dtype()))) + return zero, True + # Multiplicative dominance: 0 * x = 0 case PsMul(PsConstantExpr(c), other_op) if c.value == 0: return PsConstantExpr(c), True @@ -247,6 +257,13 @@ class EliminateConstants: false = self._typify(PsConstantExpr(PsConstant(False, PsBoolType()))) return false, True + # Trivial ternaries + case PsTernary(PsConstantExpr(c), then, els): + if c.value: + return then, subtree_constness[1] + else: + return els, subtree_constness[2] + # end match: no idempotence or dominance encountered # Detect constant expressions @@ -299,9 +316,8 @@ class EliminateConstants: ) elif isinstance(expr, PsDiv): if is_int: - pass - # TODO: C integer division! - # folded = PsConstant(v1 // v2, dtype) + from ..ast.util import c_intdiv + folded = PsConstant(c_intdiv(v1, v2), dtype) elif isinstance(dtype, PsIeeeFloatType): folded = PsConstant(v1 / v2, dtype) diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 99aac49df90ba694ce9501eec389bd22f00a4070..1593be684ca48ec24b730168feaadb3331d4d7dd 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -5,7 +5,6 @@ from pystencils import Assignment, fields from pystencils.backend.ast.structural import ( PsAssignment, - PsBlock, PsDeclaration, ) from pystencils.backend.ast.expressions import ( @@ -14,6 +13,7 @@ from pystencils.backend.ast.expressions import ( PsBitwiseOr, PsBitwiseXor, PsExpression, + PsTernary, PsIntDiv, PsLeftShift, PsRightShift, @@ -33,6 +33,7 @@ from pystencils.backend.kernelcreation import ( FreezeExpressions, FullIterationSpace, ) +from pystencils.backend.kernelcreation.freeze import FreezeError from pystencils.sympyextensions.integer_functions import ( bit_shift_left, @@ -194,3 +195,28 @@ def test_freeze_relations(rel_pair): expr1 = freeze(sp_op(x, y + z)) assert expr1.structurally_equal(ps_op(x2, y2 + z2)) + + +def test_freeze_piecewise(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + p, q, x, y, z = sp.symbols("p, q, x, y, z") + + p2 = PsExpression.make(ctx.get_symbol("p")) + q2 = PsExpression.make(ctx.get_symbol("q")) + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + + piecewise = sp.Piecewise((x, p), (y, q), (z, True)) + expr = freeze(piecewise) + + assert isinstance(expr, PsTernary) + + should = PsTernary(p2, x2, PsTernary(q2, y2, z2)) + assert expr.structurally_equal(should) + + piecewise = sp.Piecewise((x, p), (y, q), (z, sp.Or(p, q))) + with pytest.raises(FreezeError): + freeze(piecewise) diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py index 7fd6d778ff62f7fb2fcbc24a55af5225fb9f870e..f9646afc26d11bddfb49c5a178096f9d2157d5f6 100644 --- a/tests/nbackend/kernelcreation/test_iteration_space.py +++ b/tests/nbackend/kernelcreation/test_iteration_space.py @@ -1,13 +1,13 @@ import pytest -from pystencils.field import Field -from pystencils.sympyextensions.typed_sympy import TypedSymbol, create_type +from pystencils import make_slice, Field, create_type +from pystencils.sympyextensions.typed_sympy import TypedSymbol +from pystencils.backend.constants import PsConstant from pystencils.backend.kernelcreation import KernelCreationContext, FullIterationSpace - from pystencils.backend.ast.expressions import PsAdd, PsConstantExpr, PsExpression from pystencils.backend.kernelcreation.typification import TypificationError -from pystencils.types import PsTypeError +from pystencils.types.quick import Int def test_slices(): @@ -36,12 +36,12 @@ def test_slices(): op.structurally_equal(PsExpression.make(archetype_arr.shape[0])) for op in dims[0].stop.children ) - + assert isinstance(dims[1].stop, PsAdd) and any( op.structurally_equal(PsExpression.make(archetype_arr.shape[1])) for op in dims[1].stop.children ) - + assert dims[2].stop.structurally_equal(PsExpression.make(archetype_arr.shape[2])) @@ -58,3 +58,28 @@ def test_invalid_slices(): islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),) with pytest.raises(TypificationError): FullIterationSpace.create_from_slice(ctx, islice, archetype_field) + + +def test_iteration_count(): + ctx = KernelCreationContext() + + i, j, k = [PsExpression.make(ctx.get_symbol(x, ctx.index_dtype)) for x in "ijk"] + zero = PsExpression.make(PsConstant(0, ctx.index_dtype)) + two = PsExpression.make(PsConstant(2, ctx.index_dtype)) + three = PsExpression.make(PsConstant(3, ctx.index_dtype)) + + ispace = FullIterationSpace.create_from_slice( + ctx, make_slice[three : i-two, 1:8:3] + ) + + iters = [ispace.actual_iterations(coord) for coord in range(2)] + assert iters[0].structurally_equal((i - two) - three) + assert iters[1].structurally_equal(three) + + empty_ispace = FullIterationSpace.create_from_slice( + ctx, make_slice[4:4:1, 4:4:7] + ) + + iters = [empty_ispace.actual_iterations(coord) for coord in range(2)] + assert iters[0].structurally_equal(zero) + assert iters[1].structurally_equal(zero) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 01f68c0a3e637e3139990f9208710e9861243e9d..5c2631e1eba1eaa84eb8f7ba6442ec020a191245 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -27,6 +27,7 @@ from pystencils.backend.ast.expressions import ( PsGt, PsLt, PsCall, + PsTernary ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import CFunction @@ -365,6 +366,31 @@ def test_invalid_conditions(): with pytest.raises(TypificationError): typify(cond) + +def test_typify_ternary(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] + a, b = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "ab"] + p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] + + expr = PsTernary(p, x, y) + expr = typify(expr) + assert expr.dtype == Fp(32, const=True) + + expr = PsTernary(PsAnd(p, q), a, b + a) + expr = typify(expr) + assert expr.dtype == Int(32, const=True) + + expr = PsTernary(PsAnd(p, q), a, x) + with pytest.raises(TypificationError): + typify(expr) + + expr = PsTernary(y, a, b) + with pytest.raises(TypificationError): + typify(expr) + def test_cfunction(): ctx = KernelCreationContext() diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py index 4c83e6e995f0823f81a4627e93d38256f648d28c..8fb44e748ad828b96e9ac46042527a7218828da5 100644 --- a/tests/nbackend/test_code_printing.py +++ b/tests/nbackend/test_code_printing.py @@ -55,17 +55,18 @@ def test_printing_integer_functions(): PsBitwiseOr, PsBitwiseXor, PsIntDiv, + PsRem ) expr = PsBitwiseAnd( PsBitwiseXor( PsBitwiseXor(j, k), PsBitwiseOr(PsLeftShift(i, PsRightShift(j, k)), PsIntDiv(i, k)), - ), + ) + PsRem(i, k), i, ) code = cprint(expr) - assert code == "(j ^ k ^ (i << (j >> k) | i / k)) & i" + assert code == "(j ^ k ^ (i << (j >> k) | i / k)) + i % k & i" def test_logical_precedence(): @@ -124,3 +125,32 @@ def test_relations_precedence(): expr = PsOr(PsNe(x, y), PsNot(PsGt(y, z))) code = cprint(expr) assert code == "x != y || !(y > z)" + + +def test_ternary(): + from pystencils.backend.ast.expressions import PsTernary + from pystencils.backend.ast.expressions import PsNot, PsAnd, PsOr + + p, q = [PsExpression.make(PsSymbol(x, Bool())) for x in "pq"] + x, y, z = [PsExpression.make(PsSymbol(x, Fp(32))) for x in "xyz"] + cprint = CAstPrinter() + + expr = PsTernary(p, x, y) + code = cprint(expr) + assert code == "p ? x : y" + + expr = PsTernary(PsAnd(p, q), x + y, z) + code = cprint(expr) + assert code == "p && q ? x + y : z" + + expr = PsTernary(p, PsTernary(q, x, y), z) + code = cprint(expr) + assert code == "p ? (q ? x : y) : z" + + expr = PsTernary(p, x, PsTernary(q, y, z)) + code = cprint(expr) + assert code == "p ? x : q ? y : z" + + expr = PsTernary(PsTernary(p, q, PsOr(p, q)), x, y) + code = cprint(expr) + assert code == "(p ? q : p || q) ? x : y" diff --git a/tests/nbackend/transformations/test_branch_elimination.py b/tests/nbackend/transformations/test_branch_elimination.py index fae8f158aaa472e02efadbd93365ce042dff0ab1..539fccc7b0ebf4d1515901449bc4d47dc1f50214 100644 --- a/tests/nbackend/transformations/test_branch_elimination.py +++ b/tests/nbackend/transformations/test_branch_elimination.py @@ -1,3 +1,5 @@ +import pytest + from pystencils import make_slice from pystencils.backend.kernelcreation import ( KernelCreationContext, @@ -62,6 +64,8 @@ def test_eliminate_nested_conditional(): def test_isl(): + pytest.importorskip("islpy") + ctx = KernelCreationContext() factory = AstFactory(ctx) typify = Typifier(ctx) diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py index 48df23ee193e24ff6736d74c22317f04dddc056c..92bb5c947b4bc2e4b6d50064a9c07874cdda43cf 100644 --- a/tests/nbackend/transformations/test_constant_elimination.py +++ b/tests/nbackend/transformations/test_constant_elimination.py @@ -10,6 +10,9 @@ from pystencils.backend.ast.expressions import ( PsNot, PsEq, PsGt, + PsTernary, + PsRem, + PsIntDiv ) from pystencils.types.quick import Int, Fp, Bool @@ -26,8 +29,10 @@ f1 = PsExpression.make(PsConstant(1.0, Fp(32))) i0 = PsExpression.make(PsConstant(0, Int(32))) i1 = PsExpression.make(PsConstant(1, Int(32))) +im1 = PsExpression.make(PsConstant(-1, Int(32))) i3 = PsExpression.make(PsConstant(3, Int(32))) +i4 = PsExpression.make(PsConstant(4, Int(32))) im3 = PsExpression.make(PsConstant(-3, Int(32))) i12 = PsExpression.make(PsConstant(12, Int(32))) @@ -86,6 +91,64 @@ def test_zero_dominance(): assert result.structurally_equal(i0) +def test_divisions(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + elim = EliminateConstants(ctx) + + expr = typify(f3p5 / f1) + result = elim(expr) + assert result.structurally_equal(f3p5) + + expr = typify(i3 / i1) + result = elim(expr) + assert result.structurally_equal(i3) + + expr = typify(PsRem(i3, i1)) + result = elim(expr) + assert result.structurally_equal(i0) + + expr = typify(PsIntDiv(i12, i3)) + result = elim(expr) + assert result.structurally_equal(i4) + + expr = typify(i12 / i3) + result = elim(expr) + assert result.structurally_equal(i4) + + expr = typify(PsIntDiv(i4, i3)) + result = elim(expr) + assert result.structurally_equal(i1) + + expr = typify(PsIntDiv(-i4, i3)) + result = elim(expr) + assert result.structurally_equal(im1) + + expr = typify(PsIntDiv(i4, -i3)) + result = elim(expr) + assert result.structurally_equal(im1) + + expr = typify(PsIntDiv(-i4, -i3)) + result = elim(expr) + assert result.structurally_equal(i1) + + expr = typify(PsRem(i4, i3)) + result = elim(expr) + assert result.structurally_equal(i1) + + expr = typify(PsRem(-i4, i3)) + result = elim(expr) + assert result.structurally_equal(im1) + + expr = typify(PsRem(i4, -i3)) + result = elim(expr) + assert result.structurally_equal(i1) + + expr = typify(PsRem(-i4, -i3)) + result = elim(expr) + assert result.structurally_equal(im1) + + def test_boolean_folding(): ctx = KernelCreationContext() typify = Typifier(ctx) @@ -128,3 +191,25 @@ def test_relations_folding(): expr = typify(PsGt(x + y, f1 * (x + y))) result = elim(expr) assert result.structurally_equal(false) + + +def test_ternary_folding(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + elim = EliminateConstants(ctx) + + expr = typify(PsTernary(true, x, y)) + result = elim(expr) + assert result.structurally_equal(x) + + expr = typify(PsTernary(false, x, y)) + result = elim(expr) + assert result.structurally_equal(y) + + expr = typify(PsTernary(PsGt(i1, i0), PsTernary(PsEq(i1, i12), x, y), z)) + result = elim(expr) + assert result.structurally_equal(y) + + expr = typify(PsTernary(PsGt(x, y), x + f0, y * f1)) + result = elim(expr) + assert result.structurally_equal(PsTernary(PsGt(x, y), x, y))