From 6d048af1ba83a11f2714d96910e5a1b6a544fb54 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 28 Feb 2024 09:42:42 +0100
Subject: [PATCH] various minor fixes and refactorings

---
 src/pystencils/backend/arrays.py               |  2 ++
 src/pystencils/backend/ast/analysis.py         |  4 +++-
 src/pystencils/backend/emission.py             |  2 +-
 .../backend/kernelcreation/context.py          | 18 ++++++++++++++++++
 .../backend/platforms/generic_cpu.py           |  3 ---
 src/pystencils/backend/platforms/platform.py   |  4 ----
 src/pystencils/backend/platforms/x86.py        | 16 +++++++++++++---
 .../transformations/erase_anonymous_structs.py |  7 +++++++
 src/pystencils/kernelcreation.py               |  7 ++++---
 9 files changed, 48 insertions(+), 15 deletions(-)

diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py
index 586da3799..be159bcae 100644
--- a/src/pystencils/backend/arrays.py
+++ b/src/pystencils/backend/arrays.py
@@ -156,6 +156,7 @@ class PsArrayAssocSymbol(PsSymbol, ABC):
     Instances of this class represent pointers and indexing information bound
     to a particular array.
     """
+
     __match_args__ = ("name", "dtype", "array")
 
     def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray):
@@ -214,6 +215,7 @@ class PsArrayStrideSymbol(PsArrayAssocSymbol):
     Do not instantiate this class yourself, but only use its instances
     as provided by `PsLinearizedArray.strides`.
     """
+
     __match_args__ = ("array", "coordinate", "dtype")
 
     def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py
index 718a9397e..4bd174485 100644
--- a/src/pystencils/backend/ast/analysis.py
+++ b/src/pystencils/backend/ast/analysis.py
@@ -70,7 +70,9 @@ class UndefinedSymbolsCollector:
                 return {symb}
             case _:
                 return reduce(
-                    set.union, (self.visit_expr(cast(PsExpression, c)) for c in expr.children), set()
+                    set.union,
+                    (self.visit_expr(cast(PsExpression, c)) for c in expr.children),
+                    set(),
                 )
 
     def declared_variables(self, node: PsAstNode) -> set[PsSymbol]:
diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py
index 8de626e1f..829ffb53a 100644
--- a/src/pystencils/backend/emission.py
+++ b/src/pystencils/backend/emission.py
@@ -85,7 +85,7 @@ class Ops(Enum):
 class PrinterCtx:
     def __init__(self) -> None:
         self.operator_stack = [Ops.Weakest]
-        self.branch_stack: list[LR] = []
+        self.branch_stack = [LR.Middle]
         self.indent_level = 0
 
     def push_op(self, operator: Ops, branch: LR):
diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index 3bde2a135..ba6574090 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+from typing import Iterable, Iterator
 from itertools import chain
 from types import EllipsisType
 
@@ -24,6 +25,14 @@ class FieldsInKernel:
         self.custom_fields: set[Field] = set()
         self.buffer_fields: set[Field] = set()
 
+    def __iter__(self) -> Iterator:
+        return chain(
+            self.domain_fields,
+            self.index_fields,
+            self.custom_fields,
+            self.buffer_fields,
+        )
+
 
 class KernelCreationContext:
     """Manages the translation process from the SymPy frontend to the backend AST, and collects
@@ -80,6 +89,7 @@ class KernelCreationContext:
         return tuple(self._constraints)
 
     #   Symbols
+
     def get_symbol(self, name: str, dtype: PsAbstractType | None = None) -> PsSymbol:
         if name not in self._symbols:
             symb = PsSymbol(name, None)
@@ -109,6 +119,10 @@ class KernelCreationContext:
 
         self._symbols[old.name] = new
 
+    @property
+    def symbols(self) -> Iterable[PsSymbol]:
+        return self._symbols.values()
+
     #   Fields and Arrays
 
     @property
@@ -214,6 +228,10 @@ class KernelCreationContext:
             if isinstance(symb, PsSymbol):
                 self.add_symbol(symb)
 
+    @property
+    def arrays(self) -> Iterable[PsLinearizedArray]:
+        return self._field_arrays.values()
+
     def get_array(self, field: Field) -> PsLinearizedArray:
         """Retrieve the underlying array for a given field.
 
diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py
index b3a49cb65..8f8c0fab8 100644
--- a/src/pystencils/backend/platforms/generic_cpu.py
+++ b/src/pystencils/backend/platforms/generic_cpu.py
@@ -38,9 +38,6 @@ class GenericCpu(Platform):
         else:
             assert False, "unreachable code"
 
-    def optimize(self, kernel: PsBlock) -> PsBlock:
-        return kernel
-
     #   Internals
 
     def _create_domain_loops(
diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py
index 7c6d3a2ee..3fedf7c01 100644
--- a/src/pystencils/backend/platforms/platform.py
+++ b/src/pystencils/backend/platforms/platform.py
@@ -28,7 +28,3 @@ class Platform(ABC):
         self, block: PsBlock, ispace: IterationSpace
     ) -> PsBlock:
         pass
-
-    @abstractmethod
-    def optimize(self, kernel: PsBlock) -> PsBlock:
-        pass
diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py
index f0e42bccb..7fa92c16d 100644
--- a/src/pystencils/backend/platforms/x86.py
+++ b/src/pystencils/backend/platforms/x86.py
@@ -3,7 +3,12 @@ from enum import Enum
 from functools import cache
 from typing import Sequence
 
-from ..ast.expressions import PsExpression, PsVectorArrayAccess, PsAddressOf, PsSubscript
+from ..ast.expressions import (
+    PsExpression,
+    PsVectorArrayAccess,
+    PsAddressOf,
+    PsSubscript,
+)
 from ..transformations.vector_intrinsics import IntrinsicOps
 from ..types import PsCustomType, PsVectorType
 from ..constants import PsConstant
@@ -135,14 +140,19 @@ class X86VectorCpu(GenericVectorCpu):
     def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression:
         if acc.stride == 1:
             load_func = _x86_packed_load(self._vector_arch, acc.dtype, False)
-            return load_func(PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)))
+            return load_func(
+                PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index))
+            )
         else:
             raise NotImplementedError("Gather loads not implemented yet.")
 
     def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression:
         if acc.stride == 1:
             store_func = _x86_packed_store(self._vector_arch, acc.dtype, False)
-            return store_func(PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)), arg)
+            return store_func(
+                PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)),
+                arg,
+            )
         else:
             raise NotImplementedError("Scatter stores not implemented yet.")
 
diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py
index 8b039a1dc..ebaeecdd7 100644
--- a/src/pystencils/backend/transformations/erase_anonymous_structs.py
+++ b/src/pystencils/backend/transformations/erase_anonymous_structs.py
@@ -32,6 +32,13 @@ class EraseAnonymousStructTypes:
     def __call__(self, node: PsAstNode) -> PsAstNode:
         self._substitutions = dict()
 
+        #   Check if AST traversal is even necessary
+        if not any(
+            (isinstance(arr.element_type, PsStructType) and arr.element_type.anonymous)
+            for arr in self._ctx.arrays
+        ):
+            return node
+
         node = self.visit(node)
 
         for old, new in self._substitutions.items():
diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index 293578982..770bcf8d9 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -1,6 +1,9 @@
+from typing import cast
+
 from .enums import Target
 from .config import CreateKernelConfig
 from .backend.ast import PsKernelFunction
+from .backend.ast.structural import PsBlock
 from .backend.kernelcreation import (
     KernelCreationContext,
     KernelAnalysis,
@@ -15,7 +18,6 @@ from .backend.kernelcreation.iteration_space import (
 from .backend.ast.analysis import collect_required_headers
 from .backend.transformations import EraseAnonymousStructTypes
 
-from .enums import Target
 from .sympyextensions import AssignmentCollection, Assignment
 
 
@@ -66,13 +68,12 @@ def create_kernel(
             raise NotImplementedError("Target platform not implemented")
 
     kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
-    kernel_ast = EraseAnonymousStructTypes(ctx)(kernel_ast)
+    kernel_ast = cast(PsBlock, EraseAnonymousStructTypes(ctx)(kernel_ast))
 
     #   7. Apply optimizations
     #     - Vectorization
     #     - OpenMP
     #     - Loop Splitting, Tiling, Blocking
-    kernel_ast = platform.optimize(kernel_ast)
 
     assert config.jit is not None
     req_headers = collect_required_headers(kernel_ast) | platform.required_headers
-- 
GitLab