From d9a260ef114bb617fab8eceefb15b8d9f23732a0 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 7 Mar 2024 12:52:33 +0100 Subject: [PATCH] constant elimination, part one: idempotence, dominance, and folding --- src/pystencils/backend/ast/__init__.py | 2 + src/pystencils/backend/ast/expressions.py | 29 +++- src/pystencils/backend/ast/util.py | 30 +++- .../backend/kernelcreation/typification.py | 2 +- .../backend/transformations/__init__.py | 7 +- .../transformations/eliminate_constants.py | 161 ++++++++++++++++++ src/pystencils/kernelcreation.py | 5 +- .../test_constant_elimination.py | 65 +++++++ 8 files changed, 292 insertions(+), 9 deletions(-) create mode 100644 src/pystencils/backend/transformations/eliminate_constants.py create mode 100644 tests/nbackend/transformations/test_constant_elimination.py diff --git a/src/pystencils/backend/ast/__init__.py b/src/pystencils/backend/ast/__init__.py index 3cb4e2940..2ed5c03c2 100644 --- a/src/pystencils/backend/ast/__init__.py +++ b/src/pystencils/backend/ast/__init__.py @@ -1,6 +1,8 @@ +from .astnode import PsAstNode from .iteration import dfs_preorder, dfs_postorder __all__ = [ + "PsAstNode", "dfs_preorder", "dfs_postorder", ] diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index eab309a1d..a2c13548a 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Sequence, overload +from typing import Sequence, overload, Callable, Any +import operator from ..symbols import PsSymbol from ..constants import PsConstant @@ -369,9 +370,15 @@ class PsUnOp(PsExpression): idx = [0][idx] self._operand = failing_cast(PsExpression, c) + @property + def python_operator(self) -> None | Callable[[Any], Any]: + return None + class PsNeg(PsUnOp): - pass + @property + def python_operator(self): + return operator.neg class PsDeref(PsUnOp): @@ -450,18 +457,30 @@ class PsBinOp(PsExpression): opname = self.__class__.__name__ return f"{opname}({repr(self._op1)}, {repr(self._op2)})" + @property + def python_operator(self) -> None | Callable[[Any, Any], Any]: + return None + class PsAdd(PsBinOp): - pass + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.add class PsSub(PsBinOp): - pass + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.sub class PsMul(PsBinOp): - pass + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.mul class PsDiv(PsBinOp): + # python_operator not implemented because can't unambigously decide + # between intdiv and truediv pass diff --git a/src/pystencils/backend/ast/util.py b/src/pystencils/backend/ast/util.py index c3d93ed4c..a93dc74e2 100644 --- a/src/pystencils/backend/ast/util.py +++ b/src/pystencils/backend/ast/util.py @@ -1,7 +1,35 @@ -from typing import Any +from __future__ import annotations +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from .astnode import PsAstNode def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any: if not isinstance(obj, target): raise TypeError(f"Casting {obj} to {target} failed.") return obj + + +class EqWrapper: + """Wrapper around AST nodes that maps the `__eq__` method onto `structurally_equal`. + + Useful in dictionaries when the goal is to keep track of subtrees according to their + structure, e.g. in elimination of constants or common subexpressions. + """ + + def __init__(self, node: PsAstNode): + self._node = node + + @property + def n(self): + return self._node + + def __eq__(self, other: object) -> bool: + if not isinstance(other, EqWrapper): + return False + + return self._node.structurally_equal(other._node) + + def __hash__(self) -> int: + return hash(self._node) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 48e4e6aa0..ae6cc8b4b 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -178,7 +178,7 @@ class Typifier: case PsArrayAccess(_, idx): tc.apply_and_check(expr, expr.dtype) - + index_tc = TypeContext() self.visit_expr(idx, index_tc) if index_tc.target_type is None: diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index c630e580b..fc261f174 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -1,4 +1,9 @@ +from .eliminate_constants import EliminateConstants from .erase_anonymous_structs import EraseAnonymousStructTypes from .vector_intrinsics import MaterializeVectorIntrinsics -__all__ = ["EraseAnonymousStructTypes", "MaterializeVectorIntrinsics"] +__all__ = [ + "EliminateConstants", + "EraseAnonymousStructTypes", + "MaterializeVectorIntrinsics", +] diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py new file mode 100644 index 000000000..a1bfa332e --- /dev/null +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -0,0 +1,161 @@ +from typing import cast + +from ..kernelcreation.context import KernelCreationContext + +from ..ast import PsAstNode +from ..ast.expressions import ( + PsExpression, + PsConstantExpr, + PsSymbolExpr, + PsBinOp, + PsAdd, + PsSub, + PsMul, + PsDiv, +) + +from ..constants import PsConstant +from ...types import PsIntegerType, PsIeeeFloatType + + +__all__ = ["EliminateConstants"] + + +class ECContext: + def __init__(self): + pass + + +class EliminateConstants: + """Eliminate constant expressions in various ways. + + - Constant folding: Nontrivial constant integer (and optionally floating point) expressions + are evaluated and replaced by their result + - Idempotence elimination: Idempotent operations (e.g. addition of zero, multiplication with one) + are replaced by their result + - Dominance elimination: Multiplication by zero is replaced by zero + - Constant extraction: Optionally, nontrivial constant expressions are extracted and listed at the beginning of + the outermost block. + """ + + def __init__(self, ctx: KernelCreationContext): + self._ctx = ctx + + self._fold_integers = True + self._fold_floats = False + self._extract_constant_exprs = True + + def __call__(self, node: PsAstNode) -> PsAstNode: + return self.visit(node) + + def visit(self, node: PsAstNode) -> PsAstNode: + match node: + case PsExpression(): + transformed_expr, _ = self.visit_expr(node) + return transformed_expr + case _: + node.children = [self.visit(c) for c in node.children] + return node + + def visit_expr(self, expr: PsExpression) -> tuple[PsExpression, bool]: + """Transformation of expressions. + + Returns: + (transformed_expr, is_const): The tranformed expression, and a flag indicating whether it is constant + """ + # Return constants as they are + if isinstance(expr, PsConstantExpr): + return expr, True + + # Shortcut symbols + if isinstance(expr, PsSymbolExpr): + return expr, False + + subtree_results = [ + self.visit_expr(cast(PsExpression, c)) for c in expr.children + ] + expr.children = [r[0] for r in subtree_results] + subtree_constness = [r[1] for r in subtree_results] + + # Eliminate idempotence and dominance + match expr: + # Additive idempotence: Addition and subtraction of zero + case PsAdd(PsConstantExpr(c), other_op) if c.value == 0: + return other_op, all(subtree_constness) + + case PsAdd(other_op, PsConstantExpr(c)) if c.value == 0: + return other_op, all(subtree_constness) + + case PsSub(other_op, PsConstantExpr(c)) if c.value == 0: + return other_op, all(subtree_constness) + + # Additive idempotence: Subtraction from zero + case PsSub(PsConstantExpr(c), other_op) if c.value == 0: + other_transformed, is_const = self.visit_expr(-other_op) + return other_transformed, is_const + + # Multiplicative idempotence: Multiplication with and division by one + case PsMul(PsConstantExpr(c), other_op) if c.value == 1: + return other_op, all(subtree_constness) + + case PsMul(other_op, PsConstantExpr(c)) if c.value == 1: + return other_op, all(subtree_constness) + + case PsDiv(other_op, PsConstantExpr(c)) if c.value == 1: + return other_op, all(subtree_constness) + + # Multiplicative dominance: 0 * x = 0 + case PsMul(PsConstantExpr(c), other_op) if c.value == 0: + return PsConstantExpr(c), True + + case PsMul(other_op, PsConstantExpr(c)) if c.value == 0: + return PsConstantExpr(c), True + + # end match: no idempotence or dominance encountered + + # Detect constant expressions + if all(subtree_constness): + # Fold binary expressions where possible + if isinstance(expr, PsBinOp): + op1_transformed = expr.operand1 + op2_transformed = expr.operand2 + + if isinstance(op1_transformed, PsConstantExpr) and isinstance( + op2_transformed, PsConstantExpr + ): + v1 = op1_transformed.constant.value + v2 = op2_transformed.constant.value + + # assume they are of equal type + dtype = op1_transformed.constant.dtype + + is_int = isinstance(dtype, PsIntegerType) + is_float = isinstance(dtype, PsIeeeFloatType) + + if (self._fold_integers and is_int) or ( + self._fold_floats and is_float + ): + py_operator = expr.python_operator + + folded = None + if py_operator is not None: + folded = PsConstant( + py_operator(v1, v2), + dtype, + ) + elif isinstance(expr, PsDiv): + if isinstance(dtype, PsIntegerType): + folded = PsConstant(v1 // v2, dtype) + elif isinstance(dtype, PsIeeeFloatType): + folded = PsConstant(v1 / v2, dtype) + + if folded is not None: + return PsConstantExpr(folded), True + + expr.operand1 = op1_transformed + expr.operand2 = op2_transformed + return expr, True + # end if: no constant expressions encountered + + # Any other expressions are not considered constant even if their arguments are + return expr, False diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 535586cd8..93b09e393 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -25,7 +25,7 @@ from .backend.kernelcreation.iteration_space import ( ) from .backend.ast.analysis import collect_required_headers, collect_undefined_symbols -from .backend.transformations import EraseAnonymousStructTypes +from .backend.transformations import EraseAnonymousStructTypes, EliminateConstants from .sympyextensions import AssignmentCollection, Assignment @@ -80,6 +80,9 @@ def create_kernel( raise NotImplementedError("Target platform not implemented") kernel_ast = platform.materialize_iteration_space(kernel_body, ispace) + + # Simplifying transformations + kernel_ast = cast(PsBlock, EliminateConstants(ctx)(kernel_ast)) kernel_ast = cast(PsBlock, EraseAnonymousStructTypes(ctx)(kernel_ast)) # 7. Apply optimizations diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py new file mode 100644 index 000000000..da89ab044 --- /dev/null +++ b/tests/nbackend/transformations/test_constant_elimination.py @@ -0,0 +1,65 @@ +from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr +from pystencils.backend.symbols import PsSymbol +from pystencils.backend.constants import PsConstant +from pystencils.backend.transformations import EliminateConstants + +from pystencils.types.quick import Int, Fp + +x, y, z = [PsExpression.make(PsSymbol(name)) for name in "xyz"] + +f3p5 = PsExpression.make(PsConstant(3.5, Fp(32))) +f42 = PsExpression.make(PsConstant(42, Fp(32))) + +f0 = PsExpression.make(PsConstant(0.0, Fp(32))) +f1 = PsExpression.make(PsConstant(1.0, Fp(32))) + +i0 = PsExpression.make(PsConstant(0, Int(32))) +i1 = PsExpression.make(PsConstant(1, Int(32))) + +i3 = PsExpression.make(PsConstant(3, Int(32))) +i12 = PsExpression.make(PsConstant(12, Int(32))) + + +def test_idempotence(): + elim = EliminateConstants() + + expr = f42 * (f1 + f0) - f0 + result = elim(expr) + assert isinstance(result, PsConstantExpr) and result.structurally_equal(f42) + + expr = (x + f0) * f3p5 + (f1 * y + f0) * f42 + result = elim(expr) + assert result.structurally_equal(x * f3p5 + y * f42) + + expr = (f3p5 * f1) + (f42 * f1) + result = elim(expr) + # do not fold floats by default + assert expr.structurally_equal(f3p5 + f42) + + expr = f1 * x + f0 + (f0 + f0 + f1 + f0) * y + result = elim(expr) + assert result.structurally_equal(x + y) + + +def test_int_folding(): + elim = EliminateConstants() + + expr = (i1 * x + i1 * i3) + i1 * i12 + result = elim(expr) + assert result.structurally_equal((x + i3) + i12) + + expr = (i1 + i1 + i1 + i0 + i0 + i1) * (i1 + i1 + i1) + result = elim(expr) + assert result.structurally_equal(i12) + + +def test_zero_dominance(): + elim = EliminateConstants() + + expr = (f0 * x) + (y * f0) + f1 + result = elim(expr) + assert result.structurally_equal(f1) + + expr = (i3 + i12 * (x + y) + x / (i3 * y)) * i0 + result = elim(expr) + assert result.structurally_equal(i0) -- GitLab