Skip to content
Snippets Groups Projects
Commit 205ef895 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'bauerd/isl' into 'backend-rework'

Eliminate branches: implement isl analysis and recurse into conditionals

See merge request !390
parents 02965644 aebb23ee
1 merge request!390Eliminate branches: implement isl analysis and recurse into conditionals
Pipeline #66734 passed with stages
in 1 minute and 24 seconds
...@@ -16,3 +16,6 @@ ignore_missing_imports=true ...@@ -16,3 +16,6 @@ ignore_missing_imports=true
[mypy-appdirs.*] [mypy-appdirs.*]
ignore_missing_imports=true ignore_missing_imports=true
[mypy-islpy.*]
ignore_missing_imports=true
from ..kernelcreation import KernelCreationContext from ..kernelcreation import KernelCreationContext
from ..ast import PsAstNode from ..ast import PsAstNode
from ..ast.analysis import collect_undefined_symbols
from ..ast.structural import PsLoop, PsBlock, PsConditional from ..ast.structural import PsLoop, PsBlock, PsConditional
from ..ast.expressions import PsConstantExpr from ..ast.expressions import (
PsAnd,
PsCast,
PsConstant,
PsConstantExpr,
PsDiv,
PsEq,
PsExpression,
PsGe,
PsGt,
PsIntDiv,
PsLe,
PsLt,
PsMul,
PsNe,
PsNeg,
PsNot,
PsOr,
PsSub,
PsSymbolExpr,
PsAdd,
)
from .eliminate_constants import EliminateConstants from .eliminate_constants import EliminateConstants
from ...types import PsBoolType, PsIntegerType
__all__ = ["EliminateBranches"] __all__ = ["EliminateBranches"]
class IslAnalysisError(Exception):
"""Indicates a fatal error during integer set analysis (based on islpy)"""
class BranchElimContext: class BranchElimContext:
def __init__(self) -> None: def __init__(self) -> None:
self.enclosing_loops: list[PsLoop] = [] self.enclosing_loops: list[PsLoop] = []
self.enclosing_conditions: list[PsExpression] = []
class EliminateBranches: class EliminateBranches:
...@@ -20,12 +48,16 @@ class EliminateBranches: ...@@ -20,12 +48,16 @@ class EliminateBranches:
This pass will attempt to evaluate branch conditions within their context in the AST, and replace This pass will attempt to evaluate branch conditions within their context in the AST, and replace
conditionals by either their then- or their else-block if the branch is unequivocal. conditionals by either their then- or their else-block if the branch is unequivocal.
TODO: If islpy is installed, this pass will incorporate information about the iteration regions If islpy is installed, this pass will incorporate information about the iteration regions
of enclosing loops into its analysis. of enclosing loops and enclosing conditionals into its analysis.
Args:
use_isl (bool, optional): enable islpy based analysis (default: True)
""" """
def __init__(self, ctx: KernelCreationContext) -> None: def __init__(self, ctx: KernelCreationContext, use_isl: bool = True) -> None:
self._ctx = ctx self._ctx = ctx
self._use_isl = use_isl
self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False) self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False)
def __call__(self, node: PsAstNode) -> PsAstNode: def __call__(self, node: PsAstNode) -> PsAstNode:
...@@ -41,20 +73,30 @@ class EliminateBranches: ...@@ -41,20 +73,30 @@ class EliminateBranches:
case PsBlock(statements): case PsBlock(statements):
statements_new: list[PsAstNode] = [] statements_new: list[PsAstNode] = []
for stmt in statements: for stmt in statements:
if isinstance(stmt, PsConditional): statements_new.append(self.visit(stmt, ec))
result = self.handle_conditional(stmt, ec)
if result is not None:
statements_new.append(result)
else:
statements_new.append(self.visit(stmt, ec))
node.statements = statements_new node.statements = statements_new
case PsConditional(): case PsConditional():
result = self.handle_conditional(node, ec) result = self.handle_conditional(node, ec)
if result is None:
return PsBlock([]) match result:
else: case PsConditional(_, branch_true, branch_false):
return result ec.enclosing_conditions.append(result.condition)
self.visit(branch_true, ec)
ec.enclosing_conditions.pop()
if branch_false is not None:
ec.enclosing_conditions.append(PsNot(result.condition))
self.visit(branch_false, ec)
ec.enclosing_conditions.pop()
case PsBlock():
self.visit(result, ec)
case None:
result = PsBlock([])
case _:
assert False, "unreachable code"
return result
return node return node
...@@ -62,12 +104,124 @@ class EliminateBranches: ...@@ -62,12 +104,124 @@ class EliminateBranches:
self, conditional: PsConditional, ec: BranchElimContext self, conditional: PsConditional, ec: BranchElimContext
) -> PsConditional | PsBlock | None: ) -> PsConditional | PsBlock | None:
condition_simplified = self._elim_constants(conditional.condition) condition_simplified = self._elim_constants(conditional.condition)
if self._use_isl:
condition_simplified = self._isl_simplify_condition(
condition_simplified, ec
)
match condition_simplified: match condition_simplified:
case PsConstantExpr(c) if c.value: case PsConstantExpr(c) if c.value:
return conditional.branch_true return conditional.branch_true
case PsConstantExpr(c) if not c.value: case PsConstantExpr(c) if not c.value:
return conditional.branch_false return conditional.branch_false
# TODO: Analyze condition against counters of enclosing loops using ISL
return conditional return conditional
def _isl_simplify_condition(
self, condition: PsExpression, ec: BranchElimContext
) -> PsExpression:
"""If installed, use ISL to simplify the passed condition to true or
false based on enclosing loops and conditionals. If no simplification
can be made or ISL is not installed, the original condition is returned.
"""
try:
import islpy as isl
except ImportError:
return condition
def printer(expr: PsExpression):
match expr:
case PsSymbolExpr(symbol):
return symbol.name
case PsConstantExpr(constant):
dtype = constant.get_dtype()
if not isinstance(dtype, (PsIntegerType, PsBoolType)):
raise IslAnalysisError(
"Only scalar integer and bool constant may appear in isl expressions."
)
return str(constant.value)
case PsAdd(op1, op2):
return f"({printer(op1)} + {printer(op2)})"
case PsSub(op1, op2):
return f"({printer(op1)} - {printer(op2)})"
case PsMul(op1, op2):
return f"({printer(op1)} * {printer(op2)})"
case PsDiv(op1, op2) | PsIntDiv(op1, op2):
return f"({printer(op1)} / {printer(op2)})"
case PsAnd(op1, op2):
return f"({printer(op1)} and {printer(op2)})"
case PsOr(op1, op2):
return f"({printer(op1)} or {printer(op2)})"
case PsEq(op1, op2):
return f"({printer(op1)} = {printer(op2)})"
case PsNe(op1, op2):
return f"({printer(op1)} != {printer(op2)})"
case PsGt(op1, op2):
return f"({printer(op1)} > {printer(op2)})"
case PsGe(op1, op2):
return f"({printer(op1)} >= {printer(op2)})"
case PsLt(op1, op2):
return f"({printer(op1)} < {printer(op2)})"
case PsLe(op1, op2):
return f"({printer(op1)} <= {printer(op2)})"
case PsNeg(operand):
return f"(-{printer(operand)})"
case PsNot(operand):
return f"(not {printer(operand)})"
case PsCast(_, operand):
return printer(operand)
case _:
raise IslAnalysisError(
f"Not supported by isl or don't know how to print {expr}"
)
dofs = collect_undefined_symbols(condition)
outer_conditions = []
for loop in ec.enclosing_loops:
if not (
isinstance(loop.step, PsConstantExpr)
and loop.step.constant.value == 1
):
raise IslAnalysisError(
"Loops with strides != 1 are not yet supported."
)
dofs.add(loop.counter.symbol)
dofs.update(collect_undefined_symbols(loop.start))
dofs.update(collect_undefined_symbols(loop.stop))
loop_start_str = printer(loop.start)
loop_stop_str = printer(loop.stop)
ctr_name = loop.counter.symbol.name
outer_conditions.append(
f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}"
)
for cond in ec.enclosing_conditions:
dofs.update(collect_undefined_symbols(cond))
outer_conditions.append(printer(cond))
dofs_str = ",".join(dof.name for dof in dofs)
outer_conditions_str = " and ".join(outer_conditions)
condition_str = printer(condition)
outer_set = isl.BasicSet(f"{{ [{dofs_str}] : {outer_conditions_str} }}")
inner_set = isl.BasicSet(f"{{ [{dofs_str}] : {condition_str} }}")
if inner_set.is_empty():
return PsExpression.make(PsConstant(False))
intersection = outer_set.intersect(inner_set)
if intersection.is_empty():
return PsExpression.make(PsConstant(False))
elif intersection == outer_set:
return PsExpression.make(PsConstant(True))
else:
return condition
...@@ -4,12 +4,18 @@ from pystencils.backend.kernelcreation import ( ...@@ -4,12 +4,18 @@ from pystencils.backend.kernelcreation import (
Typifier, Typifier,
AstFactory, AstFactory,
) )
from pystencils.backend.ast.expressions import PsExpression from pystencils.backend.ast.expressions import (
PsExpression,
PsEq,
PsGe,
PsGt,
PsLe,
PsLt,
)
from pystencils.backend.ast.structural import PsConditional, PsBlock, PsComment from pystencils.backend.ast.structural import PsConditional, PsBlock, PsComment
from pystencils.backend.constants import PsConstant from pystencils.backend.constants import PsConstant
from pystencils.backend.transformations import EliminateBranches from pystencils.backend.transformations import EliminateBranches
from pystencils.types.quick import Int from pystencils.types.quick import Int
from pystencils.backend.ast.expressions import PsGt
i0 = PsExpression.make(PsConstant(0, Int(32))) i0 = PsExpression.make(PsConstant(0, Int(32)))
...@@ -53,3 +59,39 @@ def test_eliminate_nested_conditional(): ...@@ -53,3 +59,39 @@ def test_eliminate_nested_conditional():
result = elim(ast) result = elim(ast)
assert result.body.statements[0].body.statements[0] == b1 assert result.body.statements[0].body.statements[0] == b1
def test_isl():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
typify = Typifier(ctx)
elim = EliminateBranches(ctx)
i = PsExpression.make(ctx.get_symbol("i", ctx.index_dtype))
j = PsExpression.make(ctx.get_symbol("j", ctx.index_dtype))
const_2 = PsExpression.make(PsConstant(2, ctx.index_dtype))
const_4 = PsExpression.make(PsConstant(4, ctx.index_dtype))
a_true = PsBlock([PsComment("a true")])
a_false = PsBlock([PsComment("a false")])
b_true = PsBlock([PsComment("b true")])
b_false = PsBlock([PsComment("b false")])
c_true = PsBlock([PsComment("c true")])
c_false = PsBlock([PsComment("c false")])
a = PsConditional(PsLt(i + j, const_2 * const_4), a_true, a_false)
b = PsConditional(PsGe(j, const_4), b_true, b_false)
c = PsConditional(PsEq(i, const_4), c_true, c_false)
outer_loop = factory.loop(j.symbol.name, slice(0, 3), PsBlock([a, b, c]))
outer_cond = typify(
PsConditional(PsLe(i, const_4), PsBlock([outer_loop]), PsBlock([]))
)
ast = outer_cond
result = elim(ast)
assert result.branch_true.statements[0].body.statements[0] == a_true
assert result.branch_true.statements[0].body.statements[1] == b_false
assert result.branch_true.statements[0].body.statements[2] == c
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment