From ed9315aff9ce95b69653f5049bbc33fcc31b38b4 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 19 Jan 2024 21:37:39 +0100
Subject: [PATCH] simplify children members of AST. introduce conditional
 branch.

---
 src/pystencils/nbackend/ast/__init__.py       |   2 +
 src/pystencils/nbackend/ast/collectors.py     |   2 +-
 src/pystencils/nbackend/ast/kernelfunction.py |  13 +-
 src/pystencils/nbackend/ast/nodes.py          | 122 +++++++++++-------
 .../nbackend/ast/transformations.py           |   2 +-
 src/pystencils/nbackend/ast/util.py           |   6 +-
 src/pystencils/nbackend/emission.py           |  45 +++++--
 .../nbackend/jit/cpu_extension_module.py      |   1 -
 8 files changed, 118 insertions(+), 75 deletions(-)

diff --git a/src/pystencils/nbackend/ast/__init__.py b/src/pystencils/nbackend/ast/__init__.py
index daee7214f..c7731f37a 100644
--- a/src/pystencils/nbackend/ast/__init__.py
+++ b/src/pystencils/nbackend/ast/__init__.py
@@ -7,6 +7,7 @@ from .nodes import (
     PsAssignment,
     PsDeclaration,
     PsLoop,
+    PsConditional,
 )
 from .kernelfunction import PsKernelFunction
 
@@ -24,5 +25,6 @@ __all__ = [
     "PsAssignment",
     "PsDeclaration",
     "PsLoop",
+    "PsConditional",
     "ast_subs"
 ]
diff --git a/src/pystencils/nbackend/ast/collectors.py b/src/pystencils/nbackend/ast/collectors.py
index 65bc14d4f..3c995ccf6 100644
--- a/src/pystencils/nbackend/ast/collectors.py
+++ b/src/pystencils/nbackend/ast/collectors.py
@@ -102,7 +102,7 @@ class RequiredHeadersCollector(Collector):
             case PsExpression(expr):
                 return self.rec(expr)
             case node:
-                return reduce(set.union, (self(c) for c in node.children()), set())
+                return reduce(set.union, (self(c) for c in node.children), set())
 
     def map_typed_variable(self, var: PsTypedVariable) -> set[str]:
         return var.dtype.required_headers
diff --git a/src/pystencils/nbackend/ast/kernelfunction.py b/src/pystencils/nbackend/ast/kernelfunction.py
index 2729d25a6..58f2ddf36 100644
--- a/src/pystencils/nbackend/ast/kernelfunction.py
+++ b/src/pystencils/nbackend/ast/kernelfunction.py
@@ -1,6 +1,5 @@
 from __future__ import annotations
 
-from typing import Generator
 from dataclasses import dataclass
 
 from pymbolic.mapper.dependency import DependencyMapper
@@ -104,16 +103,8 @@ class PsKernelFunction(PsAstNode):
         """For backward compatibility"""
         return None
 
-    def num_children(self) -> int:
-        return 1
-
-    def children(self) -> Generator[PsAstNode, None, None]:
-        yield from (self._body,)
-
-    def get_child(self, idx: int):
-        if idx not in (0, -1):
-            raise IndexError(f"Child index out of bounds: {idx}")
-        return self._body
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return (self._body,)
 
     def set_child(self, idx: int, c: PsAstNode):
         if idx not in (0, -1):
diff --git a/src/pystencils/nbackend/ast/nodes.py b/src/pystencils/nbackend/ast/nodes.py
index 2944e073d..5a20e5835 100644
--- a/src/pystencils/nbackend/ast/nodes.py
+++ b/src/pystencils/nbackend/ast/nodes.py
@@ -1,5 +1,6 @@
 from __future__ import annotations
-from typing import Sequence, Generator, Iterable, cast, TypeAlias
+from typing import Sequence, Iterable, cast, TypeAlias
+from types import NoneType
 
 from abc import ABC, abstractmethod
 
@@ -12,32 +13,28 @@ class PsAstNode(ABC):
     """Base class for all nodes in the pystencils AST.
 
     This base class provides a common interface to inspect and update the AST's branching structure.
-    The four methods `num_children`, `children`, `get_child` and `set_child` must be implemented by
-    each subclass.
+    The two methods `get_children` and `set_child` must be implemented by each subclass.
     Subclasses are also responsible for doing the necessary type checks if they place restrictions on
     the types of their children.
     """
 
-    @abstractmethod
-    def num_children(self) -> int:
-        ...
+    @property
+    def children(self) -> tuple[PsAstNode, ...]:
+        return self.get_children()
 
-    @abstractmethod
-    def children(self) -> Generator[PsAstNode, None, None]:
-        ...
+    @children.setter
+    def children(self, cs: Iterable[PsAstNode]):
+        for i, c in enumerate(cs):
+            self.set_child(i, c)
 
     @abstractmethod
-    def get_child(self, idx: int):
+    def get_children(self) -> tuple[PsAstNode, ...]:
         ...
 
     @abstractmethod
     def set_child(self, idx: int, c: PsAstNode):
         ...
 
-    def set_children(self, cs: Iterable[PsAstNode]):
-        for i, c in enumerate(cs):
-            self.set_child(i, c)
-
 
 class PsBlock(PsAstNode):
     __match_args__ = ("statements",)
@@ -45,14 +42,8 @@ class PsBlock(PsAstNode):
     def __init__(self, cs: Sequence[PsAstNode]):
         self._statements = list(cs)
 
-    def num_children(self) -> int:
-        return len(self._statements)
-
-    def children(self) -> Generator[PsAstNode, None, None]:
-        yield from self._statements
-
-    def get_child(self, idx: int):
-        return self._statements[idx]
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return tuple(self._statements)
 
     def set_child(self, idx: int, c: PsAstNode):
         self._statements[idx] = c
@@ -67,14 +58,8 @@ class PsBlock(PsAstNode):
 
 
 class PsLeafNode(PsAstNode):
-    def num_children(self) -> int:
-        return 0
-
-    def children(self) -> Generator[PsAstNode, None, None]:
-        yield from ()
-
-    def get_child(self, idx: int):
-        raise IndexError("Child index out of bounds: Leaf nodes have no children.")
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return ()
 
     def set_child(self, idx: int, c: PsAstNode):
         raise IndexError("Child index out of bounds: Leaf nodes have no children.")
@@ -154,14 +139,8 @@ class PsAssignment(PsAstNode):
     def rhs(self, expr: PsExpression):
         self._rhs = expr
 
-    def num_children(self) -> int:
-        return 2
-
-    def children(self) -> Generator[PsAstNode, None, None]:
-        yield from (self._lhs, self._rhs)
-
-    def get_child(self, idx: int):
-        return (self._lhs, self._rhs)[idx]
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return (self._lhs, self._rhs)
 
     def set_child(self, idx: int, c: PsAstNode):
         idx = [0, 1][idx]  # trick to normalize index
@@ -265,14 +244,8 @@ class PsLoop(PsAstNode):
     def body(self, block: PsBlock):
         self._body = block
 
-    def num_children(self) -> int:
-        return 5
-
-    def children(self) -> Generator[PsAstNode, None, None]:
-        yield from (self._ctr, self._start, self._stop, self._step, self._body)
-
-    def get_child(self, idx: int):
-        return (self._ctr, self._start, self._stop, self._step, self._body)[idx]
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return (self._ctr, self._start, self._stop, self._step, self._body)
 
     def set_child(self, idx: int, c: PsAstNode):
         idx = list(range(5))[idx]
@@ -289,3 +262,60 @@ class PsLoop(PsAstNode):
                 self._body = failing_cast(PsBlock, c)
             case _:
                 assert False, "unreachable code"
+
+
+class PsConditional(PsAstNode):
+    """Conditional branch"""
+
+    __match_args__ = ("condition", "branch_true", "branch_false")
+
+    def __init__(
+        self,
+        cond: PsExpression,
+        branch_true: PsBlock,
+        branch_false: PsBlock | None = None,
+    ):
+        self._condition = cond
+        self._branch_true = branch_true
+        self._branch_false = branch_false
+
+    @property
+    def condition(self) -> PsExpression:
+        return self._condition
+
+    @condition.setter
+    def condition(self, expr: PsExpression):
+        self._condition = expr
+
+    @property
+    def branch_true(self) -> PsBlock:
+        return self._branch_true
+
+    @branch_true.setter
+    def branch_true(self, block: PsBlock):
+        self._branch_true = block
+
+    @property
+    def branch_false(self) -> PsBlock | None:
+        return self._branch_false
+
+    @branch_false.setter
+    def branch_false(self, block: PsBlock | None):
+        self._branch_false = block
+
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return (self._condition, self._branch_true) + (
+            (self._branch_false,) if self._branch_false is not None else ()
+        )
+
+    def set_child(self, idx: int, c: PsAstNode):
+        idx = list(range(3))[idx]
+        match idx:
+            case 0:
+                self._condition = failing_cast(PsExpression, c)
+            case 1:
+                self._branch_true = failing_cast(PsBlock, c)
+            case 2:
+                self._branch_false = failing_cast((PsBlock, NoneType), c)
+            case _:
+                assert False, "unreachable code"
diff --git a/src/pystencils/nbackend/ast/transformations.py b/src/pystencils/nbackend/ast/transformations.py
index cef242b89..4260e18dc 100644
--- a/src/pystencils/nbackend/ast/transformations.py
+++ b/src/pystencils/nbackend/ast/transformations.py
@@ -12,7 +12,7 @@ from .nodes import PsAstNode, PsAssignment, PsLoop, PsExpression
 
 class PsAstTransformer(ABC):
     def transform_children(self, node: PsAstNode, *args, **kwargs):
-        node.set_children(self.visit(c, *args, **kwargs) for c in node.children())
+        node.children = tuple(self.visit(c, *args, **kwargs) for c in node.children)
 
     @ast_visitor
     def visit(self, node, *args, **kwargs):
diff --git a/src/pystencils/nbackend/ast/util.py b/src/pystencils/nbackend/ast/util.py
index aa1866baf..c3d93ed4c 100644
--- a/src/pystencils/nbackend/ast/util.py
+++ b/src/pystencils/nbackend/ast/util.py
@@ -1,9 +1,7 @@
-from typing import TypeVar
+from typing import Any
 
-T = TypeVar("T")
 
-
-def failing_cast(target: type, obj: T):
+def failing_cast(target: type | tuple[type, ...], obj: Any) -> Any:
     if not isinstance(obj, target):
         raise TypeError(f"Casting {obj} to {target} failed.")
     return obj
diff --git a/src/pystencils/nbackend/emission.py b/src/pystencils/nbackend/emission.py
index b315b172e..f25070192 100644
--- a/src/pystencils/nbackend/emission.py
+++ b/src/pystencils/nbackend/emission.py
@@ -2,14 +2,23 @@ from __future__ import annotations
 
 from pymbolic.mapper.c_code import CCodeMapper
 
-from .ast import ast_visitor, PsAstNode, PsBlock, PsExpression, PsDeclaration, PsAssignment, PsLoop
+from .ast import (
+    ast_visitor,
+    PsAstNode,
+    PsBlock,
+    PsExpression,
+    PsDeclaration,
+    PsAssignment,
+    PsLoop,
+    PsConditional,
+)
 from .ast.kernelfunction import PsKernelFunction
 
 
 def emit_code(kernel: PsKernelFunction):
     #   TODO: Specialize for different targets
     printer = CPrinter()
-    return printer.print(kernel)    
+    return printer.print(kernel)
 
 
 class CPrinter:
@@ -17,7 +26,6 @@ class CPrinter:
         self._indent_width = indent_width
 
         self._current_indent_level = 0
-        self._inside_expression = False  # controls parentheses in nested arithmetic expressions
 
         self._pb_cmapper = CCodeMapper()
 
@@ -30,7 +38,7 @@ class CPrinter:
     @ast_visitor
     def visit(self, _: PsAstNode) -> str:
         raise ValueError("Cannot print this node.")
-    
+
     @visit.case(PsKernelFunction)
     def function(self, func: PsKernelFunction) -> str:
         params_spec = func.get_parameters()
@@ -41,11 +49,11 @@ class CPrinter:
 
     @visit.case(PsBlock)
     def block(self, block: PsBlock):
-        if not block.children():
+        if not block.children:
             return self.indent("{ }")
 
         self._current_indent_level += self._indent_width
-        interior = "\n".join(self.visit(c) for c in block.children())
+        interior = "\n".join(self.visit(c) for c in block.children)
         self._current_indent_level -= self._indent_width
         return self.indent("{\n") + interior + self.indent("}\n")
 
@@ -77,8 +85,23 @@ class CPrinter:
 
         body_code = self.visit(loop.body)
 
-        code = f"for({ctr_symbol.dtype} {ctr} = {start_code};" + \
-               f" {ctr} < {stop_code};" + \
-               f" {ctr} += {step_code})\n" + \
-               body_code
-        return code
+        code = (
+            f"for({ctr_symbol.dtype} {ctr} = {start_code};"
+            + f" {ctr} < {stop_code};"
+            + f" {ctr} += {step_code})\n"
+            + body_code
+        )
+        return self.indent(code)
+
+    @visit.case(PsConditional)
+    def conditional(self, node: PsConditional):
+        cond_code = self.visit(node.condition)
+        then_code = self.visit(node.branch_true)
+
+        code = f"if({cond_code})\n{then_code}"
+
+        if node.branch_false is not None:
+            else_code = self.visit(node.branch_false)
+            code += f"\nelse\n{else_code}"
+
+        return self.indent(code)
diff --git a/src/pystencils/nbackend/jit/cpu_extension_module.py b/src/pystencils/nbackend/jit/cpu_extension_module.py
index 1859ec661..f2c0d3ff9 100644
--- a/src/pystencils/nbackend/jit/cpu_extension_module.py
+++ b/src/pystencils/nbackend/jit/cpu_extension_module.py
@@ -22,7 +22,6 @@ from ..arrays import (
 )
 from ..types import (
     PsAbstractType,
-    PsScalarType,
     PsUnsignedIntegerType,
     PsSignedIntegerType,
     PsIeeeFloatType,
-- 
GitLab