Skip to content
Snippets Groups Projects
Commit d9a260ef authored by Frederik Hennig's avatar Frederik Hennig
Browse files

constant elimination, part one: idempotence, dominance, and folding

parent 645b742b
Branches
Tags
No related merge requests found
Pipeline #63818 failed with stages
in 3 minutes and 12 seconds
from .astnode import PsAstNode
from .iteration import dfs_preorder, dfs_postorder
__all__ = [
"PsAstNode",
"dfs_preorder",
"dfs_postorder",
]
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Sequence, overload
from typing import Sequence, overload, Callable, Any
import operator
from ..symbols import PsSymbol
from ..constants import PsConstant
......@@ -369,9 +370,15 @@ class PsUnOp(PsExpression):
idx = [0][idx]
self._operand = failing_cast(PsExpression, c)
@property
def python_operator(self) -> None | Callable[[Any], Any]:
return None
class PsNeg(PsUnOp):
pass
@property
def python_operator(self):
return operator.neg
class PsDeref(PsUnOp):
......@@ -450,18 +457,30 @@ class PsBinOp(PsExpression):
opname = self.__class__.__name__
return f"{opname}({repr(self._op1)}, {repr(self._op2)})"
@property
def python_operator(self) -> None | Callable[[Any, Any], Any]:
return None
class PsAdd(PsBinOp):
pass
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.add
class PsSub(PsBinOp):
pass
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.sub
class PsMul(PsBinOp):
pass
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
return operator.mul
class PsDiv(PsBinOp):
# python_operator not implemented because can't unambigously decide
# between intdiv and truediv
pass
from typing import Any
from __future__ import annotations
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from .astnode import PsAstNode
def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any:
if not isinstance(obj, target):
raise TypeError(f"Casting {obj} to {target} failed.")
return obj
class EqWrapper:
"""Wrapper around AST nodes that maps the `__eq__` method onto `structurally_equal`.
Useful in dictionaries when the goal is to keep track of subtrees according to their
structure, e.g. in elimination of constants or common subexpressions.
"""
def __init__(self, node: PsAstNode):
self._node = node
@property
def n(self):
return self._node
def __eq__(self, other: object) -> bool:
if not isinstance(other, EqWrapper):
return False
return self._node.structurally_equal(other._node)
def __hash__(self) -> int:
return hash(self._node)
......@@ -178,7 +178,7 @@ class Typifier:
case PsArrayAccess(_, idx):
tc.apply_and_check(expr, expr.dtype)
index_tc = TypeContext()
self.visit_expr(idx, index_tc)
if index_tc.target_type is None:
......
from .eliminate_constants import EliminateConstants
from .erase_anonymous_structs import EraseAnonymousStructTypes
from .vector_intrinsics import MaterializeVectorIntrinsics
__all__ = ["EraseAnonymousStructTypes", "MaterializeVectorIntrinsics"]
__all__ = [
"EliminateConstants",
"EraseAnonymousStructTypes",
"MaterializeVectorIntrinsics",
]
from typing import cast
from ..kernelcreation.context import KernelCreationContext
from ..ast import PsAstNode
from ..ast.expressions import (
PsExpression,
PsConstantExpr,
PsSymbolExpr,
PsBinOp,
PsAdd,
PsSub,
PsMul,
PsDiv,
)
from ..constants import PsConstant
from ...types import PsIntegerType, PsIeeeFloatType
__all__ = ["EliminateConstants"]
class ECContext:
def __init__(self):
pass
class EliminateConstants:
"""Eliminate constant expressions in various ways.
- Constant folding: Nontrivial constant integer (and optionally floating point) expressions
are evaluated and replaced by their result
- Idempotence elimination: Idempotent operations (e.g. addition of zero, multiplication with one)
are replaced by their result
- Dominance elimination: Multiplication by zero is replaced by zero
- Constant extraction: Optionally, nontrivial constant expressions are extracted and listed at the beginning of
the outermost block.
"""
def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx
self._fold_integers = True
self._fold_floats = False
self._extract_constant_exprs = True
def __call__(self, node: PsAstNode) -> PsAstNode:
return self.visit(node)
def visit(self, node: PsAstNode) -> PsAstNode:
match node:
case PsExpression():
transformed_expr, _ = self.visit_expr(node)
return transformed_expr
case _:
node.children = [self.visit(c) for c in node.children]
return node
def visit_expr(self, expr: PsExpression) -> tuple[PsExpression, bool]:
"""Transformation of expressions.
Returns:
(transformed_expr, is_const): The tranformed expression, and a flag indicating whether it is constant
"""
# Return constants as they are
if isinstance(expr, PsConstantExpr):
return expr, True
# Shortcut symbols
if isinstance(expr, PsSymbolExpr):
return expr, False
subtree_results = [
self.visit_expr(cast(PsExpression, c)) for c in expr.children
]
expr.children = [r[0] for r in subtree_results]
subtree_constness = [r[1] for r in subtree_results]
# Eliminate idempotence and dominance
match expr:
# Additive idempotence: Addition and subtraction of zero
case PsAdd(PsConstantExpr(c), other_op) if c.value == 0:
return other_op, all(subtree_constness)
case PsAdd(other_op, PsConstantExpr(c)) if c.value == 0:
return other_op, all(subtree_constness)
case PsSub(other_op, PsConstantExpr(c)) if c.value == 0:
return other_op, all(subtree_constness)
# Additive idempotence: Subtraction from zero
case PsSub(PsConstantExpr(c), other_op) if c.value == 0:
other_transformed, is_const = self.visit_expr(-other_op)
return other_transformed, is_const
# Multiplicative idempotence: Multiplication with and division by one
case PsMul(PsConstantExpr(c), other_op) if c.value == 1:
return other_op, all(subtree_constness)
case PsMul(other_op, PsConstantExpr(c)) if c.value == 1:
return other_op, all(subtree_constness)
case PsDiv(other_op, PsConstantExpr(c)) if c.value == 1:
return other_op, all(subtree_constness)
# Multiplicative dominance: 0 * x = 0
case PsMul(PsConstantExpr(c), other_op) if c.value == 0:
return PsConstantExpr(c), True
case PsMul(other_op, PsConstantExpr(c)) if c.value == 0:
return PsConstantExpr(c), True
# end match: no idempotence or dominance encountered
# Detect constant expressions
if all(subtree_constness):
# Fold binary expressions where possible
if isinstance(expr, PsBinOp):
op1_transformed = expr.operand1
op2_transformed = expr.operand2
if isinstance(op1_transformed, PsConstantExpr) and isinstance(
op2_transformed, PsConstantExpr
):
v1 = op1_transformed.constant.value
v2 = op2_transformed.constant.value
# assume they are of equal type
dtype = op1_transformed.constant.dtype
is_int = isinstance(dtype, PsIntegerType)
is_float = isinstance(dtype, PsIeeeFloatType)
if (self._fold_integers and is_int) or (
self._fold_floats and is_float
):
py_operator = expr.python_operator
folded = None
if py_operator is not None:
folded = PsConstant(
py_operator(v1, v2),
dtype,
)
elif isinstance(expr, PsDiv):
if isinstance(dtype, PsIntegerType):
folded = PsConstant(v1 // v2, dtype)
elif isinstance(dtype, PsIeeeFloatType):
folded = PsConstant(v1 / v2, dtype)
if folded is not None:
return PsConstantExpr(folded), True
expr.operand1 = op1_transformed
expr.operand2 = op2_transformed
return expr, True
# end if: no constant expressions encountered
# Any other expressions are not considered constant even if their arguments are
return expr, False
......@@ -25,7 +25,7 @@ from .backend.kernelcreation.iteration_space import (
)
from .backend.ast.analysis import collect_required_headers, collect_undefined_symbols
from .backend.transformations import EraseAnonymousStructTypes
from .backend.transformations import EraseAnonymousStructTypes, EliminateConstants
from .sympyextensions import AssignmentCollection, Assignment
......@@ -80,6 +80,9 @@ def create_kernel(
raise NotImplementedError("Target platform not implemented")
kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
# Simplifying transformations
kernel_ast = cast(PsBlock, EliminateConstants(ctx)(kernel_ast))
kernel_ast = cast(PsBlock, EraseAnonymousStructTypes(ctx)(kernel_ast))
# 7. Apply optimizations
......
from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr
from pystencils.backend.symbols import PsSymbol
from pystencils.backend.constants import PsConstant
from pystencils.backend.transformations import EliminateConstants
from pystencils.types.quick import Int, Fp
x, y, z = [PsExpression.make(PsSymbol(name)) for name in "xyz"]
f3p5 = PsExpression.make(PsConstant(3.5, Fp(32)))
f42 = PsExpression.make(PsConstant(42, Fp(32)))
f0 = PsExpression.make(PsConstant(0.0, Fp(32)))
f1 = PsExpression.make(PsConstant(1.0, Fp(32)))
i0 = PsExpression.make(PsConstant(0, Int(32)))
i1 = PsExpression.make(PsConstant(1, Int(32)))
i3 = PsExpression.make(PsConstant(3, Int(32)))
i12 = PsExpression.make(PsConstant(12, Int(32)))
def test_idempotence():
elim = EliminateConstants()
expr = f42 * (f1 + f0) - f0
result = elim(expr)
assert isinstance(result, PsConstantExpr) and result.structurally_equal(f42)
expr = (x + f0) * f3p5 + (f1 * y + f0) * f42
result = elim(expr)
assert result.structurally_equal(x * f3p5 + y * f42)
expr = (f3p5 * f1) + (f42 * f1)
result = elim(expr)
# do not fold floats by default
assert expr.structurally_equal(f3p5 + f42)
expr = f1 * x + f0 + (f0 + f0 + f1 + f0) * y
result = elim(expr)
assert result.structurally_equal(x + y)
def test_int_folding():
elim = EliminateConstants()
expr = (i1 * x + i1 * i3) + i1 * i12
result = elim(expr)
assert result.structurally_equal((x + i3) + i12)
expr = (i1 + i1 + i1 + i0 + i0 + i1) * (i1 + i1 + i1)
result = elim(expr)
assert result.structurally_equal(i12)
def test_zero_dominance():
elim = EliminateConstants()
expr = (f0 * x) + (y * f0) + f1
result = elim(expr)
assert result.structurally_equal(f1)
expr = (i3 + i12 * (x + y) + x / (i3 * y)) * i0
result = elim(expr)
assert result.structurally_equal(i0)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment