Skip to content
Snippets Groups Projects
Commit d976a971 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

updated freeze. Sketched kernel creation.

parent 0aee9f49
Branches
Tags
No related merge requests found
Pipeline #60370 failed with stages
in 11 minutes and 1 second
...@@ -35,14 +35,15 @@ class IterationSpace(ABC): ...@@ -35,14 +35,15 @@ class IterationSpace(ABC):
spatial indices. spatial indices.
""" """
def __init__(self, spatial_index_variables: tuple[PsTypedVariable, ...]): def __init__(self, spatial_indices: tuple[PsTypedVariable, ...]):
if len(spatial_index_variables) == 0: if len(spatial_indices) == 0:
raise ValueError("Iteration space must be at least one-dimensional.") raise ValueError("Iteration space must be at least one-dimensional.")
self._spatial_index_vars = spatial_index_variables self._spatial_indices = spatial_indices
def get_spatial_index(self, coordinate: int) -> PsTypedVariable: @property
return self._spatial_index_vars[coordinate] def spatial_indices(self) -> tuple[PsTypedVariable, ...]:
return self._spatial_indices
class FullIterationSpace(IterationSpace): class FullIterationSpace(IterationSpace):
...@@ -117,6 +118,7 @@ class KernelCreationContext: ...@@ -117,6 +118,7 @@ class KernelCreationContext:
def __init__(self, index_dtype: PsIntegerType): def __init__(self, index_dtype: PsIntegerType):
self._index_dtype = index_dtype self._index_dtype = index_dtype
self._arrays: dict[Field, PsFieldArrayPair] = dict()
self._constraints: list[PsKernelConstraint] = [] self._constraints: list[PsKernelConstraint] = []
@property @property
...@@ -156,4 +158,9 @@ class KernelCreationContext: ...@@ -156,4 +158,9 @@ class KernelCreationContext:
field=field, array=arr, base_ptr=PsArrayBasePointer("arr_data", arr) field=field, array=arr, base_ptr=PsArrayBasePointer("arr_data", arr)
) )
self._arrays[field] = fa_pair
return fa_pair return fa_pair
def get_array_descriptor(self, field: Field):
return self._arrays[field]
from types import EllipsisType
from ...simp import AssignmentCollection
from ...field import Field
from ...kernel_contrains_check import KernelConstraintsCheck
from ..types.quick import SInt
from ..ast import PsBlock
from .context import KernelCreationContext, FullIterationSpace
from .freeze import FreezeExpressions
# flake8: noqa
def create_domain_kernel(assignments: AssignmentCollection):
# TODO: Assemble configuration
# 1. Prepare context
ctx = KernelCreationContext(SInt(64)) # TODO: how to determine index type?
# 2. Check kernel constraints and collect all fields
check = KernelConstraintsCheck() # TODO: config
check.visit(assignments)
all_fields: set[Field] = check.fields_written | check.fields_read
# 3. Register fields
for f in all_fields:
ctx.add_field(f)
# All steps up to this point are the same in domain and indexed kernels;
# the difference now comes with the iteration space.
#
# Domain kernels create a full iteration space from their iteration slice
# which is either explicitly given or computed from ghost layer requirements.
# Indexed kernels, on the other hand, have to create a sparse iteration space
# from one index list.
# 4. Create iteration space
ghost_layers: int = NotImplemented # determine required ghost layers
common_shape: tuple[
int | EllipsisType, ...
] = NotImplemented # unify field shapes, add parameter constraints
# don't forget custom iteration slice
ispace: FullIterationSpace = (
NotImplemented # create from ghost layers and with given shape
)
# 5. Freeze assignments
# This call is the same for both domain and indexed kernels
freeze = FreezeExpressions(ctx, ispace)
kernel_body: PsBlock = freeze(assignments)
# 6. Typify
# Also the same for both types of kernels
# determine_types(kernel_body)
# Up to this point, all was target-agnostic, but now the target becomes relevant.
# Here we might hand off the compilation to a target-specific part of the compiler
# (CPU/CUDA/...), since these will likely also apply very different optimizations.
# 7. Add loops or device indexing
# This step translates the iteration space to actual index calculation code and is once again
# different in indexed and domain kernels.
# 8. Apply optimizations
# - Vectorization
# - OpenMP
# - Loop Splitting, Tiling, Blocking
# 9. Create and return kernel function.
import pymbolic.primitives as pb
from pymbolic.interop.sympy import SympyToPymbolicMapper
from ...field import Field, FieldType
from .context import KernelCreationContext, IterationSpace, SparseIterationSpace
from ..ast.nodes import PsAssignment
from ..types import PsSignedIntegerType, PsIeeeFloatType, PsUnsignedIntegerType
from ..typed_expressions import PsTypedVariable
from ..arrays import PsArrayAccess
from ..exceptions import PsInternalCompilerError
class FreezeExpressions(SympyToPymbolicMapper):
def __init__(self, ctx: KernelCreationContext, ispace: IterationSpace):
self._ctx = ctx
self._ispace = ispace
def map_Assignment(self, expr): # noqa
lhs = self.rec(expr.lhs)
rhs = self.rec(expr.rhs)
return PsAssignment(lhs, rhs)
def map_BasicType(self, expr):
width = expr.numpy_dtype.itemsize * 8
const = expr.const
if expr.is_float():
return PsIeeeFloatType(width, const)
elif expr.is_uint():
return PsUnsignedIntegerType(width, const)
elif expr.is_int():
return PsSignedIntegerType(width, const)
else:
raise NotImplementedError("Data type not supported.")
def map_FieldShapeSymbol(self, expr):
dtype = self.rec(expr.dtype)
return PsTypedVariable(expr.name, dtype)
def map_TypedSymbol(self, expr):
dtype = self.rec(expr.dtype)
return PsTypedVariable(expr.name, dtype)
def map_Access(self, access: Field.Access):
field = access.field
array_desc = self._ctx.get_array_descriptor(field)
array = array_desc.array
ptr = array_desc.base_ptr
offsets: list[pb.Expression] = [self.rec(o) for o in access.offsets]
indices: list[pb.Expression] = [self.rec(o) for o in access.index]
if not access.is_absolute_access:
match field.field_type:
case FieldType.GENERIC:
# Add the iteration counters
offsets = [
i + o for i, o in zip(self._ispace.spatial_indices, offsets)
]
case FieldType.INDEXED:
if isinstance(self._ispace, SparseIterationSpace):
# TODO: make sure index (and all offsets?) are zero
# TODO: Add sparse iteration counter
raise NotImplementedError()
else:
raise PsInternalCompilerError(
"Cannot translate index field access without a sparse iteration space."
)
case FieldType.CUSTOM:
raise ValueError("Custom fields support only absolute accesses.")
case unknown:
raise NotImplementedError(
f"Cannot translate accesses to field type {unknown} yet."
)
index = pb.Sum(
tuple(
idx * stride
for idx, stride in zip(offsets + indices, array.strides, strict=True)
)
)
return PsArrayAccess(ptr, index)
from pymbolic.interop.sympy import SympyToPymbolicMapper
from pystencils.typing import TypedSymbol
from pystencils.typing.typed_sympy import SHAPE_DTYPE
from .ast.nodes import PsAssignment, PsSymbolExpr
from .types import PsSignedIntegerType, PsIeeeFloatType, PsUnsignedIntegerType
from .typed_expressions import PsTypedVariable
from .arrays import PsArrayBasePointer, PsLinearizedArray, PsArrayAccess
CTR_SYMBOLS = [TypedSymbol(f"ctr_{i}", SHAPE_DTYPE) for i in range(3)]
class PystencilsToPymbolicMapper(SympyToPymbolicMapper):
def map_Assignment(self, expr): # noqa
lhs = self.rec(expr.lhs)
rhs = self.rec(expr.rhs)
return PsAssignment(lhs, rhs)
def map_BasicType(self, expr):
width = expr.numpy_dtype.itemsize * 8
const = expr.const
if expr.is_float():
return PsIeeeFloatType(width, const)
elif expr.is_uint():
return PsUnsignedIntegerType(width, const)
elif expr.is_int():
return PsSignedIntegerType(width, const)
else:
raise (NotImplementedError, "Not supported dtype")
def map_FieldShapeSymbol(self, expr):
dtype = self.rec(expr.dtype)
return PsTypedVariable(expr.name, dtype)
def map_TypedSymbol(self, expr):
dtype = self.rec(expr.dtype)
return PsTypedVariable(expr.name, dtype)
def map_Access(self, expr):
name = expr.field.name
shape = tuple([self.rec(s) for s in expr.field.shape])
strides = tuple([self.rec(s) for s in expr.field.strides])
dtype = self.rec(expr.dtype)
array = PsLinearizedArray(name, shape, strides, dtype)
ptr = PsArrayBasePointer(expr.name, array)
index = sum(
[ctr * stride for ctr, stride in zip(CTR_SYMBOLS, expr.field.strides)]
)
index = self.rec(index)
return PsSymbolExpr(PsArrayAccess(ptr, index))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment