From d9a260ef114bb617fab8eceefb15b8d9f23732a0 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 7 Mar 2024 12:52:33 +0100
Subject: [PATCH] constant elimination, part one: idempotence, dominance, and
 folding

---
 src/pystencils/backend/ast/__init__.py        |   2 +
 src/pystencils/backend/ast/expressions.py     |  29 +++-
 src/pystencils/backend/ast/util.py            |  30 +++-
 .../backend/kernelcreation/typification.py    |   2 +-
 .../backend/transformations/__init__.py       |   7 +-
 .../transformations/eliminate_constants.py    | 161 ++++++++++++++++++
 src/pystencils/kernelcreation.py              |   5 +-
 .../test_constant_elimination.py              |  65 +++++++
 8 files changed, 292 insertions(+), 9 deletions(-)
 create mode 100644 src/pystencils/backend/transformations/eliminate_constants.py
 create mode 100644 tests/nbackend/transformations/test_constant_elimination.py

diff --git a/src/pystencils/backend/ast/__init__.py b/src/pystencils/backend/ast/__init__.py
index 3cb4e2940..2ed5c03c2 100644
--- a/src/pystencils/backend/ast/__init__.py
+++ b/src/pystencils/backend/ast/__init__.py
@@ -1,6 +1,8 @@
+from .astnode import PsAstNode
 from .iteration import dfs_preorder, dfs_postorder
 
 __all__ = [
+    "PsAstNode",
     "dfs_preorder",
     "dfs_postorder",
 ]
diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index eab309a1d..a2c13548a 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -1,6 +1,7 @@
 from __future__ import annotations
 from abc import ABC, abstractmethod
-from typing import Sequence, overload
+from typing import Sequence, overload, Callable, Any
+import operator
 
 from ..symbols import PsSymbol
 from ..constants import PsConstant
@@ -369,9 +370,15 @@ class PsUnOp(PsExpression):
         idx = [0][idx]
         self._operand = failing_cast(PsExpression, c)
 
+    @property
+    def python_operator(self) -> None | Callable[[Any], Any]:
+        return None
+
 
 class PsNeg(PsUnOp):
-    pass
+    @property
+    def python_operator(self):
+        return operator.neg
 
 
 class PsDeref(PsUnOp):
@@ -450,18 +457,30 @@ class PsBinOp(PsExpression):
         opname = self.__class__.__name__
         return f"{opname}({repr(self._op1)}, {repr(self._op2)})"
 
+    @property
+    def python_operator(self) -> None | Callable[[Any, Any], Any]:
+        return None
+
 
 class PsAdd(PsBinOp):
-    pass
+    @property
+    def python_operator(self) -> Callable[[Any, Any], Any] | None:
+        return operator.add
 
 
 class PsSub(PsBinOp):
-    pass
+    @property
+    def python_operator(self) -> Callable[[Any, Any], Any] | None:
+        return operator.sub
 
 
 class PsMul(PsBinOp):
-    pass
+    @property
+    def python_operator(self) -> Callable[[Any, Any], Any] | None:
+        return operator.mul
 
 
 class PsDiv(PsBinOp):
+    #  python_operator not implemented because can't unambigously decide
+    #  between intdiv and truediv
     pass
diff --git a/src/pystencils/backend/ast/util.py b/src/pystencils/backend/ast/util.py
index c3d93ed4c..a93dc74e2 100644
--- a/src/pystencils/backend/ast/util.py
+++ b/src/pystencils/backend/ast/util.py
@@ -1,7 +1,35 @@
-from typing import Any
+from __future__ import annotations
+from typing import Any, TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from .astnode import PsAstNode
 
 
 def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any:
     if not isinstance(obj, target):
         raise TypeError(f"Casting {obj} to {target} failed.")
     return obj
+
+
+class EqWrapper:
+    """Wrapper around AST nodes that maps the `__eq__` method onto `structurally_equal`.
+
+    Useful in dictionaries when the goal is to keep track of subtrees according to their
+    structure, e.g. in elimination of constants or common subexpressions.
+    """
+
+    def __init__(self, node: PsAstNode):
+        self._node = node
+
+    @property
+    def n(self):
+        return self._node
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, EqWrapper):
+            return False
+
+        return self._node.structurally_equal(other._node)
+
+    def __hash__(self) -> int:
+        return hash(self._node)
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 48e4e6aa0..ae6cc8b4b 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -178,7 +178,7 @@ class Typifier:
 
             case PsArrayAccess(_, idx):
                 tc.apply_and_check(expr, expr.dtype)
-                
+
                 index_tc = TypeContext()
                 self.visit_expr(idx, index_tc)
                 if index_tc.target_type is None:
diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py
index c630e580b..fc261f174 100644
--- a/src/pystencils/backend/transformations/__init__.py
+++ b/src/pystencils/backend/transformations/__init__.py
@@ -1,4 +1,9 @@
+from .eliminate_constants import EliminateConstants
 from .erase_anonymous_structs import EraseAnonymousStructTypes
 from .vector_intrinsics import MaterializeVectorIntrinsics
 
-__all__ = ["EraseAnonymousStructTypes", "MaterializeVectorIntrinsics"]
+__all__ = [
+    "EliminateConstants",
+    "EraseAnonymousStructTypes",
+    "MaterializeVectorIntrinsics",
+]
diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py
new file mode 100644
index 000000000..a1bfa332e
--- /dev/null
+++ b/src/pystencils/backend/transformations/eliminate_constants.py
@@ -0,0 +1,161 @@
+from typing import cast
+
+from ..kernelcreation.context import KernelCreationContext
+
+from ..ast import PsAstNode
+from ..ast.expressions import (
+    PsExpression,
+    PsConstantExpr,
+    PsSymbolExpr,
+    PsBinOp,
+    PsAdd,
+    PsSub,
+    PsMul,
+    PsDiv,
+)
+
+from ..constants import PsConstant
+from ...types import PsIntegerType, PsIeeeFloatType
+
+
+__all__ = ["EliminateConstants"]
+
+
+class ECContext:
+    def __init__(self):
+        pass
+
+
+class EliminateConstants:
+    """Eliminate constant expressions in various ways.
+
+    - Constant folding: Nontrivial constant integer (and optionally floating point) expressions
+      are evaluated and replaced by their result
+    - Idempotence elimination: Idempotent operations (e.g. addition of zero, multiplication with one)
+      are replaced by their result
+    - Dominance elimination: Multiplication by zero is replaced by zero
+    - Constant extraction: Optionally, nontrivial constant expressions are extracted and listed at the beginning of
+      the outermost block.
+    """
+
+    def __init__(self, ctx: KernelCreationContext):
+        self._ctx = ctx
+        
+        self._fold_integers = True
+        self._fold_floats = False
+        self._extract_constant_exprs = True
+
+    def __call__(self, node: PsAstNode) -> PsAstNode:
+        return self.visit(node)
+
+    def visit(self, node: PsAstNode) -> PsAstNode:
+        match node:
+            case PsExpression():
+                transformed_expr, _ = self.visit_expr(node)
+                return transformed_expr
+            case _:
+                node.children = [self.visit(c) for c in node.children]
+                return node
+
+    def visit_expr(self, expr: PsExpression) -> tuple[PsExpression, bool]:
+        """Transformation of expressions.
+
+        Returns:
+            (transformed_expr, is_const): The tranformed expression, and a flag indicating whether it is constant
+        """
+        #   Return constants as they are
+        if isinstance(expr, PsConstantExpr):
+            return expr, True
+        
+        #   Shortcut symbols
+        if isinstance(expr, PsSymbolExpr):
+            return expr, False
+
+        subtree_results = [
+            self.visit_expr(cast(PsExpression, c)) for c in expr.children
+        ]
+        expr.children = [r[0] for r in subtree_results]
+        subtree_constness = [r[1] for r in subtree_results]
+
+        #   Eliminate idempotence and dominance
+        match expr:
+            #   Additive idempotence: Addition and subtraction of zero
+            case PsAdd(PsConstantExpr(c), other_op) if c.value == 0:
+                return other_op, all(subtree_constness)
+
+            case PsAdd(other_op, PsConstantExpr(c)) if c.value == 0:
+                return other_op, all(subtree_constness)
+
+            case PsSub(other_op, PsConstantExpr(c)) if c.value == 0:
+                return other_op, all(subtree_constness)
+
+            #   Additive idempotence: Subtraction from zero
+            case PsSub(PsConstantExpr(c), other_op) if c.value == 0:
+                other_transformed, is_const = self.visit_expr(-other_op)
+                return other_transformed, is_const
+
+            #   Multiplicative idempotence: Multiplication with and division by one
+            case PsMul(PsConstantExpr(c), other_op) if c.value == 1:
+                return other_op, all(subtree_constness)
+
+            case PsMul(other_op, PsConstantExpr(c)) if c.value == 1:
+                return other_op, all(subtree_constness)
+
+            case PsDiv(other_op, PsConstantExpr(c)) if c.value == 1:
+                return other_op, all(subtree_constness)
+
+            #   Multiplicative dominance: 0 * x = 0
+            case PsMul(PsConstantExpr(c), other_op) if c.value == 0:
+                return PsConstantExpr(c), True
+
+            case PsMul(other_op, PsConstantExpr(c)) if c.value == 0:
+                return PsConstantExpr(c), True
+
+        # end match: no idempotence or dominance encountered
+
+        #   Detect constant expressions
+        if all(subtree_constness):
+            #   Fold binary expressions where possible
+            if isinstance(expr, PsBinOp):
+                op1_transformed = expr.operand1
+                op2_transformed = expr.operand2
+
+                if isinstance(op1_transformed, PsConstantExpr) and isinstance(
+                    op2_transformed, PsConstantExpr
+                ):
+                    v1 = op1_transformed.constant.value
+                    v2 = op2_transformed.constant.value
+
+                    # assume they are of equal type
+                    dtype = op1_transformed.constant.dtype
+
+                    is_int = isinstance(dtype, PsIntegerType)
+                    is_float = isinstance(dtype, PsIeeeFloatType)
+
+                    if (self._fold_integers and is_int) or (
+                        self._fold_floats and is_float
+                    ):
+                        py_operator = expr.python_operator
+
+                        folded = None
+                        if py_operator is not None:
+                            folded = PsConstant(
+                                py_operator(v1, v2),
+                                dtype,
+                            )
+                        elif isinstance(expr, PsDiv):
+                            if isinstance(dtype, PsIntegerType):
+                                folded = PsConstant(v1 // v2, dtype)
+                            elif isinstance(dtype, PsIeeeFloatType):
+                                folded = PsConstant(v1 / v2, dtype)
+
+                        if folded is not None:
+                            return PsConstantExpr(folded), True
+
+                expr.operand1 = op1_transformed
+                expr.operand2 = op2_transformed
+                return expr, True
+        # end if: no constant expressions encountered
+
+        #   Any other expressions are not considered constant even if their arguments are
+        return expr, False
diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index 535586cd8..93b09e393 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -25,7 +25,7 @@ from .backend.kernelcreation.iteration_space import (
 )
 
 from .backend.ast.analysis import collect_required_headers, collect_undefined_symbols
-from .backend.transformations import EraseAnonymousStructTypes
+from .backend.transformations import EraseAnonymousStructTypes, EliminateConstants
 
 from .sympyextensions import AssignmentCollection, Assignment
 
@@ -80,6 +80,9 @@ def create_kernel(
             raise NotImplementedError("Target platform not implemented")
 
     kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
+
+    #   Simplifying transformations
+    kernel_ast = cast(PsBlock, EliminateConstants(ctx)(kernel_ast))
     kernel_ast = cast(PsBlock, EraseAnonymousStructTypes(ctx)(kernel_ast))
 
     #   7. Apply optimizations
diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py
new file mode 100644
index 000000000..da89ab044
--- /dev/null
+++ b/tests/nbackend/transformations/test_constant_elimination.py
@@ -0,0 +1,65 @@
+from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr
+from pystencils.backend.symbols import PsSymbol
+from pystencils.backend.constants import PsConstant
+from pystencils.backend.transformations import EliminateConstants
+
+from pystencils.types.quick import Int, Fp
+
+x, y, z = [PsExpression.make(PsSymbol(name)) for name in "xyz"]
+
+f3p5 = PsExpression.make(PsConstant(3.5, Fp(32)))
+f42 = PsExpression.make(PsConstant(42, Fp(32)))
+
+f0 = PsExpression.make(PsConstant(0.0, Fp(32)))
+f1 = PsExpression.make(PsConstant(1.0, Fp(32)))
+
+i0 = PsExpression.make(PsConstant(0, Int(32)))
+i1 = PsExpression.make(PsConstant(1, Int(32)))
+
+i3 = PsExpression.make(PsConstant(3, Int(32)))
+i12 = PsExpression.make(PsConstant(12, Int(32)))
+
+
+def test_idempotence():
+    elim = EliminateConstants()
+
+    expr = f42 * (f1 + f0) - f0
+    result = elim(expr)
+    assert isinstance(result, PsConstantExpr) and result.structurally_equal(f42)
+
+    expr = (x + f0) * f3p5 + (f1 * y + f0) * f42
+    result = elim(expr)
+    assert result.structurally_equal(x * f3p5 + y * f42)
+
+    expr = (f3p5 * f1) + (f42 * f1)
+    result = elim(expr)
+    #   do not fold floats by default
+    assert expr.structurally_equal(f3p5 + f42)
+
+    expr = f1 * x + f0 + (f0 + f0 + f1 + f0) * y
+    result = elim(expr)
+    assert result.structurally_equal(x + y)
+
+
+def test_int_folding():
+    elim = EliminateConstants()
+
+    expr = (i1 * x + i1 * i3) + i1 * i12
+    result = elim(expr)
+    assert result.structurally_equal((x + i3) + i12)
+
+    expr = (i1 + i1 + i1 + i0 + i0 + i1) * (i1 + i1 + i1)
+    result = elim(expr)
+    assert result.structurally_equal(i12)
+
+
+def test_zero_dominance():
+    elim = EliminateConstants()
+
+    expr = (f0 * x) + (y * f0) + f1
+    result = elim(expr)
+    assert result.structurally_equal(f1)
+
+    expr = (i3 + i12 * (x + y) + x / (i3 * y)) * i0
+    result = elim(expr)
+    assert result.structurally_equal(i0)
-- 
GitLab