Skip to content
Snippets Groups Projects
Commit 6d048af1 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

various minor fixes and refactorings

parent ee64e7e1
No related merge requests found
Pipeline #63251 failed with stages
in 47 seconds
...@@ -156,6 +156,7 @@ class PsArrayAssocSymbol(PsSymbol, ABC): ...@@ -156,6 +156,7 @@ class PsArrayAssocSymbol(PsSymbol, ABC):
Instances of this class represent pointers and indexing information bound Instances of this class represent pointers and indexing information bound
to a particular array. to a particular array.
""" """
__match_args__ = ("name", "dtype", "array") __match_args__ = ("name", "dtype", "array")
def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray): def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray):
...@@ -214,6 +215,7 @@ class PsArrayStrideSymbol(PsArrayAssocSymbol): ...@@ -214,6 +215,7 @@ class PsArrayStrideSymbol(PsArrayAssocSymbol):
Do not instantiate this class yourself, but only use its instances Do not instantiate this class yourself, but only use its instances
as provided by `PsLinearizedArray.strides`. as provided by `PsLinearizedArray.strides`.
""" """
__match_args__ = ("array", "coordinate", "dtype") __match_args__ = ("array", "coordinate", "dtype")
def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType): def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
......
...@@ -70,7 +70,9 @@ class UndefinedSymbolsCollector: ...@@ -70,7 +70,9 @@ class UndefinedSymbolsCollector:
return {symb} return {symb}
case _: case _:
return reduce( return reduce(
set.union, (self.visit_expr(cast(PsExpression, c)) for c in expr.children), set() set.union,
(self.visit_expr(cast(PsExpression, c)) for c in expr.children),
set(),
) )
def declared_variables(self, node: PsAstNode) -> set[PsSymbol]: def declared_variables(self, node: PsAstNode) -> set[PsSymbol]:
......
...@@ -85,7 +85,7 @@ class Ops(Enum): ...@@ -85,7 +85,7 @@ class Ops(Enum):
class PrinterCtx: class PrinterCtx:
def __init__(self) -> None: def __init__(self) -> None:
self.operator_stack = [Ops.Weakest] self.operator_stack = [Ops.Weakest]
self.branch_stack: list[LR] = [] self.branch_stack = [LR.Middle]
self.indent_level = 0 self.indent_level = 0
def push_op(self, operator: Ops, branch: LR): def push_op(self, operator: Ops, branch: LR):
......
from __future__ import annotations from __future__ import annotations
from typing import Iterable, Iterator
from itertools import chain from itertools import chain
from types import EllipsisType from types import EllipsisType
...@@ -24,6 +25,14 @@ class FieldsInKernel: ...@@ -24,6 +25,14 @@ class FieldsInKernel:
self.custom_fields: set[Field] = set() self.custom_fields: set[Field] = set()
self.buffer_fields: set[Field] = set() self.buffer_fields: set[Field] = set()
def __iter__(self) -> Iterator:
return chain(
self.domain_fields,
self.index_fields,
self.custom_fields,
self.buffer_fields,
)
class KernelCreationContext: class KernelCreationContext:
"""Manages the translation process from the SymPy frontend to the backend AST, and collects """Manages the translation process from the SymPy frontend to the backend AST, and collects
...@@ -80,6 +89,7 @@ class KernelCreationContext: ...@@ -80,6 +89,7 @@ class KernelCreationContext:
return tuple(self._constraints) return tuple(self._constraints)
# Symbols # Symbols
def get_symbol(self, name: str, dtype: PsAbstractType | None = None) -> PsSymbol: def get_symbol(self, name: str, dtype: PsAbstractType | None = None) -> PsSymbol:
if name not in self._symbols: if name not in self._symbols:
symb = PsSymbol(name, None) symb = PsSymbol(name, None)
...@@ -109,6 +119,10 @@ class KernelCreationContext: ...@@ -109,6 +119,10 @@ class KernelCreationContext:
self._symbols[old.name] = new self._symbols[old.name] = new
@property
def symbols(self) -> Iterable[PsSymbol]:
return self._symbols.values()
# Fields and Arrays # Fields and Arrays
@property @property
...@@ -214,6 +228,10 @@ class KernelCreationContext: ...@@ -214,6 +228,10 @@ class KernelCreationContext:
if isinstance(symb, PsSymbol): if isinstance(symb, PsSymbol):
self.add_symbol(symb) self.add_symbol(symb)
@property
def arrays(self) -> Iterable[PsLinearizedArray]:
return self._field_arrays.values()
def get_array(self, field: Field) -> PsLinearizedArray: def get_array(self, field: Field) -> PsLinearizedArray:
"""Retrieve the underlying array for a given field. """Retrieve the underlying array for a given field.
......
...@@ -38,9 +38,6 @@ class GenericCpu(Platform): ...@@ -38,9 +38,6 @@ class GenericCpu(Platform):
else: else:
assert False, "unreachable code" assert False, "unreachable code"
def optimize(self, kernel: PsBlock) -> PsBlock:
return kernel
# Internals # Internals
def _create_domain_loops( def _create_domain_loops(
......
...@@ -28,7 +28,3 @@ class Platform(ABC): ...@@ -28,7 +28,3 @@ class Platform(ABC):
self, block: PsBlock, ispace: IterationSpace self, block: PsBlock, ispace: IterationSpace
) -> PsBlock: ) -> PsBlock:
pass pass
@abstractmethod
def optimize(self, kernel: PsBlock) -> PsBlock:
pass
...@@ -3,7 +3,12 @@ from enum import Enum ...@@ -3,7 +3,12 @@ from enum import Enum
from functools import cache from functools import cache
from typing import Sequence from typing import Sequence
from ..ast.expressions import PsExpression, PsVectorArrayAccess, PsAddressOf, PsSubscript from ..ast.expressions import (
PsExpression,
PsVectorArrayAccess,
PsAddressOf,
PsSubscript,
)
from ..transformations.vector_intrinsics import IntrinsicOps from ..transformations.vector_intrinsics import IntrinsicOps
from ..types import PsCustomType, PsVectorType from ..types import PsCustomType, PsVectorType
from ..constants import PsConstant from ..constants import PsConstant
...@@ -135,14 +140,19 @@ class X86VectorCpu(GenericVectorCpu): ...@@ -135,14 +140,19 @@ class X86VectorCpu(GenericVectorCpu):
def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression: def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression:
if acc.stride == 1: if acc.stride == 1:
load_func = _x86_packed_load(self._vector_arch, acc.dtype, False) load_func = _x86_packed_load(self._vector_arch, acc.dtype, False)
return load_func(PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index))) return load_func(
PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index))
)
else: else:
raise NotImplementedError("Gather loads not implemented yet.") raise NotImplementedError("Gather loads not implemented yet.")
def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression: def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression:
if acc.stride == 1: if acc.stride == 1:
store_func = _x86_packed_store(self._vector_arch, acc.dtype, False) store_func = _x86_packed_store(self._vector_arch, acc.dtype, False)
return store_func(PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)), arg) return store_func(
PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)),
arg,
)
else: else:
raise NotImplementedError("Scatter stores not implemented yet.") raise NotImplementedError("Scatter stores not implemented yet.")
......
...@@ -32,6 +32,13 @@ class EraseAnonymousStructTypes: ...@@ -32,6 +32,13 @@ class EraseAnonymousStructTypes:
def __call__(self, node: PsAstNode) -> PsAstNode: def __call__(self, node: PsAstNode) -> PsAstNode:
self._substitutions = dict() self._substitutions = dict()
# Check if AST traversal is even necessary
if not any(
(isinstance(arr.element_type, PsStructType) and arr.element_type.anonymous)
for arr in self._ctx.arrays
):
return node
node = self.visit(node) node = self.visit(node)
for old, new in self._substitutions.items(): for old, new in self._substitutions.items():
......
from typing import cast
from .enums import Target from .enums import Target
from .config import CreateKernelConfig from .config import CreateKernelConfig
from .backend.ast import PsKernelFunction from .backend.ast import PsKernelFunction
from .backend.ast.structural import PsBlock
from .backend.kernelcreation import ( from .backend.kernelcreation import (
KernelCreationContext, KernelCreationContext,
KernelAnalysis, KernelAnalysis,
...@@ -15,7 +18,6 @@ from .backend.kernelcreation.iteration_space import ( ...@@ -15,7 +18,6 @@ from .backend.kernelcreation.iteration_space import (
from .backend.ast.analysis import collect_required_headers from .backend.ast.analysis import collect_required_headers
from .backend.transformations import EraseAnonymousStructTypes from .backend.transformations import EraseAnonymousStructTypes
from .enums import Target
from .sympyextensions import AssignmentCollection, Assignment from .sympyextensions import AssignmentCollection, Assignment
...@@ -66,13 +68,12 @@ def create_kernel( ...@@ -66,13 +68,12 @@ def create_kernel(
raise NotImplementedError("Target platform not implemented") raise NotImplementedError("Target platform not implemented")
kernel_ast = platform.materialize_iteration_space(kernel_body, ispace) kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
kernel_ast = EraseAnonymousStructTypes(ctx)(kernel_ast) kernel_ast = cast(PsBlock, EraseAnonymousStructTypes(ctx)(kernel_ast))
# 7. Apply optimizations # 7. Apply optimizations
# - Vectorization # - Vectorization
# - OpenMP # - OpenMP
# - Loop Splitting, Tiling, Blocking # - Loop Splitting, Tiling, Blocking
kernel_ast = platform.optimize(kernel_ast)
assert config.jit is not None assert config.jit is not None
req_headers = collect_required_headers(kernel_ast) | platform.required_headers req_headers = collect_required_headers(kernel_ast) | platform.required_headers
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment