From d2520fd3a955ab107c23e81b2a6c69c54296f21a Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 10 Jan 2024 15:23:48 +0100
Subject: [PATCH] refactored branching structure, satisfied mypy.

---
 pystencils/nbackend/ast/__init__.py        |  23 ++-
 pystencils/nbackend/ast/nodes.py           | 204 +++++++++++++++------
 pystencils/nbackend/ast/transformations.py |   2 +-
 pystencils/nbackend/c_printer.py           |   6 +-
 4 files changed, 168 insertions(+), 67 deletions(-)

diff --git a/pystencils/nbackend/ast/__init__.py b/pystencils/nbackend/ast/__init__.py
index d8d4c33b8..567840b2d 100644
--- a/pystencils/nbackend/ast/__init__.py
+++ b/pystencils/nbackend/ast/__init__.py
@@ -1,13 +1,26 @@
 from .nodes import (
-    PsAstNode, PsBlock, PsExpression, PsLvalueExpr, PsSymbolExpr,
-    PsAssignment, PsDeclaration, PsLoop
+    PsAstNode,
+    PsBlock,
+    PsExpression,
+    PsLvalueExpr,
+    PsSymbolExpr,
+    PsAssignment,
+    PsDeclaration,
+    PsLoop,
 )
 
 from .dispatcher import ast_visitor
 from .transformations import ast_subs
 
 __all__ = [
-    ast_visitor,
-    PsAstNode, PsBlock, PsExpression, PsLvalueExpr, PsSymbolExpr, PsAssignment, PsDeclaration, PsLoop,
-    ast_subs
+    "ast_visitor",
+    "PsAstNode",
+    "PsBlock",
+    "PsExpression",
+    "PsLvalueExpr",
+    "PsSymbolExpr",
+    "PsAssignment",
+    "PsDeclaration",
+    "PsLoop",
+    "ast_subs",
 ]
diff --git a/pystencils/nbackend/ast/nodes.py b/pystencils/nbackend/ast/nodes.py
index 584455db4..7459de7aa 100644
--- a/pystencils/nbackend/ast/nodes.py
+++ b/pystencils/nbackend/ast/nodes.py
@@ -1,58 +1,85 @@
 from __future__ import annotations
-from typing import Sequence, Generator
+from typing import Sequence, Generator, TypeVar, Iterable, cast
 
-from abc import ABC
+from abc import ABC, abstractmethod
 
 import pymbolic.primitives as pb
 
-from ..typed_expressions import PsTypedVariable, PsLvalue
+from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue
 
 
-class PsAstNode(ABC):
-    """Base class for all nodes in the pystencils AST."""
+T = TypeVar("T")
+def failing_cast(target: type, obj: T):
+    if not isinstance(obj, target):
+        raise TypeError(f"Casting {obj} to {target} failed.")
+    return obj
 
-    def __init__(self, *children: Sequence[PsAstNode]):
-        for c in children:
-            if not isinstance(c, PsAstNode):
-                raise TypeError(f"Child {c} was not a PsAstNode.")
-        self._children = list(children)
 
-    @property
+class PsAstNode(ABC):
+    """Base class for all nodes in the pystencils AST.
+    
+    This base class provides a common interface to inspect and update the AST's branching structure.
+    The four methods `num_children`, `children`, `get_child` and `set_child` must be implemented by
+    each subclass.
+    Subclasses are also responsible for doing the necessary type checks if they place restrictions on
+    the types of their children.
+    """
+
+    @abstractmethod
+    def num_children(self) -> int:
+        ...
+
+    @abstractmethod
     def children(self) -> Generator[PsAstNode, None, None]:
-        yield from self._children
+        ...
 
-    def child(self, idx: int):
-        return self._children[idx]
+    @abstractmethod
+    def get_child(self, idx: int):
+        ...
 
-    @children.setter
-    def children(self, cs: Sequence[PsAstNode]):
-        if len(cs) != len(self._children):
-            raise ValueError("The number of child nodes must remain the same!")
+    @abstractmethod
+    def set_child(self, idx: int, c: PsAstNode):
+        ...
+
+    def set_children(self, cs: Iterable[PsAstNode]):
+        for i, c in enumerate(cs):
+            self.set_child(i, c)
+
+
+class PsBlock(PsAstNode):
+    def __init__(self, cs: Sequence[PsAstNode]):
         self._children = list(cs)
 
-    def __getitem__(self, idx: int):
+    def num_children(self) -> int:
+        return len(self._children)
+
+    def children(self) -> Generator[PsAstNode, None, None]:
+        yield from self._children
+
+    def get_child(self, idx: int):
         return self._children[idx]
 
-    def __setitem__(self, idx: int, c: PsAstNode):
+    def set_child(self, idx: int, c: PsAstNode):
         self._children[idx] = c
 
+class PsLeafNode(PsAstNode):
+    def num_children(self) -> int:
+        return 0
 
-class PsBlock(PsAstNode):
-
-    @property
     def children(self) -> Generator[PsAstNode, None, None]:
-        yield from self._children  # need to override entire property to override the setter
+        yield from ()
 
-    @children.setter
-    def children(self, cs: Sequence[PsAstNode]):
-        self._children = cs
+    def get_child(self, idx: int):
+        raise IndexError("Child index out of bounds: Leaf nodes have no children.")
 
+    def set_child(self, idx: int, c: PsAstNode):
+        raise IndexError("Child index out of bounds: Leaf nodes have no children.")
 
-class PsExpression(PsAstNode):
+
+class PsExpression(PsLeafNode):
     """Wrapper around pymbolics expressions."""
 
     def __init__(self, expr: pb.Expression):
-        super().__init__()
         self._expr = expr
 
     @property
@@ -68,7 +95,7 @@ class PsLvalueExpr(PsExpression):
     """Wrapper around pymbolics expressions that may occur at the left-hand side of an assignment"""
 
     def __init__(self, expr: PsLvalue):
-        if not isinstance(expr, PsLvalue):
+        if not isinstance(expr, (PsTypedVariable, PsArrayAccess)):
             raise TypeError("Expression was not a valid lvalue")
 
         super(PsLvalueExpr, self).__init__(expr)
@@ -78,52 +105,85 @@ class PsSymbolExpr(PsLvalueExpr):
     """Wrapper around PsTypedSymbols"""
 
     def __init__(self, symbol: PsTypedVariable):
-        if not isinstance(symbol, PsTypedVariable):
-            raise TypeError("Not a symbol!")
-
-        super(PsLvalueExpr, self).__init__(symbol)
+        super().__init__(symbol)
 
     @property
-    def symbol(self) -> PsSymbolExpr:
-        return self.expression
+    def symbol(self) -> PsTypedVariable:
+        return cast(PsTypedVariable, self._expr)
 
     @symbol.setter
-    def symbol(self, symbol: PsSymbolExpr):
-        self.expression = symbol
+    def symbol(self, symbol: PsTypedVariable):
+        self._expr = symbol
 
 
 class PsAssignment(PsAstNode):
     def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression):
-        super(PsAssignment, self).__init__(lhs, rhs)
+        self._lhs = lhs
+        self._rhs = rhs
 
     @property
     def lhs(self) -> PsLvalueExpr:
-        return self._children[0]
+        return self._lhs
 
     @lhs.setter
     def lhs(self, lvalue: PsLvalueExpr):
-        self._children[0] = lvalue
+        self._lhs = lvalue
 
     @property
     def rhs(self) -> PsExpression:
-        return self._children[1]
+        return self._rhs
 
     @rhs.setter
     def rhs(self, expr: PsExpression):
-        self._children[1] = expr
+        self._rhs = expr
+
+    def num_children(self) -> int:
+        return 2
+
+    def children(self) -> Generator[PsAstNode, None, None]:
+        yield from (self._lhs, self._rhs)
+
+    def get_child(self, idx: int):
+        return (self._lhs, self._rhs)[idx]
+
+    def set_child(self, idx: int, c: PsAstNode):
+        idx = [0, 1][idx]   # trick to normalize index
+        if idx == 0:
+            self._lhs = failing_cast(PsLvalueExpr, c)
+        elif idx == 1:
+            self._rhs = failing_cast(PsExpression, c)
+        else:
+            assert False, "unreachable code"
 
 
 class PsDeclaration(PsAssignment):
     def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression):
-        super(PsDeclaration, self).__init__(lhs, rhs)
+        super().__init__(lhs, rhs)
 
     @property
-    def lhs(self) -> PsSymbolExpr:
-        return self._children[0]
+    def lhs(self) -> PsLvalueExpr:
+        return self._lhs
 
     @lhs.setter
-    def lhs(self, symbol_node: PsSymbolExpr):
-        self._children[0] = symbol_node
+    def lhs(self, lvalue: PsLvalueExpr):
+        self._lhs = failing_cast(PsSymbolExpr, lvalue)
+
+    @property
+    def declared_symbol(self) -> PsSymbolExpr:
+        return cast(PsSymbolExpr, self._lhs)
+
+    @declared_symbol.setter
+    def declared_symbol(self, lvalue: PsSymbolExpr):
+        self._lhs = lvalue
+
+    def set_child(self, idx: int, c: PsAstNode):
+        idx = [0, 1][idx]   # trick to normalize index
+        if idx == 0:
+            self._lhs = failing_cast(PsSymbolExpr, c)
+        elif idx == 1:
+            self._rhs = failing_cast(PsExpression, c)
+        else:
+            assert False, "unreachable code"
 
 
 class PsLoop(PsAstNode):
@@ -133,40 +193,68 @@ class PsLoop(PsAstNode):
                  stop: PsExpression,
                  step: PsExpression,
                  body: PsBlock):
-        super(PsLoop, self).__init__(ctr, start, stop, step, body)
+        self._ctr = ctr
+        self._start = start
+        self._stop = stop
+        self._step = step
+        self._body = body
 
     @property
     def counter(self) -> PsSymbolExpr:
-        return self._children[0]
+        return self._ctr
+    
+    @counter.setter
+    def counter(self, expr: PsSymbolExpr):
+        self._ctr = expr
 
     @property
     def start(self) -> PsExpression:
-        return self._children[1]
+        return self._start
 
     @start.setter
     def start(self, expr: PsExpression):
-        self._children[1] = expr
+        self._start = expr
 
     @property
     def stop(self) -> PsExpression:
-        return self._children[2]
+        return self._stop
 
     @stop.setter
     def stop(self, expr: PsExpression):
-        self._children[2] = expr
+        self._stop = expr
 
     @property
     def step(self) -> PsExpression:
-        return self._children[3]
+        return self._step
 
     @step.setter
     def step(self, expr: PsExpression):
-        self._children[3] = expr
+        self._step = expr
 
     @property
     def body(self) -> PsBlock:
-        return self._children[4]
+        return self._body
 
     @body.setter
     def body(self, block: PsBlock):
-        self._children[4] = block
+        self._body = block
+
+    def num_children(self) -> int:
+        return 5
+
+    def children(self) -> Generator[PsAstNode, None, None]:
+        yield from (self._ctr, self._start, self._stop, self._step, self._body)
+
+    def get_child(self, idx: int):
+        return (self._ctr, self._start, self._stop, self._step, self._body)[idx]
+
+
+    def set_child(self, idx: int, c: PsAstNode):
+        idx = list(range(5))[idx]
+        match idx:
+            case 0: self._ctr = failing_cast(PsSymbolExpr, c)
+            case 1: self._start = failing_cast(PsExpression, c)
+            case 2: self._stop = failing_cast(PsExpression, c)
+            case 3: self._step = failing_cast(PsExpression, c)
+            case 4: self._body = failing_cast(PsBlock, c)
+            case _: assert False, "unreachable code"
diff --git a/pystencils/nbackend/ast/transformations.py b/pystencils/nbackend/ast/transformations.py
index b75b13fab..cef242b89 100644
--- a/pystencils/nbackend/ast/transformations.py
+++ b/pystencils/nbackend/ast/transformations.py
@@ -12,7 +12,7 @@ from .nodes import PsAstNode, PsAssignment, PsLoop, PsExpression
 
 class PsAstTransformer(ABC):
     def transform_children(self, node: PsAstNode, *args, **kwargs):
-        node.children = [self.visit(c, *args, **kwargs) for c in node.children]
+        node.set_children(self.visit(c, *args, **kwargs) for c in node.children())
 
     @ast_visitor
     def visit(self, node, *args, **kwargs):
diff --git a/pystencils/nbackend/c_printer.py b/pystencils/nbackend/c_printer.py
index d920fa9dc..a1f176f83 100644
--- a/pystencils/nbackend/c_printer.py
+++ b/pystencils/nbackend/c_printer.py
@@ -26,11 +26,11 @@ class CPrinter:
 
     @visit.case(PsBlock)
     def block(self, block: PsBlock):
-        if not block.children:
+        if not block.children():
             return self.indent("{ }")
 
         self._current_indent_level += self._indent_width
-        interior = "".join(self.visit(c) for c in block.children)
+        interior = "".join(self.visit(c) for c in block.children())
         self._current_indent_level -= self._indent_width
         return self.indent("{\n") + interior + self.indent("}\n")
 
@@ -40,7 +40,7 @@ class CPrinter:
 
     @visit.case(PsDeclaration)
     def declaration(self, decl: PsDeclaration):
-        lhs_symb = decl.lhs.symbol
+        lhs_symb = decl.declared_symbol.symbol
         lhs_dtype = lhs_symb.dtype
         rhs_code = self.visit(decl.rhs)
 
-- 
GitLab