diff --git a/mypy.ini b/mypy.ini index 8e9fe08334a08d8ab0a272f114d4a719d81398ed..07228fe24009da6ea4f21cb6cdf15a0516041149 100644 --- a/mypy.ini +++ b/mypy.ini @@ -16,3 +16,6 @@ ignore_missing_imports=true [mypy-appdirs.*] ignore_missing_imports=true + +[mypy-islpy.*] +ignore_missing_imports=true diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py index eab3d3722c30756ab39af072e75e9d6d89874447..f098d82df1ce6a748097756aa1616a72e57487b5 100644 --- a/src/pystencils/backend/transformations/eliminate_branches.py +++ b/src/pystencils/backend/transformations/eliminate_branches.py @@ -1,16 +1,44 @@ from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode +from ..ast.analysis import collect_undefined_symbols from ..ast.structural import PsLoop, PsBlock, PsConditional -from ..ast.expressions import PsConstantExpr +from ..ast.expressions import ( + PsAnd, + PsCast, + PsConstant, + PsConstantExpr, + PsDiv, + PsEq, + PsExpression, + PsGe, + PsGt, + PsIntDiv, + PsLe, + PsLt, + PsMul, + PsNe, + PsNeg, + PsNot, + PsOr, + PsSub, + PsSymbolExpr, + PsAdd, +) from .eliminate_constants import EliminateConstants +from ...types import PsBoolType, PsIntegerType __all__ = ["EliminateBranches"] +class IslAnalysisError(Exception): + """Indicates a fatal error during integer set analysis (based on islpy)""" + + class BranchElimContext: def __init__(self) -> None: self.enclosing_loops: list[PsLoop] = [] + self.enclosing_conditions: list[PsExpression] = [] class EliminateBranches: @@ -20,12 +48,16 @@ class EliminateBranches: This pass will attempt to evaluate branch conditions within their context in the AST, and replace conditionals by either their then- or their else-block if the branch is unequivocal. - TODO: If islpy is installed, this pass will incorporate information about the iteration regions - of enclosing loops into its analysis. + If islpy is installed, this pass will incorporate information about the iteration regions + of enclosing loops and enclosing conditionals into its analysis. + + Args: + use_isl (bool, optional): enable islpy based analysis (default: True) """ - def __init__(self, ctx: KernelCreationContext) -> None: + def __init__(self, ctx: KernelCreationContext, use_isl: bool = True) -> None: self._ctx = ctx + self._use_isl = use_isl self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False) def __call__(self, node: PsAstNode) -> PsAstNode: @@ -41,20 +73,30 @@ class EliminateBranches: case PsBlock(statements): statements_new: list[PsAstNode] = [] for stmt in statements: - if isinstance(stmt, PsConditional): - result = self.handle_conditional(stmt, ec) - if result is not None: - statements_new.append(result) - else: - statements_new.append(self.visit(stmt, ec)) + statements_new.append(self.visit(stmt, ec)) node.statements = statements_new case PsConditional(): result = self.handle_conditional(node, ec) - if result is None: - return PsBlock([]) - else: - return result + + match result: + case PsConditional(_, branch_true, branch_false): + ec.enclosing_conditions.append(result.condition) + self.visit(branch_true, ec) + ec.enclosing_conditions.pop() + + if branch_false is not None: + ec.enclosing_conditions.append(PsNot(result.condition)) + self.visit(branch_false, ec) + ec.enclosing_conditions.pop() + case PsBlock(): + self.visit(result, ec) + case None: + result = PsBlock([]) + case _: + assert False, "unreachable code" + + return result return node @@ -62,12 +104,124 @@ class EliminateBranches: self, conditional: PsConditional, ec: BranchElimContext ) -> PsConditional | PsBlock | None: condition_simplified = self._elim_constants(conditional.condition) + if self._use_isl: + condition_simplified = self._isl_simplify_condition( + condition_simplified, ec + ) + match condition_simplified: case PsConstantExpr(c) if c.value: return conditional.branch_true case PsConstantExpr(c) if not c.value: return conditional.branch_false - # TODO: Analyze condition against counters of enclosing loops using ISL - return conditional + + def _isl_simplify_condition( + self, condition: PsExpression, ec: BranchElimContext + ) -> PsExpression: + """If installed, use ISL to simplify the passed condition to true or + false based on enclosing loops and conditionals. If no simplification + can be made or ISL is not installed, the original condition is returned. + """ + + try: + import islpy as isl + except ImportError: + return condition + + def printer(expr: PsExpression): + match expr: + case PsSymbolExpr(symbol): + return symbol.name + + case PsConstantExpr(constant): + dtype = constant.get_dtype() + if not isinstance(dtype, (PsIntegerType, PsBoolType)): + raise IslAnalysisError( + "Only scalar integer and bool constant may appear in isl expressions." + ) + return str(constant.value) + + case PsAdd(op1, op2): + return f"({printer(op1)} + {printer(op2)})" + case PsSub(op1, op2): + return f"({printer(op1)} - {printer(op2)})" + case PsMul(op1, op2): + return f"({printer(op1)} * {printer(op2)})" + case PsDiv(op1, op2) | PsIntDiv(op1, op2): + return f"({printer(op1)} / {printer(op2)})" + case PsAnd(op1, op2): + return f"({printer(op1)} and {printer(op2)})" + case PsOr(op1, op2): + return f"({printer(op1)} or {printer(op2)})" + case PsEq(op1, op2): + return f"({printer(op1)} = {printer(op2)})" + case PsNe(op1, op2): + return f"({printer(op1)} != {printer(op2)})" + case PsGt(op1, op2): + return f"({printer(op1)} > {printer(op2)})" + case PsGe(op1, op2): + return f"({printer(op1)} >= {printer(op2)})" + case PsLt(op1, op2): + return f"({printer(op1)} < {printer(op2)})" + case PsLe(op1, op2): + return f"({printer(op1)} <= {printer(op2)})" + + case PsNeg(operand): + return f"(-{printer(operand)})" + case PsNot(operand): + return f"(not {printer(operand)})" + + case PsCast(_, operand): + return printer(operand) + + case _: + raise IslAnalysisError( + f"Not supported by isl or don't know how to print {expr}" + ) + + dofs = collect_undefined_symbols(condition) + outer_conditions = [] + + for loop in ec.enclosing_loops: + if not ( + isinstance(loop.step, PsConstantExpr) + and loop.step.constant.value == 1 + ): + raise IslAnalysisError( + "Loops with strides != 1 are not yet supported." + ) + + dofs.add(loop.counter.symbol) + dofs.update(collect_undefined_symbols(loop.start)) + dofs.update(collect_undefined_symbols(loop.stop)) + + loop_start_str = printer(loop.start) + loop_stop_str = printer(loop.stop) + ctr_name = loop.counter.symbol.name + outer_conditions.append( + f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}" + ) + + for cond in ec.enclosing_conditions: + dofs.update(collect_undefined_symbols(cond)) + outer_conditions.append(printer(cond)) + + dofs_str = ",".join(dof.name for dof in dofs) + outer_conditions_str = " and ".join(outer_conditions) + condition_str = printer(condition) + + outer_set = isl.BasicSet(f"{{ [{dofs_str}] : {outer_conditions_str} }}") + inner_set = isl.BasicSet(f"{{ [{dofs_str}] : {condition_str} }}") + + if inner_set.is_empty(): + return PsExpression.make(PsConstant(False)) + + intersection = outer_set.intersect(inner_set) + if intersection.is_empty(): + return PsExpression.make(PsConstant(False)) + elif intersection == outer_set: + return PsExpression.make(PsConstant(True)) + else: + return condition diff --git a/tests/nbackend/transformations/test_branch_elimination.py b/tests/nbackend/transformations/test_branch_elimination.py index 0fb3526d0b53fd40972c4dfeb06cf3a614bc6c10..fae8f158aaa472e02efadbd93365ce042dff0ab1 100644 --- a/tests/nbackend/transformations/test_branch_elimination.py +++ b/tests/nbackend/transformations/test_branch_elimination.py @@ -4,12 +4,18 @@ from pystencils.backend.kernelcreation import ( Typifier, AstFactory, ) -from pystencils.backend.ast.expressions import PsExpression +from pystencils.backend.ast.expressions import ( + PsExpression, + PsEq, + PsGe, + PsGt, + PsLe, + PsLt, +) from pystencils.backend.ast.structural import PsConditional, PsBlock, PsComment from pystencils.backend.constants import PsConstant from pystencils.backend.transformations import EliminateBranches from pystencils.types.quick import Int -from pystencils.backend.ast.expressions import PsGt i0 = PsExpression.make(PsConstant(0, Int(32))) @@ -53,3 +59,39 @@ def test_eliminate_nested_conditional(): result = elim(ast) assert result.body.statements[0].body.statements[0] == b1 + + +def test_isl(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + typify = Typifier(ctx) + elim = EliminateBranches(ctx) + + i = PsExpression.make(ctx.get_symbol("i", ctx.index_dtype)) + j = PsExpression.make(ctx.get_symbol("j", ctx.index_dtype)) + + const_2 = PsExpression.make(PsConstant(2, ctx.index_dtype)) + const_4 = PsExpression.make(PsConstant(4, ctx.index_dtype)) + + a_true = PsBlock([PsComment("a true")]) + a_false = PsBlock([PsComment("a false")]) + b_true = PsBlock([PsComment("b true")]) + b_false = PsBlock([PsComment("b false")]) + c_true = PsBlock([PsComment("c true")]) + c_false = PsBlock([PsComment("c false")]) + + a = PsConditional(PsLt(i + j, const_2 * const_4), a_true, a_false) + b = PsConditional(PsGe(j, const_4), b_true, b_false) + c = PsConditional(PsEq(i, const_4), c_true, c_false) + + outer_loop = factory.loop(j.symbol.name, slice(0, 3), PsBlock([a, b, c])) + outer_cond = typify( + PsConditional(PsLe(i, const_4), PsBlock([outer_loop]), PsBlock([])) + ) + ast = outer_cond + + result = elim(ast) + + assert result.branch_true.statements[0].body.statements[0] == a_true + assert result.branch_true.statements[0].body.statements[1] == b_false + assert result.branch_true.statements[0].body.statements[2] == c