From 6d048af1ba83a11f2714d96910e5a1b6a544fb54 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 28 Feb 2024 09:42:42 +0100 Subject: [PATCH] various minor fixes and refactorings --- src/pystencils/backend/arrays.py | 2 ++ src/pystencils/backend/ast/analysis.py | 4 +++- src/pystencils/backend/emission.py | 2 +- .../backend/kernelcreation/context.py | 18 ++++++++++++++++++ .../backend/platforms/generic_cpu.py | 3 --- src/pystencils/backend/platforms/platform.py | 4 ---- src/pystencils/backend/platforms/x86.py | 16 +++++++++++++--- .../transformations/erase_anonymous_structs.py | 7 +++++++ src/pystencils/kernelcreation.py | 7 ++++--- 9 files changed, 48 insertions(+), 15 deletions(-) diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py index 586da3799..be159bcae 100644 --- a/src/pystencils/backend/arrays.py +++ b/src/pystencils/backend/arrays.py @@ -156,6 +156,7 @@ class PsArrayAssocSymbol(PsSymbol, ABC): Instances of this class represent pointers and indexing information bound to a particular array. """ + __match_args__ = ("name", "dtype", "array") def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray): @@ -214,6 +215,7 @@ class PsArrayStrideSymbol(PsArrayAssocSymbol): Do not instantiate this class yourself, but only use its instances as provided by `PsLinearizedArray.strides`. """ + __match_args__ = ("array", "coordinate", "dtype") def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType): diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py index 718a9397e..4bd174485 100644 --- a/src/pystencils/backend/ast/analysis.py +++ b/src/pystencils/backend/ast/analysis.py @@ -70,7 +70,9 @@ class UndefinedSymbolsCollector: return {symb} case _: return reduce( - set.union, (self.visit_expr(cast(PsExpression, c)) for c in expr.children), set() + set.union, + (self.visit_expr(cast(PsExpression, c)) for c in expr.children), + set(), ) def declared_variables(self, node: PsAstNode) -> set[PsSymbol]: diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 8de626e1f..829ffb53a 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -85,7 +85,7 @@ class Ops(Enum): class PrinterCtx: def __init__(self) -> None: self.operator_stack = [Ops.Weakest] - self.branch_stack: list[LR] = [] + self.branch_stack = [LR.Middle] self.indent_level = 0 def push_op(self, operator: Ops, branch: LR): diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 3bde2a135..ba6574090 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -1,5 +1,6 @@ from __future__ import annotations +from typing import Iterable, Iterator from itertools import chain from types import EllipsisType @@ -24,6 +25,14 @@ class FieldsInKernel: self.custom_fields: set[Field] = set() self.buffer_fields: set[Field] = set() + def __iter__(self) -> Iterator: + return chain( + self.domain_fields, + self.index_fields, + self.custom_fields, + self.buffer_fields, + ) + class KernelCreationContext: """Manages the translation process from the SymPy frontend to the backend AST, and collects @@ -80,6 +89,7 @@ class KernelCreationContext: return tuple(self._constraints) # Symbols + def get_symbol(self, name: str, dtype: PsAbstractType | None = None) -> PsSymbol: if name not in self._symbols: symb = PsSymbol(name, None) @@ -109,6 +119,10 @@ class KernelCreationContext: self._symbols[old.name] = new + @property + def symbols(self) -> Iterable[PsSymbol]: + return self._symbols.values() + # Fields and Arrays @property @@ -214,6 +228,10 @@ class KernelCreationContext: if isinstance(symb, PsSymbol): self.add_symbol(symb) + @property + def arrays(self) -> Iterable[PsLinearizedArray]: + return self._field_arrays.values() + def get_array(self, field: Field) -> PsLinearizedArray: """Retrieve the underlying array for a given field. diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index b3a49cb65..8f8c0fab8 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -38,9 +38,6 @@ class GenericCpu(Platform): else: assert False, "unreachable code" - def optimize(self, kernel: PsBlock) -> PsBlock: - return kernel - # Internals def _create_domain_loops( diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py index 7c6d3a2ee..3fedf7c01 100644 --- a/src/pystencils/backend/platforms/platform.py +++ b/src/pystencils/backend/platforms/platform.py @@ -28,7 +28,3 @@ class Platform(ABC): self, block: PsBlock, ispace: IterationSpace ) -> PsBlock: pass - - @abstractmethod - def optimize(self, kernel: PsBlock) -> PsBlock: - pass diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index f0e42bccb..7fa92c16d 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -3,7 +3,12 @@ from enum import Enum from functools import cache from typing import Sequence -from ..ast.expressions import PsExpression, PsVectorArrayAccess, PsAddressOf, PsSubscript +from ..ast.expressions import ( + PsExpression, + PsVectorArrayAccess, + PsAddressOf, + PsSubscript, +) from ..transformations.vector_intrinsics import IntrinsicOps from ..types import PsCustomType, PsVectorType from ..constants import PsConstant @@ -135,14 +140,19 @@ class X86VectorCpu(GenericVectorCpu): def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression: if acc.stride == 1: load_func = _x86_packed_load(self._vector_arch, acc.dtype, False) - return load_func(PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index))) + return load_func( + PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)) + ) else: raise NotImplementedError("Gather loads not implemented yet.") def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression: if acc.stride == 1: store_func = _x86_packed_store(self._vector_arch, acc.dtype, False) - return store_func(PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)), arg) + return store_func( + PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)), + arg, + ) else: raise NotImplementedError("Scatter stores not implemented yet.") diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py index 8b039a1dc..ebaeecdd7 100644 --- a/src/pystencils/backend/transformations/erase_anonymous_structs.py +++ b/src/pystencils/backend/transformations/erase_anonymous_structs.py @@ -32,6 +32,13 @@ class EraseAnonymousStructTypes: def __call__(self, node: PsAstNode) -> PsAstNode: self._substitutions = dict() + # Check if AST traversal is even necessary + if not any( + (isinstance(arr.element_type, PsStructType) and arr.element_type.anonymous) + for arr in self._ctx.arrays + ): + return node + node = self.visit(node) for old, new in self._substitutions.items(): diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 293578982..770bcf8d9 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -1,6 +1,9 @@ +from typing import cast + from .enums import Target from .config import CreateKernelConfig from .backend.ast import PsKernelFunction +from .backend.ast.structural import PsBlock from .backend.kernelcreation import ( KernelCreationContext, KernelAnalysis, @@ -15,7 +18,6 @@ from .backend.kernelcreation.iteration_space import ( from .backend.ast.analysis import collect_required_headers from .backend.transformations import EraseAnonymousStructTypes -from .enums import Target from .sympyextensions import AssignmentCollection, Assignment @@ -66,13 +68,12 @@ def create_kernel( raise NotImplementedError("Target platform not implemented") kernel_ast = platform.materialize_iteration_space(kernel_body, ispace) - kernel_ast = EraseAnonymousStructTypes(ctx)(kernel_ast) + kernel_ast = cast(PsBlock, EraseAnonymousStructTypes(ctx)(kernel_ast)) # 7. Apply optimizations # - Vectorization # - OpenMP # - Loop Splitting, Tiling, Blocking - kernel_ast = platform.optimize(kernel_ast) assert config.jit is not None req_headers = collect_required_headers(kernel_ast) | platform.required_headers -- GitLab