From 19d0df07d21d7e7d4820d0811ed6652aa61d6bec Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 7 Mar 2024 13:56:08 +0100 Subject: [PATCH] constant elimination, part 2: subexpression extraction --- src/pystencils/backend/ast/structural.py | 8 ++ src/pystencils/backend/ast/util.py | 11 +- .../backend/kernelcreation/iteration_space.py | 3 +- .../backend/kernelcreation/typification.py | 8 +- .../transformations/eliminate_constants.py | 108 +++++++++++++++--- src/pystencils/kernelcreation.py | 7 +- 6 files changed, 119 insertions(+), 26 deletions(-) diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 2338ee8c4..e5b88891c 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -14,6 +14,14 @@ class PsBlock(PsAstNode): def __init__(self, cs: Sequence[PsAstNode]): self._statements = list(cs) + @property + def children(self) -> Sequence[PsAstNode]: + return self.get_children() + + @children.setter + def children(self, cs: Sequence[PsAstNode]): + self._statements = list(cs) + def get_children(self) -> tuple[PsAstNode, ...]: return tuple(self._statements) diff --git a/src/pystencils/backend/ast/util.py b/src/pystencils/backend/ast/util.py index a93dc74e2..0d3b78629 100644 --- a/src/pystencils/backend/ast/util.py +++ b/src/pystencils/backend/ast/util.py @@ -11,8 +11,9 @@ def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any: return obj -class EqWrapper: - """Wrapper around AST nodes that maps the `__eq__` method onto `structurally_equal`. +class AstEqWrapper: + """Wrapper around AST nodes that computes a hash from the AST's textual representation + and maps the `__eq__` method onto `structurally_equal`. Useful in dictionaries when the goal is to keep track of subtrees according to their structure, e.g. in elimination of constants or common subexpressions. @@ -26,10 +27,12 @@ class EqWrapper: return self._node def __eq__(self, other: object) -> bool: - if not isinstance(other, EqWrapper): + if not isinstance(other, AstEqWrapper): return False return self._node.structurally_equal(other._node) def __hash__(self) -> int: - return hash(self._node) + # TODO: consider replacing this with smth. more performant + # TODO: Check that repr is implemented by all AST nodes + return hash(repr(self._node)) diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 2215c7e6a..e5b586688 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -150,9 +150,10 @@ class FullIterationSpace(IterationSpace): if isinstance(expr, int): return PsConstantExpr(PsConstant(expr, ctx.index_dtype)) elif isinstance(expr, sp.Expr): - return typifier.typify_expression( + typed_expr, _ = typifier.typify_expression( freeze.freeze_expression(expr), ctx.index_dtype ) + return typed_expr else: raise ValueError(f"Invalid entry in slice: {expr}") diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index ae6cc8b4b..ef04a617d 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -128,10 +128,14 @@ class Typifier: def typify_expression( self, expr: PsExpression, target_type: PsNumericType | None = None - ) -> PsExpression: + ) -> tuple[PsExpression, PsType]: tc = TypeContext(target_type) self.visit_expr(expr, tc) - return expr + + if tc.target_type is None: + raise TypificationError(f"Unable to determine type for {expr}") + + return expr, tc.target_type def visit(self, node: PsAstNode) -> None: """Recursive processing of structural nodes""" diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index a1bfa332e..f7cf21fff 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -1,8 +1,10 @@ -from typing import cast +from typing import cast, Iterable +from collections import defaultdict -from ..kernelcreation.context import KernelCreationContext +from ..kernelcreation import KernelCreationContext, Typifier from ..ast import PsAstNode +from ..ast.structural import PsBlock, PsDeclaration from ..ast.expressions import ( PsExpression, PsConstantExpr, @@ -13,17 +15,63 @@ from ..ast.expressions import ( PsMul, PsDiv, ) +from ..ast.util import AstEqWrapper from ..constants import PsConstant -from ...types import PsIntegerType, PsIeeeFloatType +from ..symbols import PsSymbol +from ...types import PsIntegerType, PsIeeeFloatType, PsTypeError +from ..emission import CAstPrinter __all__ = ["EliminateConstants"] class ECContext: - def __init__(self): - pass + def __init__(self, ctx: KernelCreationContext): + self._ctx = ctx + self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict() + + self._typifier = Typifier(ctx) + self._printer = CAstPrinter(0) + + @property + def extractions(self) -> Iterable[tuple[PsSymbol, PsExpression]]: + return [ + (symb, cast(PsExpression, w.n)) + for (w, symb) in self._extracted_constants.items() + ] + + def _get_symb_name(self, expr: PsExpression): + code = self._printer(expr) + code = code.lower() + # remove spaces + code = "".join(code.split()) + + def valid_char(c): + return (ord("0") <= ord(c) <= ord("9")) or (ord("a") <= ord(c) <= ord("z")) + + charmap = {"+": "p", "-": "s", "*": "m", "/": "o"} + charmap = defaultdict(lambda: "_", charmap) # type: ignore + + code = "".join((c if valid_char(c) else charmap[c]) for c in code) + return f"__c_{code}" + + def extract_expression(self, expr: PsExpression) -> PsSymbolExpr: + expr, dtype = self._typifier.typify_expression(expr) + expr_wrapped = AstEqWrapper(expr) + + if expr_wrapped not in self._extracted_constants: + symb_name = self._get_symb_name(expr) + try: + symb = self._ctx.get_symbol(symb_name, dtype) + except PsTypeError: + symb = self._ctx.get_symbol(f"{symb_name}_{dtype.c_string()}", dtype) + + self._extracted_constants[expr_wrapped] = symb + else: + symb = self._extracted_constants[expr_wrapped] + + return PsSymbolExpr(symb) class EliminateConstants: @@ -38,26 +86,45 @@ class EliminateConstants: the outermost block. """ - def __init__(self, ctx: KernelCreationContext): + def __init__( + self, ctx: KernelCreationContext, extract_constant_exprs: bool = False + ): self._ctx = ctx - + self._fold_integers = True self._fold_floats = False - self._extract_constant_exprs = True + self._extract_constant_exprs = extract_constant_exprs def __call__(self, node: PsAstNode) -> PsAstNode: - return self.visit(node) + ecc = ECContext(self._ctx) - def visit(self, node: PsAstNode) -> PsAstNode: + node = self.visit(node, ecc) + + if ecc.extractions: + prepend_decls = [ + PsDeclaration(PsExpression.make(symb), expr) + for symb, expr in ecc.extractions + ] + + if not isinstance(node, PsBlock): + node = PsBlock(prepend_decls + [node]) + else: + node.children = prepend_decls + list(node.children) + + return node + + def visit(self, node: PsAstNode, ecc: ECContext) -> PsAstNode: match node: case PsExpression(): - transformed_expr, _ = self.visit_expr(node) + transformed_expr, _ = self.visit_expr(node, ecc) return transformed_expr case _: - node.children = [self.visit(c) for c in node.children] + node.children = [self.visit(c, ecc) for c in node.children] return node - def visit_expr(self, expr: PsExpression) -> tuple[PsExpression, bool]: + def visit_expr( + self, expr: PsExpression, ecc: ECContext + ) -> tuple[PsExpression, bool]: """Transformation of expressions. Returns: @@ -66,13 +133,13 @@ class EliminateConstants: # Return constants as they are if isinstance(expr, PsConstantExpr): return expr, True - + # Shortcut symbols if isinstance(expr, PsSymbolExpr): return expr, False subtree_results = [ - self.visit_expr(cast(PsExpression, c)) for c in expr.children + self.visit_expr(cast(PsExpression, c), ecc) for c in expr.children ] expr.children = [r[0] for r in subtree_results] subtree_constness = [r[1] for r in subtree_results] @@ -91,7 +158,7 @@ class EliminateConstants: # Additive idempotence: Subtraction from zero case PsSub(PsConstantExpr(c), other_op) if c.value == 0: - other_transformed, is_const = self.visit_expr(-other_op) + other_transformed, is_const = self.visit_expr(-other_op, ecc) return other_transformed, is_const # Multiplicative idempotence: Multiplication with and division by one @@ -155,7 +222,14 @@ class EliminateConstants: expr.operand1 = op1_transformed expr.operand2 = op2_transformed return expr, True - # end if: no constant expressions encountered + # end if: this expression is not constant + + # If required, extract constant subexpressions + if self._extract_constant_exprs: + for i, (child, is_const) in enumerate(subtree_results): + if is_const and not isinstance(child, PsConstantExpr): + replacement = ecc.extract_expression(child) + expr.set_child(i, replacement) # Any other expressions are not considered constant even if their arguments are return expr, False diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 93b09e393..d49cf1bf5 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -82,8 +82,11 @@ def create_kernel( kernel_ast = platform.materialize_iteration_space(kernel_body, ispace) # Simplifying transformations - kernel_ast = cast(PsBlock, EliminateConstants(ctx)(kernel_ast)) - kernel_ast = cast(PsBlock, EraseAnonymousStructTypes(ctx)(kernel_ast)) + elim_constants = EliminateConstants(ctx, extract_constant_exprs=True) + kernel_ast = cast(PsBlock, elim_constants(kernel_ast)) + + erase_anons = EraseAnonymousStructTypes(ctx) + kernel_ast = cast(PsBlock, erase_anons(kernel_ast)) # 7. Apply optimizations # - Vectorization -- GitLab