diff --git a/pystencils/nbackend/ast/__init__.py b/pystencils/nbackend/ast/__init__.py index d8d4c33b8709b7cc2cf50369b48bb88b2e5bedca..567840b2d61724b48dab7ca0158f0355cfd797cf 100644 --- a/pystencils/nbackend/ast/__init__.py +++ b/pystencils/nbackend/ast/__init__.py @@ -1,13 +1,26 @@ from .nodes import ( - PsAstNode, PsBlock, PsExpression, PsLvalueExpr, PsSymbolExpr, - PsAssignment, PsDeclaration, PsLoop + PsAstNode, + PsBlock, + PsExpression, + PsLvalueExpr, + PsSymbolExpr, + PsAssignment, + PsDeclaration, + PsLoop, ) from .dispatcher import ast_visitor from .transformations import ast_subs __all__ = [ - ast_visitor, - PsAstNode, PsBlock, PsExpression, PsLvalueExpr, PsSymbolExpr, PsAssignment, PsDeclaration, PsLoop, - ast_subs + "ast_visitor", + "PsAstNode", + "PsBlock", + "PsExpression", + "PsLvalueExpr", + "PsSymbolExpr", + "PsAssignment", + "PsDeclaration", + "PsLoop", + "ast_subs", ] diff --git a/pystencils/nbackend/ast/nodes.py b/pystencils/nbackend/ast/nodes.py index 584455db4f44dccd02931923ccd5d799e557df6b..7459de7aadae49948c46a2479f770400e36d8321 100644 --- a/pystencils/nbackend/ast/nodes.py +++ b/pystencils/nbackend/ast/nodes.py @@ -1,58 +1,85 @@ from __future__ import annotations -from typing import Sequence, Generator +from typing import Sequence, Generator, TypeVar, Iterable, cast -from abc import ABC +from abc import ABC, abstractmethod import pymbolic.primitives as pb -from ..typed_expressions import PsTypedVariable, PsLvalue +from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue -class PsAstNode(ABC): - """Base class for all nodes in the pystencils AST.""" +T = TypeVar("T") +def failing_cast(target: type, obj: T): + if not isinstance(obj, target): + raise TypeError(f"Casting {obj} to {target} failed.") + return obj - def __init__(self, *children: Sequence[PsAstNode]): - for c in children: - if not isinstance(c, PsAstNode): - raise TypeError(f"Child {c} was not a PsAstNode.") - self._children = list(children) - @property +class PsAstNode(ABC): + """Base class for all nodes in the pystencils AST. + + This base class provides a common interface to inspect and update the AST's branching structure. + The four methods `num_children`, `children`, `get_child` and `set_child` must be implemented by + each subclass. + Subclasses are also responsible for doing the necessary type checks if they place restrictions on + the types of their children. + """ + + @abstractmethod + def num_children(self) -> int: + ... + + @abstractmethod def children(self) -> Generator[PsAstNode, None, None]: - yield from self._children + ... - def child(self, idx: int): - return self._children[idx] + @abstractmethod + def get_child(self, idx: int): + ... - @children.setter - def children(self, cs: Sequence[PsAstNode]): - if len(cs) != len(self._children): - raise ValueError("The number of child nodes must remain the same!") + @abstractmethod + def set_child(self, idx: int, c: PsAstNode): + ... + + def set_children(self, cs: Iterable[PsAstNode]): + for i, c in enumerate(cs): + self.set_child(i, c) + + +class PsBlock(PsAstNode): + def __init__(self, cs: Sequence[PsAstNode]): self._children = list(cs) - def __getitem__(self, idx: int): + def num_children(self) -> int: + return len(self._children) + + def children(self) -> Generator[PsAstNode, None, None]: + yield from self._children + + def get_child(self, idx: int): return self._children[idx] - def __setitem__(self, idx: int, c: PsAstNode): + def set_child(self, idx: int, c: PsAstNode): self._children[idx] = c +class PsLeafNode(PsAstNode): + def num_children(self) -> int: + return 0 -class PsBlock(PsAstNode): - - @property def children(self) -> Generator[PsAstNode, None, None]: - yield from self._children # need to override entire property to override the setter + yield from () - @children.setter - def children(self, cs: Sequence[PsAstNode]): - self._children = cs + def get_child(self, idx: int): + raise IndexError("Child index out of bounds: Leaf nodes have no children.") + def set_child(self, idx: int, c: PsAstNode): + raise IndexError("Child index out of bounds: Leaf nodes have no children.") -class PsExpression(PsAstNode): + +class PsExpression(PsLeafNode): """Wrapper around pymbolics expressions.""" def __init__(self, expr: pb.Expression): - super().__init__() self._expr = expr @property @@ -68,7 +95,7 @@ class PsLvalueExpr(PsExpression): """Wrapper around pymbolics expressions that may occur at the left-hand side of an assignment""" def __init__(self, expr: PsLvalue): - if not isinstance(expr, PsLvalue): + if not isinstance(expr, (PsTypedVariable, PsArrayAccess)): raise TypeError("Expression was not a valid lvalue") super(PsLvalueExpr, self).__init__(expr) @@ -78,52 +105,85 @@ class PsSymbolExpr(PsLvalueExpr): """Wrapper around PsTypedSymbols""" def __init__(self, symbol: PsTypedVariable): - if not isinstance(symbol, PsTypedVariable): - raise TypeError("Not a symbol!") - - super(PsLvalueExpr, self).__init__(symbol) + super().__init__(symbol) @property - def symbol(self) -> PsSymbolExpr: - return self.expression + def symbol(self) -> PsTypedVariable: + return cast(PsTypedVariable, self._expr) @symbol.setter - def symbol(self, symbol: PsSymbolExpr): - self.expression = symbol + def symbol(self, symbol: PsTypedVariable): + self._expr = symbol class PsAssignment(PsAstNode): def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression): - super(PsAssignment, self).__init__(lhs, rhs) + self._lhs = lhs + self._rhs = rhs @property def lhs(self) -> PsLvalueExpr: - return self._children[0] + return self._lhs @lhs.setter def lhs(self, lvalue: PsLvalueExpr): - self._children[0] = lvalue + self._lhs = lvalue @property def rhs(self) -> PsExpression: - return self._children[1] + return self._rhs @rhs.setter def rhs(self, expr: PsExpression): - self._children[1] = expr + self._rhs = expr + + def num_children(self) -> int: + return 2 + + def children(self) -> Generator[PsAstNode, None, None]: + yield from (self._lhs, self._rhs) + + def get_child(self, idx: int): + return (self._lhs, self._rhs)[idx] + + def set_child(self, idx: int, c: PsAstNode): + idx = [0, 1][idx] # trick to normalize index + if idx == 0: + self._lhs = failing_cast(PsLvalueExpr, c) + elif idx == 1: + self._rhs = failing_cast(PsExpression, c) + else: + assert False, "unreachable code" class PsDeclaration(PsAssignment): def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression): - super(PsDeclaration, self).__init__(lhs, rhs) + super().__init__(lhs, rhs) @property - def lhs(self) -> PsSymbolExpr: - return self._children[0] + def lhs(self) -> PsLvalueExpr: + return self._lhs @lhs.setter - def lhs(self, symbol_node: PsSymbolExpr): - self._children[0] = symbol_node + def lhs(self, lvalue: PsLvalueExpr): + self._lhs = failing_cast(PsSymbolExpr, lvalue) + + @property + def declared_symbol(self) -> PsSymbolExpr: + return cast(PsSymbolExpr, self._lhs) + + @declared_symbol.setter + def declared_symbol(self, lvalue: PsSymbolExpr): + self._lhs = lvalue + + def set_child(self, idx: int, c: PsAstNode): + idx = [0, 1][idx] # trick to normalize index + if idx == 0: + self._lhs = failing_cast(PsSymbolExpr, c) + elif idx == 1: + self._rhs = failing_cast(PsExpression, c) + else: + assert False, "unreachable code" class PsLoop(PsAstNode): @@ -133,40 +193,68 @@ class PsLoop(PsAstNode): stop: PsExpression, step: PsExpression, body: PsBlock): - super(PsLoop, self).__init__(ctr, start, stop, step, body) + self._ctr = ctr + self._start = start + self._stop = stop + self._step = step + self._body = body @property def counter(self) -> PsSymbolExpr: - return self._children[0] + return self._ctr + + @counter.setter + def counter(self, expr: PsSymbolExpr): + self._ctr = expr @property def start(self) -> PsExpression: - return self._children[1] + return self._start @start.setter def start(self, expr: PsExpression): - self._children[1] = expr + self._start = expr @property def stop(self) -> PsExpression: - return self._children[2] + return self._stop @stop.setter def stop(self, expr: PsExpression): - self._children[2] = expr + self._stop = expr @property def step(self) -> PsExpression: - return self._children[3] + return self._step @step.setter def step(self, expr: PsExpression): - self._children[3] = expr + self._step = expr @property def body(self) -> PsBlock: - return self._children[4] + return self._body @body.setter def body(self, block: PsBlock): - self._children[4] = block + self._body = block + + def num_children(self) -> int: + return 5 + + def children(self) -> Generator[PsAstNode, None, None]: + yield from (self._ctr, self._start, self._stop, self._step, self._body) + + def get_child(self, idx: int): + return (self._ctr, self._start, self._stop, self._step, self._body)[idx] + + + def set_child(self, idx: int, c: PsAstNode): + idx = list(range(5))[idx] + match idx: + case 0: self._ctr = failing_cast(PsSymbolExpr, c) + case 1: self._start = failing_cast(PsExpression, c) + case 2: self._stop = failing_cast(PsExpression, c) + case 3: self._step = failing_cast(PsExpression, c) + case 4: self._body = failing_cast(PsBlock, c) + case _: assert False, "unreachable code" diff --git a/pystencils/nbackend/ast/transformations.py b/pystencils/nbackend/ast/transformations.py index b75b13fab04d4cefa08151300843aad023a1290e..cef242b89eaea6b4f468dba97c0cc6d586a147f3 100644 --- a/pystencils/nbackend/ast/transformations.py +++ b/pystencils/nbackend/ast/transformations.py @@ -12,7 +12,7 @@ from .nodes import PsAstNode, PsAssignment, PsLoop, PsExpression class PsAstTransformer(ABC): def transform_children(self, node: PsAstNode, *args, **kwargs): - node.children = [self.visit(c, *args, **kwargs) for c in node.children] + node.set_children(self.visit(c, *args, **kwargs) for c in node.children()) @ast_visitor def visit(self, node, *args, **kwargs): diff --git a/pystencils/nbackend/c_printer.py b/pystencils/nbackend/c_printer.py index d920fa9dc776c2cd683d5378dbf9ae10aa00ed92..a1f176f8375922578c7d52aa3d501b57e7cb44cd 100644 --- a/pystencils/nbackend/c_printer.py +++ b/pystencils/nbackend/c_printer.py @@ -26,11 +26,11 @@ class CPrinter: @visit.case(PsBlock) def block(self, block: PsBlock): - if not block.children: + if not block.children(): return self.indent("{ }") self._current_indent_level += self._indent_width - interior = "".join(self.visit(c) for c in block.children) + interior = "".join(self.visit(c) for c in block.children()) self._current_indent_level -= self._indent_width return self.indent("{\n") + interior + self.indent("}\n") @@ -40,7 +40,7 @@ class CPrinter: @visit.case(PsDeclaration) def declaration(self, decl: PsDeclaration): - lhs_symb = decl.lhs.symbol + lhs_symb = decl.declared_symbol.symbol lhs_dtype = lhs_symb.dtype rhs_code = self.visit(decl.rhs)