From 19d0df07d21d7e7d4820d0811ed6652aa61d6bec Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 7 Mar 2024 13:56:08 +0100
Subject: [PATCH] constant elimination, part 2: subexpression extraction

---
 src/pystencils/backend/ast/structural.py      |   8 ++
 src/pystencils/backend/ast/util.py            |  11 +-
 .../backend/kernelcreation/iteration_space.py |   3 +-
 .../backend/kernelcreation/typification.py    |   8 +-
 .../transformations/eliminate_constants.py    | 108 +++++++++++++++---
 src/pystencils/kernelcreation.py              |   7 +-
 6 files changed, 119 insertions(+), 26 deletions(-)

diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index 2338ee8c4..e5b88891c 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -14,6 +14,14 @@ class PsBlock(PsAstNode):
     def __init__(self, cs: Sequence[PsAstNode]):
         self._statements = list(cs)
 
+    @property
+    def children(self) -> Sequence[PsAstNode]:
+        return self.get_children()
+
+    @children.setter
+    def children(self, cs: Sequence[PsAstNode]):
+        self._statements = list(cs)
+
     def get_children(self) -> tuple[PsAstNode, ...]:
         return tuple(self._statements)
 
diff --git a/src/pystencils/backend/ast/util.py b/src/pystencils/backend/ast/util.py
index a93dc74e2..0d3b78629 100644
--- a/src/pystencils/backend/ast/util.py
+++ b/src/pystencils/backend/ast/util.py
@@ -11,8 +11,9 @@ def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any:
     return obj
 
 
-class EqWrapper:
-    """Wrapper around AST nodes that maps the `__eq__` method onto `structurally_equal`.
+class AstEqWrapper:
+    """Wrapper around AST nodes that computes a hash from the AST's textual representation
+    and maps the `__eq__` method onto `structurally_equal`.
 
     Useful in dictionaries when the goal is to keep track of subtrees according to their
     structure, e.g. in elimination of constants or common subexpressions.
@@ -26,10 +27,12 @@ class EqWrapper:
         return self._node
 
     def __eq__(self, other: object) -> bool:
-        if not isinstance(other, EqWrapper):
+        if not isinstance(other, AstEqWrapper):
             return False
 
         return self._node.structurally_equal(other._node)
 
     def __hash__(self) -> int:
-        return hash(self._node)
+        #   TODO: consider replacing this with smth. more performant
+        #   TODO: Check that repr is implemented by all AST nodes
+        return hash(repr(self._node))
diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py
index 2215c7e6a..e5b586688 100644
--- a/src/pystencils/backend/kernelcreation/iteration_space.py
+++ b/src/pystencils/backend/kernelcreation/iteration_space.py
@@ -150,9 +150,10 @@ class FullIterationSpace(IterationSpace):
             if isinstance(expr, int):
                 return PsConstantExpr(PsConstant(expr, ctx.index_dtype))
             elif isinstance(expr, sp.Expr):
-                return typifier.typify_expression(
+                typed_expr, _ = typifier.typify_expression(
                     freeze.freeze_expression(expr), ctx.index_dtype
                 )
+                return typed_expr
             else:
                 raise ValueError(f"Invalid entry in slice: {expr}")
 
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index ae6cc8b4b..ef04a617d 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -128,10 +128,14 @@ class Typifier:
 
     def typify_expression(
         self, expr: PsExpression, target_type: PsNumericType | None = None
-    ) -> PsExpression:
+    ) -> tuple[PsExpression, PsType]:
         tc = TypeContext(target_type)
         self.visit_expr(expr, tc)
-        return expr
+
+        if tc.target_type is None:
+            raise TypificationError(f"Unable to determine type for {expr}")
+
+        return expr, tc.target_type
 
     def visit(self, node: PsAstNode) -> None:
         """Recursive processing of structural nodes"""
diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py
index a1bfa332e..f7cf21fff 100644
--- a/src/pystencils/backend/transformations/eliminate_constants.py
+++ b/src/pystencils/backend/transformations/eliminate_constants.py
@@ -1,8 +1,10 @@
-from typing import cast
+from typing import cast, Iterable
+from collections import defaultdict
 
-from ..kernelcreation.context import KernelCreationContext
+from ..kernelcreation import KernelCreationContext, Typifier
 
 from ..ast import PsAstNode
+from ..ast.structural import PsBlock, PsDeclaration
 from ..ast.expressions import (
     PsExpression,
     PsConstantExpr,
@@ -13,17 +15,63 @@ from ..ast.expressions import (
     PsMul,
     PsDiv,
 )
+from ..ast.util import AstEqWrapper
 
 from ..constants import PsConstant
-from ...types import PsIntegerType, PsIeeeFloatType
+from ..symbols import PsSymbol
+from ...types import PsIntegerType, PsIeeeFloatType, PsTypeError
+from ..emission import CAstPrinter
 
 
 __all__ = ["EliminateConstants"]
 
 
 class ECContext:
-    def __init__(self):
-        pass
+    def __init__(self, ctx: KernelCreationContext):
+        self._ctx = ctx
+        self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict()
+
+        self._typifier = Typifier(ctx)
+        self._printer = CAstPrinter(0)
+
+    @property
+    def extractions(self) -> Iterable[tuple[PsSymbol, PsExpression]]:
+        return [
+            (symb, cast(PsExpression, w.n))
+            for (w, symb) in self._extracted_constants.items()
+        ]
+
+    def _get_symb_name(self, expr: PsExpression):
+        code = self._printer(expr)
+        code = code.lower()
+        #   remove spaces
+        code = "".join(code.split())
+
+        def valid_char(c):
+            return (ord("0") <= ord(c) <= ord("9")) or (ord("a") <= ord(c) <= ord("z"))
+
+        charmap = {"+": "p", "-": "s", "*": "m", "/": "o"}
+        charmap = defaultdict(lambda: "_", charmap)  # type: ignore
+
+        code = "".join((c if valid_char(c) else charmap[c]) for c in code)
+        return f"__c_{code}"
+
+    def extract_expression(self, expr: PsExpression) -> PsSymbolExpr:
+        expr, dtype = self._typifier.typify_expression(expr)
+        expr_wrapped = AstEqWrapper(expr)
+
+        if expr_wrapped not in self._extracted_constants:
+            symb_name = self._get_symb_name(expr)
+            try:
+                symb = self._ctx.get_symbol(symb_name, dtype)
+            except PsTypeError:
+                symb = self._ctx.get_symbol(f"{symb_name}_{dtype.c_string()}", dtype)
+
+            self._extracted_constants[expr_wrapped] = symb
+        else:
+            symb = self._extracted_constants[expr_wrapped]
+
+        return PsSymbolExpr(symb)
 
 
 class EliminateConstants:
@@ -38,26 +86,45 @@ class EliminateConstants:
       the outermost block.
     """
 
-    def __init__(self, ctx: KernelCreationContext):
+    def __init__(
+        self, ctx: KernelCreationContext, extract_constant_exprs: bool = False
+    ):
         self._ctx = ctx
-        
+
         self._fold_integers = True
         self._fold_floats = False
-        self._extract_constant_exprs = True
+        self._extract_constant_exprs = extract_constant_exprs
 
     def __call__(self, node: PsAstNode) -> PsAstNode:
-        return self.visit(node)
+        ecc = ECContext(self._ctx)
 
-    def visit(self, node: PsAstNode) -> PsAstNode:
+        node = self.visit(node, ecc)
+
+        if ecc.extractions:
+            prepend_decls = [
+                PsDeclaration(PsExpression.make(symb), expr)
+                for symb, expr in ecc.extractions
+            ]
+
+            if not isinstance(node, PsBlock):
+                node = PsBlock(prepend_decls + [node])
+            else:
+                node.children = prepend_decls + list(node.children)
+
+        return node
+
+    def visit(self, node: PsAstNode, ecc: ECContext) -> PsAstNode:
         match node:
             case PsExpression():
-                transformed_expr, _ = self.visit_expr(node)
+                transformed_expr, _ = self.visit_expr(node, ecc)
                 return transformed_expr
             case _:
-                node.children = [self.visit(c) for c in node.children]
+                node.children = [self.visit(c, ecc) for c in node.children]
                 return node
 
-    def visit_expr(self, expr: PsExpression) -> tuple[PsExpression, bool]:
+    def visit_expr(
+        self, expr: PsExpression, ecc: ECContext
+    ) -> tuple[PsExpression, bool]:
         """Transformation of expressions.
 
         Returns:
@@ -66,13 +133,13 @@ class EliminateConstants:
         #   Return constants as they are
         if isinstance(expr, PsConstantExpr):
             return expr, True
-        
+
         #   Shortcut symbols
         if isinstance(expr, PsSymbolExpr):
             return expr, False
 
         subtree_results = [
-            self.visit_expr(cast(PsExpression, c)) for c in expr.children
+            self.visit_expr(cast(PsExpression, c), ecc) for c in expr.children
         ]
         expr.children = [r[0] for r in subtree_results]
         subtree_constness = [r[1] for r in subtree_results]
@@ -91,7 +158,7 @@ class EliminateConstants:
 
             #   Additive idempotence: Subtraction from zero
             case PsSub(PsConstantExpr(c), other_op) if c.value == 0:
-                other_transformed, is_const = self.visit_expr(-other_op)
+                other_transformed, is_const = self.visit_expr(-other_op, ecc)
                 return other_transformed, is_const
 
             #   Multiplicative idempotence: Multiplication with and division by one
@@ -155,7 +222,14 @@ class EliminateConstants:
                 expr.operand1 = op1_transformed
                 expr.operand2 = op2_transformed
                 return expr, True
-        # end if: no constant expressions encountered
+        # end if: this expression is not constant
+
+        #   If required, extract constant subexpressions
+        if self._extract_constant_exprs:
+            for i, (child, is_const) in enumerate(subtree_results):
+                if is_const and not isinstance(child, PsConstantExpr):
+                    replacement = ecc.extract_expression(child)
+                    expr.set_child(i, replacement)
 
         #   Any other expressions are not considered constant even if their arguments are
         return expr, False
diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index 93b09e393..d49cf1bf5 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -82,8 +82,11 @@ def create_kernel(
     kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
 
     #   Simplifying transformations
-    kernel_ast = cast(PsBlock, EliminateConstants(ctx)(kernel_ast))
-    kernel_ast = cast(PsBlock, EraseAnonymousStructTypes(ctx)(kernel_ast))
+    elim_constants = EliminateConstants(ctx, extract_constant_exprs=True)
+    kernel_ast = cast(PsBlock, elim_constants(kernel_ast))
+
+    erase_anons = EraseAnonymousStructTypes(ctx)
+    kernel_ast = cast(PsBlock, erase_anons(kernel_ast))
 
     #   7. Apply optimizations
     #     - Vectorization
-- 
GitLab