Skip to content
Snippets Groups Projects
Commit 19d0df07 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

constant elimination, part 2: subexpression extraction

parent d9a260ef
No related merge requests found
Pipeline #63824 failed with stages
in 3 hours, 56 minutes, and 17 seconds
...@@ -14,6 +14,14 @@ class PsBlock(PsAstNode): ...@@ -14,6 +14,14 @@ class PsBlock(PsAstNode):
def __init__(self, cs: Sequence[PsAstNode]): def __init__(self, cs: Sequence[PsAstNode]):
self._statements = list(cs) self._statements = list(cs)
@property
def children(self) -> Sequence[PsAstNode]:
return self.get_children()
@children.setter
def children(self, cs: Sequence[PsAstNode]):
self._statements = list(cs)
def get_children(self) -> tuple[PsAstNode, ...]: def get_children(self) -> tuple[PsAstNode, ...]:
return tuple(self._statements) return tuple(self._statements)
......
...@@ -11,8 +11,9 @@ def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any: ...@@ -11,8 +11,9 @@ def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any:
return obj return obj
class EqWrapper: class AstEqWrapper:
"""Wrapper around AST nodes that maps the `__eq__` method onto `structurally_equal`. """Wrapper around AST nodes that computes a hash from the AST's textual representation
and maps the `__eq__` method onto `structurally_equal`.
Useful in dictionaries when the goal is to keep track of subtrees according to their Useful in dictionaries when the goal is to keep track of subtrees according to their
structure, e.g. in elimination of constants or common subexpressions. structure, e.g. in elimination of constants or common subexpressions.
...@@ -26,10 +27,12 @@ class EqWrapper: ...@@ -26,10 +27,12 @@ class EqWrapper:
return self._node return self._node
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, EqWrapper): if not isinstance(other, AstEqWrapper):
return False return False
return self._node.structurally_equal(other._node) return self._node.structurally_equal(other._node)
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(self._node) # TODO: consider replacing this with smth. more performant
# TODO: Check that repr is implemented by all AST nodes
return hash(repr(self._node))
...@@ -150,9 +150,10 @@ class FullIterationSpace(IterationSpace): ...@@ -150,9 +150,10 @@ class FullIterationSpace(IterationSpace):
if isinstance(expr, int): if isinstance(expr, int):
return PsConstantExpr(PsConstant(expr, ctx.index_dtype)) return PsConstantExpr(PsConstant(expr, ctx.index_dtype))
elif isinstance(expr, sp.Expr): elif isinstance(expr, sp.Expr):
return typifier.typify_expression( typed_expr, _ = typifier.typify_expression(
freeze.freeze_expression(expr), ctx.index_dtype freeze.freeze_expression(expr), ctx.index_dtype
) )
return typed_expr
else: else:
raise ValueError(f"Invalid entry in slice: {expr}") raise ValueError(f"Invalid entry in slice: {expr}")
......
...@@ -128,10 +128,14 @@ class Typifier: ...@@ -128,10 +128,14 @@ class Typifier:
def typify_expression( def typify_expression(
self, expr: PsExpression, target_type: PsNumericType | None = None self, expr: PsExpression, target_type: PsNumericType | None = None
) -> PsExpression: ) -> tuple[PsExpression, PsType]:
tc = TypeContext(target_type) tc = TypeContext(target_type)
self.visit_expr(expr, tc) self.visit_expr(expr, tc)
return expr
if tc.target_type is None:
raise TypificationError(f"Unable to determine type for {expr}")
return expr, tc.target_type
def visit(self, node: PsAstNode) -> None: def visit(self, node: PsAstNode) -> None:
"""Recursive processing of structural nodes""" """Recursive processing of structural nodes"""
......
from typing import cast from typing import cast, Iterable
from collections import defaultdict
from ..kernelcreation.context import KernelCreationContext from ..kernelcreation import KernelCreationContext, Typifier
from ..ast import PsAstNode from ..ast import PsAstNode
from ..ast.structural import PsBlock, PsDeclaration
from ..ast.expressions import ( from ..ast.expressions import (
PsExpression, PsExpression,
PsConstantExpr, PsConstantExpr,
...@@ -13,17 +15,63 @@ from ..ast.expressions import ( ...@@ -13,17 +15,63 @@ from ..ast.expressions import (
PsMul, PsMul,
PsDiv, PsDiv,
) )
from ..ast.util import AstEqWrapper
from ..constants import PsConstant from ..constants import PsConstant
from ...types import PsIntegerType, PsIeeeFloatType from ..symbols import PsSymbol
from ...types import PsIntegerType, PsIeeeFloatType, PsTypeError
from ..emission import CAstPrinter
__all__ = ["EliminateConstants"] __all__ = ["EliminateConstants"]
class ECContext: class ECContext:
def __init__(self): def __init__(self, ctx: KernelCreationContext):
pass self._ctx = ctx
self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict()
self._typifier = Typifier(ctx)
self._printer = CAstPrinter(0)
@property
def extractions(self) -> Iterable[tuple[PsSymbol, PsExpression]]:
return [
(symb, cast(PsExpression, w.n))
for (w, symb) in self._extracted_constants.items()
]
def _get_symb_name(self, expr: PsExpression):
code = self._printer(expr)
code = code.lower()
# remove spaces
code = "".join(code.split())
def valid_char(c):
return (ord("0") <= ord(c) <= ord("9")) or (ord("a") <= ord(c) <= ord("z"))
charmap = {"+": "p", "-": "s", "*": "m", "/": "o"}
charmap = defaultdict(lambda: "_", charmap) # type: ignore
code = "".join((c if valid_char(c) else charmap[c]) for c in code)
return f"__c_{code}"
def extract_expression(self, expr: PsExpression) -> PsSymbolExpr:
expr, dtype = self._typifier.typify_expression(expr)
expr_wrapped = AstEqWrapper(expr)
if expr_wrapped not in self._extracted_constants:
symb_name = self._get_symb_name(expr)
try:
symb = self._ctx.get_symbol(symb_name, dtype)
except PsTypeError:
symb = self._ctx.get_symbol(f"{symb_name}_{dtype.c_string()}", dtype)
self._extracted_constants[expr_wrapped] = symb
else:
symb = self._extracted_constants[expr_wrapped]
return PsSymbolExpr(symb)
class EliminateConstants: class EliminateConstants:
...@@ -38,26 +86,45 @@ class EliminateConstants: ...@@ -38,26 +86,45 @@ class EliminateConstants:
the outermost block. the outermost block.
""" """
def __init__(self, ctx: KernelCreationContext): def __init__(
self, ctx: KernelCreationContext, extract_constant_exprs: bool = False
):
self._ctx = ctx self._ctx = ctx
self._fold_integers = True self._fold_integers = True
self._fold_floats = False self._fold_floats = False
self._extract_constant_exprs = True self._extract_constant_exprs = extract_constant_exprs
def __call__(self, node: PsAstNode) -> PsAstNode: def __call__(self, node: PsAstNode) -> PsAstNode:
return self.visit(node) ecc = ECContext(self._ctx)
def visit(self, node: PsAstNode) -> PsAstNode: node = self.visit(node, ecc)
if ecc.extractions:
prepend_decls = [
PsDeclaration(PsExpression.make(symb), expr)
for symb, expr in ecc.extractions
]
if not isinstance(node, PsBlock):
node = PsBlock(prepend_decls + [node])
else:
node.children = prepend_decls + list(node.children)
return node
def visit(self, node: PsAstNode, ecc: ECContext) -> PsAstNode:
match node: match node:
case PsExpression(): case PsExpression():
transformed_expr, _ = self.visit_expr(node) transformed_expr, _ = self.visit_expr(node, ecc)
return transformed_expr return transformed_expr
case _: case _:
node.children = [self.visit(c) for c in node.children] node.children = [self.visit(c, ecc) for c in node.children]
return node return node
def visit_expr(self, expr: PsExpression) -> tuple[PsExpression, bool]: def visit_expr(
self, expr: PsExpression, ecc: ECContext
) -> tuple[PsExpression, bool]:
"""Transformation of expressions. """Transformation of expressions.
Returns: Returns:
...@@ -66,13 +133,13 @@ class EliminateConstants: ...@@ -66,13 +133,13 @@ class EliminateConstants:
# Return constants as they are # Return constants as they are
if isinstance(expr, PsConstantExpr): if isinstance(expr, PsConstantExpr):
return expr, True return expr, True
# Shortcut symbols # Shortcut symbols
if isinstance(expr, PsSymbolExpr): if isinstance(expr, PsSymbolExpr):
return expr, False return expr, False
subtree_results = [ subtree_results = [
self.visit_expr(cast(PsExpression, c)) for c in expr.children self.visit_expr(cast(PsExpression, c), ecc) for c in expr.children
] ]
expr.children = [r[0] for r in subtree_results] expr.children = [r[0] for r in subtree_results]
subtree_constness = [r[1] for r in subtree_results] subtree_constness = [r[1] for r in subtree_results]
...@@ -91,7 +158,7 @@ class EliminateConstants: ...@@ -91,7 +158,7 @@ class EliminateConstants:
# Additive idempotence: Subtraction from zero # Additive idempotence: Subtraction from zero
case PsSub(PsConstantExpr(c), other_op) if c.value == 0: case PsSub(PsConstantExpr(c), other_op) if c.value == 0:
other_transformed, is_const = self.visit_expr(-other_op) other_transformed, is_const = self.visit_expr(-other_op, ecc)
return other_transformed, is_const return other_transformed, is_const
# Multiplicative idempotence: Multiplication with and division by one # Multiplicative idempotence: Multiplication with and division by one
...@@ -155,7 +222,14 @@ class EliminateConstants: ...@@ -155,7 +222,14 @@ class EliminateConstants:
expr.operand1 = op1_transformed expr.operand1 = op1_transformed
expr.operand2 = op2_transformed expr.operand2 = op2_transformed
return expr, True return expr, True
# end if: no constant expressions encountered # end if: this expression is not constant
# If required, extract constant subexpressions
if self._extract_constant_exprs:
for i, (child, is_const) in enumerate(subtree_results):
if is_const and not isinstance(child, PsConstantExpr):
replacement = ecc.extract_expression(child)
expr.set_child(i, replacement)
# Any other expressions are not considered constant even if their arguments are # Any other expressions are not considered constant even if their arguments are
return expr, False return expr, False
...@@ -82,8 +82,11 @@ def create_kernel( ...@@ -82,8 +82,11 @@ def create_kernel(
kernel_ast = platform.materialize_iteration_space(kernel_body, ispace) kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
# Simplifying transformations # Simplifying transformations
kernel_ast = cast(PsBlock, EliminateConstants(ctx)(kernel_ast)) elim_constants = EliminateConstants(ctx, extract_constant_exprs=True)
kernel_ast = cast(PsBlock, EraseAnonymousStructTypes(ctx)(kernel_ast)) kernel_ast = cast(PsBlock, elim_constants(kernel_ast))
erase_anons = EraseAnonymousStructTypes(ctx)
kernel_ast = cast(PsBlock, erase_anons(kernel_ast))
# 7. Apply optimizations # 7. Apply optimizations
# - Vectorization # - Vectorization
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment