From ed9315aff9ce95b69653f5049bbc33fcc31b38b4 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 19 Jan 2024 21:37:39 +0100 Subject: [PATCH] simplify children members of AST. introduce conditional branch. --- src/pystencils/nbackend/ast/__init__.py | 2 + src/pystencils/nbackend/ast/collectors.py | 2 +- src/pystencils/nbackend/ast/kernelfunction.py | 13 +- src/pystencils/nbackend/ast/nodes.py | 122 +++++++++++------- .../nbackend/ast/transformations.py | 2 +- src/pystencils/nbackend/ast/util.py | 6 +- src/pystencils/nbackend/emission.py | 45 +++++-- .../nbackend/jit/cpu_extension_module.py | 1 - 8 files changed, 118 insertions(+), 75 deletions(-) diff --git a/src/pystencils/nbackend/ast/__init__.py b/src/pystencils/nbackend/ast/__init__.py index daee7214f..c7731f37a 100644 --- a/src/pystencils/nbackend/ast/__init__.py +++ b/src/pystencils/nbackend/ast/__init__.py @@ -7,6 +7,7 @@ from .nodes import ( PsAssignment, PsDeclaration, PsLoop, + PsConditional, ) from .kernelfunction import PsKernelFunction @@ -24,5 +25,6 @@ __all__ = [ "PsAssignment", "PsDeclaration", "PsLoop", + "PsConditional", "ast_subs" ] diff --git a/src/pystencils/nbackend/ast/collectors.py b/src/pystencils/nbackend/ast/collectors.py index 65bc14d4f..3c995ccf6 100644 --- a/src/pystencils/nbackend/ast/collectors.py +++ b/src/pystencils/nbackend/ast/collectors.py @@ -102,7 +102,7 @@ class RequiredHeadersCollector(Collector): case PsExpression(expr): return self.rec(expr) case node: - return reduce(set.union, (self(c) for c in node.children()), set()) + return reduce(set.union, (self(c) for c in node.children), set()) def map_typed_variable(self, var: PsTypedVariable) -> set[str]: return var.dtype.required_headers diff --git a/src/pystencils/nbackend/ast/kernelfunction.py b/src/pystencils/nbackend/ast/kernelfunction.py index 2729d25a6..58f2ddf36 100644 --- a/src/pystencils/nbackend/ast/kernelfunction.py +++ b/src/pystencils/nbackend/ast/kernelfunction.py @@ -1,6 +1,5 @@ from __future__ import annotations -from typing import Generator from dataclasses import dataclass from pymbolic.mapper.dependency import DependencyMapper @@ -104,16 +103,8 @@ class PsKernelFunction(PsAstNode): """For backward compatibility""" return None - def num_children(self) -> int: - return 1 - - def children(self) -> Generator[PsAstNode, None, None]: - yield from (self._body,) - - def get_child(self, idx: int): - if idx not in (0, -1): - raise IndexError(f"Child index out of bounds: {idx}") - return self._body + def get_children(self) -> tuple[PsAstNode, ...]: + return (self._body,) def set_child(self, idx: int, c: PsAstNode): if idx not in (0, -1): diff --git a/src/pystencils/nbackend/ast/nodes.py b/src/pystencils/nbackend/ast/nodes.py index 2944e073d..5a20e5835 100644 --- a/src/pystencils/nbackend/ast/nodes.py +++ b/src/pystencils/nbackend/ast/nodes.py @@ -1,5 +1,6 @@ from __future__ import annotations -from typing import Sequence, Generator, Iterable, cast, TypeAlias +from typing import Sequence, Iterable, cast, TypeAlias +from types import NoneType from abc import ABC, abstractmethod @@ -12,32 +13,28 @@ class PsAstNode(ABC): """Base class for all nodes in the pystencils AST. This base class provides a common interface to inspect and update the AST's branching structure. - The four methods `num_children`, `children`, `get_child` and `set_child` must be implemented by - each subclass. + The two methods `get_children` and `set_child` must be implemented by each subclass. Subclasses are also responsible for doing the necessary type checks if they place restrictions on the types of their children. """ - @abstractmethod - def num_children(self) -> int: - ... + @property + def children(self) -> tuple[PsAstNode, ...]: + return self.get_children() - @abstractmethod - def children(self) -> Generator[PsAstNode, None, None]: - ... + @children.setter + def children(self, cs: Iterable[PsAstNode]): + for i, c in enumerate(cs): + self.set_child(i, c) @abstractmethod - def get_child(self, idx: int): + def get_children(self) -> tuple[PsAstNode, ...]: ... @abstractmethod def set_child(self, idx: int, c: PsAstNode): ... - def set_children(self, cs: Iterable[PsAstNode]): - for i, c in enumerate(cs): - self.set_child(i, c) - class PsBlock(PsAstNode): __match_args__ = ("statements",) @@ -45,14 +42,8 @@ class PsBlock(PsAstNode): def __init__(self, cs: Sequence[PsAstNode]): self._statements = list(cs) - def num_children(self) -> int: - return len(self._statements) - - def children(self) -> Generator[PsAstNode, None, None]: - yield from self._statements - - def get_child(self, idx: int): - return self._statements[idx] + def get_children(self) -> tuple[PsAstNode, ...]: + return tuple(self._statements) def set_child(self, idx: int, c: PsAstNode): self._statements[idx] = c @@ -67,14 +58,8 @@ class PsBlock(PsAstNode): class PsLeafNode(PsAstNode): - def num_children(self) -> int: - return 0 - - def children(self) -> Generator[PsAstNode, None, None]: - yield from () - - def get_child(self, idx: int): - raise IndexError("Child index out of bounds: Leaf nodes have no children.") + def get_children(self) -> tuple[PsAstNode, ...]: + return () def set_child(self, idx: int, c: PsAstNode): raise IndexError("Child index out of bounds: Leaf nodes have no children.") @@ -154,14 +139,8 @@ class PsAssignment(PsAstNode): def rhs(self, expr: PsExpression): self._rhs = expr - def num_children(self) -> int: - return 2 - - def children(self) -> Generator[PsAstNode, None, None]: - yield from (self._lhs, self._rhs) - - def get_child(self, idx: int): - return (self._lhs, self._rhs)[idx] + def get_children(self) -> tuple[PsAstNode, ...]: + return (self._lhs, self._rhs) def set_child(self, idx: int, c: PsAstNode): idx = [0, 1][idx] # trick to normalize index @@ -265,14 +244,8 @@ class PsLoop(PsAstNode): def body(self, block: PsBlock): self._body = block - def num_children(self) -> int: - return 5 - - def children(self) -> Generator[PsAstNode, None, None]: - yield from (self._ctr, self._start, self._stop, self._step, self._body) - - def get_child(self, idx: int): - return (self._ctr, self._start, self._stop, self._step, self._body)[idx] + def get_children(self) -> tuple[PsAstNode, ...]: + return (self._ctr, self._start, self._stop, self._step, self._body) def set_child(self, idx: int, c: PsAstNode): idx = list(range(5))[idx] @@ -289,3 +262,60 @@ class PsLoop(PsAstNode): self._body = failing_cast(PsBlock, c) case _: assert False, "unreachable code" + + +class PsConditional(PsAstNode): + """Conditional branch""" + + __match_args__ = ("condition", "branch_true", "branch_false") + + def __init__( + self, + cond: PsExpression, + branch_true: PsBlock, + branch_false: PsBlock | None = None, + ): + self._condition = cond + self._branch_true = branch_true + self._branch_false = branch_false + + @property + def condition(self) -> PsExpression: + return self._condition + + @condition.setter + def condition(self, expr: PsExpression): + self._condition = expr + + @property + def branch_true(self) -> PsBlock: + return self._branch_true + + @branch_true.setter + def branch_true(self, block: PsBlock): + self._branch_true = block + + @property + def branch_false(self) -> PsBlock | None: + return self._branch_false + + @branch_false.setter + def branch_false(self, block: PsBlock | None): + self._branch_false = block + + def get_children(self) -> tuple[PsAstNode, ...]: + return (self._condition, self._branch_true) + ( + (self._branch_false,) if self._branch_false is not None else () + ) + + def set_child(self, idx: int, c: PsAstNode): + idx = list(range(3))[idx] + match idx: + case 0: + self._condition = failing_cast(PsExpression, c) + case 1: + self._branch_true = failing_cast(PsBlock, c) + case 2: + self._branch_false = failing_cast((PsBlock, NoneType), c) + case _: + assert False, "unreachable code" diff --git a/src/pystencils/nbackend/ast/transformations.py b/src/pystencils/nbackend/ast/transformations.py index cef242b89..4260e18dc 100644 --- a/src/pystencils/nbackend/ast/transformations.py +++ b/src/pystencils/nbackend/ast/transformations.py @@ -12,7 +12,7 @@ from .nodes import PsAstNode, PsAssignment, PsLoop, PsExpression class PsAstTransformer(ABC): def transform_children(self, node: PsAstNode, *args, **kwargs): - node.set_children(self.visit(c, *args, **kwargs) for c in node.children()) + node.children = tuple(self.visit(c, *args, **kwargs) for c in node.children) @ast_visitor def visit(self, node, *args, **kwargs): diff --git a/src/pystencils/nbackend/ast/util.py b/src/pystencils/nbackend/ast/util.py index aa1866baf..c3d93ed4c 100644 --- a/src/pystencils/nbackend/ast/util.py +++ b/src/pystencils/nbackend/ast/util.py @@ -1,9 +1,7 @@ -from typing import TypeVar +from typing import Any -T = TypeVar("T") - -def failing_cast(target: type, obj: T): +def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any: if not isinstance(obj, target): raise TypeError(f"Casting {obj} to {target} failed.") return obj diff --git a/src/pystencils/nbackend/emission.py b/src/pystencils/nbackend/emission.py index b315b172e..f25070192 100644 --- a/src/pystencils/nbackend/emission.py +++ b/src/pystencils/nbackend/emission.py @@ -2,14 +2,23 @@ from __future__ import annotations from pymbolic.mapper.c_code import CCodeMapper -from .ast import ast_visitor, PsAstNode, PsBlock, PsExpression, PsDeclaration, PsAssignment, PsLoop +from .ast import ( + ast_visitor, + PsAstNode, + PsBlock, + PsExpression, + PsDeclaration, + PsAssignment, + PsLoop, + PsConditional, +) from .ast.kernelfunction import PsKernelFunction def emit_code(kernel: PsKernelFunction): # TODO: Specialize for different targets printer = CPrinter() - return printer.print(kernel) + return printer.print(kernel) class CPrinter: @@ -17,7 +26,6 @@ class CPrinter: self._indent_width = indent_width self._current_indent_level = 0 - self._inside_expression = False # controls parentheses in nested arithmetic expressions self._pb_cmapper = CCodeMapper() @@ -30,7 +38,7 @@ class CPrinter: @ast_visitor def visit(self, _: PsAstNode) -> str: raise ValueError("Cannot print this node.") - + @visit.case(PsKernelFunction) def function(self, func: PsKernelFunction) -> str: params_spec = func.get_parameters() @@ -41,11 +49,11 @@ class CPrinter: @visit.case(PsBlock) def block(self, block: PsBlock): - if not block.children(): + if not block.children: return self.indent("{ }") self._current_indent_level += self._indent_width - interior = "\n".join(self.visit(c) for c in block.children()) + interior = "\n".join(self.visit(c) for c in block.children) self._current_indent_level -= self._indent_width return self.indent("{\n") + interior + self.indent("}\n") @@ -77,8 +85,23 @@ class CPrinter: body_code = self.visit(loop.body) - code = f"for({ctr_symbol.dtype} {ctr} = {start_code};" + \ - f" {ctr} < {stop_code};" + \ - f" {ctr} += {step_code})\n" + \ - body_code - return code + code = ( + f"for({ctr_symbol.dtype} {ctr} = {start_code};" + + f" {ctr} < {stop_code};" + + f" {ctr} += {step_code})\n" + + body_code + ) + return self.indent(code) + + @visit.case(PsConditional) + def conditional(self, node: PsConditional): + cond_code = self.visit(node.condition) + then_code = self.visit(node.branch_true) + + code = f"if({cond_code})\n{then_code}" + + if node.branch_false is not None: + else_code = self.visit(node.branch_false) + code += f"\nelse\n{else_code}" + + return self.indent(code) diff --git a/src/pystencils/nbackend/jit/cpu_extension_module.py b/src/pystencils/nbackend/jit/cpu_extension_module.py index 1859ec661..f2c0d3ff9 100644 --- a/src/pystencils/nbackend/jit/cpu_extension_module.py +++ b/src/pystencils/nbackend/jit/cpu_extension_module.py @@ -22,7 +22,6 @@ from ..arrays import ( ) from ..types import ( PsAbstractType, - PsScalarType, PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, -- GitLab