Skip to content
Snippets Groups Projects
Commit d2520fd3 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

refactored branching structure, satisfied mypy.

parent 7b552412
Branches
Tags
No related merge requests found
Pipeline #59671 failed with stages
in 2 minutes and 58 seconds
from .nodes import (
PsAstNode, PsBlock, PsExpression, PsLvalueExpr, PsSymbolExpr,
PsAssignment, PsDeclaration, PsLoop
PsAstNode,
PsBlock,
PsExpression,
PsLvalueExpr,
PsSymbolExpr,
PsAssignment,
PsDeclaration,
PsLoop,
)
from .dispatcher import ast_visitor
from .transformations import ast_subs
__all__ = [
ast_visitor,
PsAstNode, PsBlock, PsExpression, PsLvalueExpr, PsSymbolExpr, PsAssignment, PsDeclaration, PsLoop,
ast_subs
"ast_visitor",
"PsAstNode",
"PsBlock",
"PsExpression",
"PsLvalueExpr",
"PsSymbolExpr",
"PsAssignment",
"PsDeclaration",
"PsLoop",
"ast_subs",
]
from __future__ import annotations
from typing import Sequence, Generator
from typing import Sequence, Generator, TypeVar, Iterable, cast
from abc import ABC
from abc import ABC, abstractmethod
import pymbolic.primitives as pb
from ..typed_expressions import PsTypedVariable, PsLvalue
from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue
class PsAstNode(ABC):
"""Base class for all nodes in the pystencils AST."""
T = TypeVar("T")
def failing_cast(target: type, obj: T):
if not isinstance(obj, target):
raise TypeError(f"Casting {obj} to {target} failed.")
return obj
def __init__(self, *children: Sequence[PsAstNode]):
for c in children:
if not isinstance(c, PsAstNode):
raise TypeError(f"Child {c} was not a PsAstNode.")
self._children = list(children)
@property
class PsAstNode(ABC):
"""Base class for all nodes in the pystencils AST.
This base class provides a common interface to inspect and update the AST's branching structure.
The four methods `num_children`, `children`, `get_child` and `set_child` must be implemented by
each subclass.
Subclasses are also responsible for doing the necessary type checks if they place restrictions on
the types of their children.
"""
@abstractmethod
def num_children(self) -> int:
...
@abstractmethod
def children(self) -> Generator[PsAstNode, None, None]:
yield from self._children
...
def child(self, idx: int):
return self._children[idx]
@abstractmethod
def get_child(self, idx: int):
...
@children.setter
def children(self, cs: Sequence[PsAstNode]):
if len(cs) != len(self._children):
raise ValueError("The number of child nodes must remain the same!")
@abstractmethod
def set_child(self, idx: int, c: PsAstNode):
...
def set_children(self, cs: Iterable[PsAstNode]):
for i, c in enumerate(cs):
self.set_child(i, c)
class PsBlock(PsAstNode):
def __init__(self, cs: Sequence[PsAstNode]):
self._children = list(cs)
def __getitem__(self, idx: int):
def num_children(self) -> int:
return len(self._children)
def children(self) -> Generator[PsAstNode, None, None]:
yield from self._children
def get_child(self, idx: int):
return self._children[idx]
def __setitem__(self, idx: int, c: PsAstNode):
def set_child(self, idx: int, c: PsAstNode):
self._children[idx] = c
class PsLeafNode(PsAstNode):
def num_children(self) -> int:
return 0
class PsBlock(PsAstNode):
@property
def children(self) -> Generator[PsAstNode, None, None]:
yield from self._children # need to override entire property to override the setter
yield from ()
@children.setter
def children(self, cs: Sequence[PsAstNode]):
self._children = cs
def get_child(self, idx: int):
raise IndexError("Child index out of bounds: Leaf nodes have no children.")
def set_child(self, idx: int, c: PsAstNode):
raise IndexError("Child index out of bounds: Leaf nodes have no children.")
class PsExpression(PsAstNode):
class PsExpression(PsLeafNode):
"""Wrapper around pymbolics expressions."""
def __init__(self, expr: pb.Expression):
super().__init__()
self._expr = expr
@property
......@@ -68,7 +95,7 @@ class PsLvalueExpr(PsExpression):
"""Wrapper around pymbolics expressions that may occur at the left-hand side of an assignment"""
def __init__(self, expr: PsLvalue):
if not isinstance(expr, PsLvalue):
if not isinstance(expr, (PsTypedVariable, PsArrayAccess)):
raise TypeError("Expression was not a valid lvalue")
super(PsLvalueExpr, self).__init__(expr)
......@@ -78,52 +105,85 @@ class PsSymbolExpr(PsLvalueExpr):
"""Wrapper around PsTypedSymbols"""
def __init__(self, symbol: PsTypedVariable):
if not isinstance(symbol, PsTypedVariable):
raise TypeError("Not a symbol!")
super(PsLvalueExpr, self).__init__(symbol)
super().__init__(symbol)
@property
def symbol(self) -> PsSymbolExpr:
return self.expression
def symbol(self) -> PsTypedVariable:
return cast(PsTypedVariable, self._expr)
@symbol.setter
def symbol(self, symbol: PsSymbolExpr):
self.expression = symbol
def symbol(self, symbol: PsTypedVariable):
self._expr = symbol
class PsAssignment(PsAstNode):
def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression):
super(PsAssignment, self).__init__(lhs, rhs)
self._lhs = lhs
self._rhs = rhs
@property
def lhs(self) -> PsLvalueExpr:
return self._children[0]
return self._lhs
@lhs.setter
def lhs(self, lvalue: PsLvalueExpr):
self._children[0] = lvalue
self._lhs = lvalue
@property
def rhs(self) -> PsExpression:
return self._children[1]
return self._rhs
@rhs.setter
def rhs(self, expr: PsExpression):
self._children[1] = expr
self._rhs = expr
def num_children(self) -> int:
return 2
def children(self) -> Generator[PsAstNode, None, None]:
yield from (self._lhs, self._rhs)
def get_child(self, idx: int):
return (self._lhs, self._rhs)[idx]
def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1][idx] # trick to normalize index
if idx == 0:
self._lhs = failing_cast(PsLvalueExpr, c)
elif idx == 1:
self._rhs = failing_cast(PsExpression, c)
else:
assert False, "unreachable code"
class PsDeclaration(PsAssignment):
def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression):
super(PsDeclaration, self).__init__(lhs, rhs)
super().__init__(lhs, rhs)
@property
def lhs(self) -> PsSymbolExpr:
return self._children[0]
def lhs(self) -> PsLvalueExpr:
return self._lhs
@lhs.setter
def lhs(self, symbol_node: PsSymbolExpr):
self._children[0] = symbol_node
def lhs(self, lvalue: PsLvalueExpr):
self._lhs = failing_cast(PsSymbolExpr, lvalue)
@property
def declared_symbol(self) -> PsSymbolExpr:
return cast(PsSymbolExpr, self._lhs)
@declared_symbol.setter
def declared_symbol(self, lvalue: PsSymbolExpr):
self._lhs = lvalue
def set_child(self, idx: int, c: PsAstNode):
idx = [0, 1][idx] # trick to normalize index
if idx == 0:
self._lhs = failing_cast(PsSymbolExpr, c)
elif idx == 1:
self._rhs = failing_cast(PsExpression, c)
else:
assert False, "unreachable code"
class PsLoop(PsAstNode):
......@@ -133,40 +193,68 @@ class PsLoop(PsAstNode):
stop: PsExpression,
step: PsExpression,
body: PsBlock):
super(PsLoop, self).__init__(ctr, start, stop, step, body)
self._ctr = ctr
self._start = start
self._stop = stop
self._step = step
self._body = body
@property
def counter(self) -> PsSymbolExpr:
return self._children[0]
return self._ctr
@counter.setter
def counter(self, expr: PsSymbolExpr):
self._ctr = expr
@property
def start(self) -> PsExpression:
return self._children[1]
return self._start
@start.setter
def start(self, expr: PsExpression):
self._children[1] = expr
self._start = expr
@property
def stop(self) -> PsExpression:
return self._children[2]
return self._stop
@stop.setter
def stop(self, expr: PsExpression):
self._children[2] = expr
self._stop = expr
@property
def step(self) -> PsExpression:
return self._children[3]
return self._step
@step.setter
def step(self, expr: PsExpression):
self._children[3] = expr
self._step = expr
@property
def body(self) -> PsBlock:
return self._children[4]
return self._body
@body.setter
def body(self, block: PsBlock):
self._children[4] = block
self._body = block
def num_children(self) -> int:
return 5
def children(self) -> Generator[PsAstNode, None, None]:
yield from (self._ctr, self._start, self._stop, self._step, self._body)
def get_child(self, idx: int):
return (self._ctr, self._start, self._stop, self._step, self._body)[idx]
def set_child(self, idx: int, c: PsAstNode):
idx = list(range(5))[idx]
match idx:
case 0: self._ctr = failing_cast(PsSymbolExpr, c)
case 1: self._start = failing_cast(PsExpression, c)
case 2: self._stop = failing_cast(PsExpression, c)
case 3: self._step = failing_cast(PsExpression, c)
case 4: self._body = failing_cast(PsBlock, c)
case _: assert False, "unreachable code"
......@@ -12,7 +12,7 @@ from .nodes import PsAstNode, PsAssignment, PsLoop, PsExpression
class PsAstTransformer(ABC):
def transform_children(self, node: PsAstNode, *args, **kwargs):
node.children = [self.visit(c, *args, **kwargs) for c in node.children]
node.set_children(self.visit(c, *args, **kwargs) for c in node.children())
@ast_visitor
def visit(self, node, *args, **kwargs):
......
......@@ -26,11 +26,11 @@ class CPrinter:
@visit.case(PsBlock)
def block(self, block: PsBlock):
if not block.children:
if not block.children():
return self.indent("{ }")
self._current_indent_level += self._indent_width
interior = "".join(self.visit(c) for c in block.children)
interior = "".join(self.visit(c) for c in block.children())
self._current_indent_level -= self._indent_width
return self.indent("{\n") + interior + self.indent("}\n")
......@@ -40,7 +40,7 @@ class CPrinter:
@visit.case(PsDeclaration)
def declaration(self, decl: PsDeclaration):
lhs_symb = decl.lhs.symbol
lhs_symb = decl.declared_symbol.symbol
lhs_dtype = lhs_symb.dtype
rhs_code = self.visit(decl.rhs)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment