Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (6)
Showing
with 638 additions and 305 deletions
...@@ -51,7 +51,7 @@ class PsExpression(PsAstNode, ABC): ...@@ -51,7 +51,7 @@ class PsExpression(PsAstNode, ABC):
def get_dtype(self) -> PsType: def get_dtype(self) -> PsType:
if self._dtype is None: if self._dtype is None:
raise PsInternalCompilerError("No dtype set on this expression yet.") raise PsInternalCompilerError(f"No data type set on expression {self}.")
return self._dtype return self._dtype
......
...@@ -17,9 +17,21 @@ def emit_ir(ir: PsAstNode): ...@@ -17,9 +17,21 @@ def emit_ir(ir: PsAstNode):
class IRAstPrinter(BasePrinter): class IRAstPrinter(BasePrinter):
"""Print the IR AST as pseudo-code.
def __init__(self, indent_width=3):
This printer produces a complete pseudocode representation of a pystencils AST.
Other than the `CAstPrinter`, the `IRAstPrinter` is capable of emitting code for
each node defined in `ast <pystencils.backend.ast>`.
It is furthermore configurable w.r.t. the level of detail it should emit.
Args:
indent_width: Number of spaces with which to indent lines in each nested block.
annotate_constants: If ``True`` (the default), annotate all constant literals with their data type.
"""
def __init__(self, indent_width=3, annotate_constants: bool = True):
super().__init__(indent_width) super().__init__(indent_width)
self._annotate_constants = annotate_constants
def visit(self, node: PsAstNode, pc: PrinterCtx) -> str: def visit(self, node: PsAstNode, pc: PrinterCtx) -> str:
match node: match node:
...@@ -66,7 +78,10 @@ class IRAstPrinter(BasePrinter): ...@@ -66,7 +78,10 @@ class IRAstPrinter(BasePrinter):
return f"{symb.name}: {self._type_str(symb.dtype)}" return f"{symb.name}: {self._type_str(symb.dtype)}"
def _constant_literal(self, constant: PsConstant) -> str: def _constant_literal(self, constant: PsConstant) -> str:
return f"[{constant.value}: {self._deconst_type_str(constant.dtype)}]" if self._annotate_constants:
return f"[{constant.value}: {self._deconst_type_str(constant.dtype)}]"
else:
return str(constant.value)
def _type_str(self, dtype: PsType | None): def _type_str(self, dtype: PsType | None):
if dtype is None: if dtype is None:
......
...@@ -12,8 +12,6 @@ from .iteration_space import ( ...@@ -12,8 +12,6 @@ from .iteration_space import (
create_sparse_iteration_space, create_sparse_iteration_space,
) )
from .cpu_optimization import optimize_cpu
__all__ = [ __all__ = [
"KernelCreationContext", "KernelCreationContext",
"KernelAnalysis", "KernelAnalysis",
...@@ -25,5 +23,4 @@ __all__ = [ ...@@ -25,5 +23,4 @@ __all__ = [
"SparseIterationSpace", "SparseIterationSpace",
"create_full_iteration_space", "create_full_iteration_space",
"create_sparse_iteration_space", "create_sparse_iteration_space",
"optimize_cpu",
] ]
from __future__ import annotations
from typing import cast, TYPE_CHECKING
from .context import KernelCreationContext
from ..ast.structural import PsBlock
from ...config import CpuOptimConfig, OpenMpConfig
if TYPE_CHECKING:
from ..platforms import GenericCpu
def optimize_cpu(
ctx: KernelCreationContext,
platform: GenericCpu,
kernel_ast: PsBlock,
cfg: CpuOptimConfig | None,
) -> PsBlock:
"""Carry out CPU-specific optimizations according to the given configuration."""
from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
canonicalize = CanonicalizeSymbols(ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
hoist_invariants = HoistLoopInvariantDeclarations(ctx)
kernel_ast = cast(PsBlock, hoist_invariants(kernel_ast))
if cfg is None:
return kernel_ast
if cfg.loop_blocking:
raise NotImplementedError("Loop blocking not implemented yet.")
if cfg.vectorize is not False:
raise NotImplementedError("Vectorization not implemented yet")
if cfg.openmp is not False:
from ..transformations import AddOpenMP
params = cfg.openmp if isinstance(cfg.openmp, OpenMpConfig) else OpenMpConfig()
add_omp = AddOpenMP(ctx, params)
kernel_ast = cast(PsBlock, add_omp(kernel_ast))
if cfg.use_cacheline_zeroing:
raise NotImplementedError("CL-zeroing not implemented yet")
return kernel_ast
...@@ -13,7 +13,7 @@ from ..memory import PsSymbol, PsBuffer ...@@ -13,7 +13,7 @@ from ..memory import PsSymbol, PsBuffer
from ..constants import PsConstant from ..constants import PsConstant
from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem
from ..ast.util import failing_cast from ..ast.util import failing_cast
from ...types import PsStructType, constify from ...types import PsStructType
from ..exceptions import PsInputError, KernelConstraintsError from ..exceptions import PsInputError, KernelConstraintsError
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -209,17 +209,23 @@ class FullIterationSpace(IterationSpace): ...@@ -209,17 +209,23 @@ class FullIterationSpace(IterationSpace):
@property @property
def archetype_field(self) -> Field | None: def archetype_field(self) -> Field | None:
return self._archetype_field return self._archetype_field
@property
def loop_order(self) -> tuple[int, ...]:
"""Return the loop order of this iteration space, ordered from slowest to fastest coordinate."""
if self._archetype_field is not None:
return self._archetype_field.layout
else:
return tuple(range(len(self.dimensions)))
def dimensions_in_loop_order(self) -> Sequence[FullIterationSpace.Dimension]: def dimensions_in_loop_order(self) -> Sequence[FullIterationSpace.Dimension]:
"""Return the dimensions of this iteration space ordered from the slowest to the fastest coordinate. """Return the dimensions of this iteration space ordered from the slowest to the fastest coordinate.
If an archetype field is specified, the field layout is used to determine the ideal loop order; If this iteration space has an `archetype field <FullIterationSpace.archetype_field>` set,
its field layout is used to determine the ideal loop order;
otherwise, the dimensions are returned as they are otherwise, the dimensions are returned as they are
""" """
if self._archetype_field is not None: return [self._dimensions[i] for i in self.loop_order]
return [self._dimensions[i] for i in self._archetype_field.layout]
else:
return self._dimensions
def actual_iterations( def actual_iterations(
self, dimension: int | FullIterationSpace.Dimension | None = None self, dimension: int | FullIterationSpace.Dimension | None = None
...@@ -359,7 +365,7 @@ def create_sparse_iteration_space( ...@@ -359,7 +365,7 @@ def create_sparse_iteration_space(
dim = archetype_field.spatial_dimensions dim = archetype_field.spatial_dimensions
coord_members = [ coord_members = [
PsStructType.Member(name, ctx.index_dtype) PsStructType.Member(name, ctx.index_dtype)
for name in DEFAULTS._index_struct_coordinate_names[:dim] for name in DEFAULTS.index_struct_coordinate_names[:dim]
] ]
# Determine index field # Determine index field
...@@ -379,7 +385,7 @@ def create_sparse_iteration_space( ...@@ -379,7 +385,7 @@ def create_sparse_iteration_space(
) )
spatial_counters = [ spatial_counters = [
ctx.get_symbol(name, constify(ctx.index_dtype)) ctx.get_symbol(name, ctx.index_dtype)
for name in DEFAULTS.spatial_counter_names[:dim] for name in DEFAULTS.spatial_counter_names[:dim]
] ]
......
...@@ -21,6 +21,7 @@ from ..constants import PsConstant ...@@ -21,6 +21,7 @@ from ..constants import PsConstant
from ..exceptions import MaterializationError from ..exceptions import MaterializationError
from .generic_cpu import GenericVectorCpu from .generic_cpu import GenericVectorCpu
from ..kernelcreation import KernelCreationContext
from ...types.quick import Fp, SInt from ...types.quick import Fp, SInt
from ..functions import CFunction from ..functions import CFunction
...@@ -77,6 +78,28 @@ class X86VectorArch(Enum): ...@@ -77,6 +78,28 @@ class X86VectorArch(Enum):
) )
return suffix return suffix
def intrin_type(self, vtype: PsVectorType):
scalar_type = vtype.scalar_type
match scalar_type:
case Fp(16) if self >= X86VectorArch.AVX512:
suffix = "h"
case Fp(32):
suffix = ""
case Fp(64):
suffix = "d"
case SInt(_):
suffix = "i"
case _:
raise MaterializationError(
f"x86/{self} does not support scalar type {scalar_type}"
)
if vtype.width > self.max_vector_width:
raise MaterializationError(
f"x86/{self} does not support {vtype}"
)
return PsCustomType(f"__m{vtype.width}{suffix}")
class X86VectorCpu(GenericVectorCpu): class X86VectorCpu(GenericVectorCpu):
...@@ -86,7 +109,8 @@ class X86VectorCpu(GenericVectorCpu): ...@@ -86,7 +109,8 @@ class X86VectorCpu(GenericVectorCpu):
https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html. https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html.
""" """
def __init__(self, vector_arch: X86VectorArch): def __init__(self, ctx: KernelCreationContext, vector_arch: X86VectorArch):
super().__init__(ctx)
self._vector_arch = vector_arch self._vector_arch = vector_arch
@property @property
...@@ -111,26 +135,7 @@ class X86VectorCpu(GenericVectorCpu): ...@@ -111,26 +135,7 @@ class X86VectorCpu(GenericVectorCpu):
return super().required_headers | headers return super().required_headers | headers
def type_intrinsic(self, vector_type: PsVectorType) -> PsCustomType: def type_intrinsic(self, vector_type: PsVectorType) -> PsCustomType:
scalar_type = vector_type.scalar_type return self._vector_arch.intrin_type(vector_type)
match scalar_type:
case Fp(16) if self._vector_arch >= X86VectorArch.AVX512:
suffix = "h"
case Fp(32):
suffix = ""
case Fp(64):
suffix = "d"
case SInt(_):
suffix = "i"
case _:
raise MaterializationError(
f"x86/{self._vector_arch} does not support scalar type {scalar_type}"
)
if vector_type.width > self._vector_arch.max_vector_width:
raise MaterializationError(
f"x86/{self._vector_arch} does not support {vector_type}"
)
return PsCustomType(f"__m{vector_type.width}{suffix}")
def constant_intrinsic(self, c: PsConstant) -> PsExpression: def constant_intrinsic(self, c: PsConstant) -> PsExpression:
vtype = c.dtype vtype = c.dtype
...@@ -212,12 +217,14 @@ def _x86_op_intrin( ...@@ -212,12 +217,14 @@ def _x86_op_intrin(
) -> CFunction: ) -> CFunction:
prefix = varch.intrin_prefix(vtype) prefix = varch.intrin_prefix(vtype)
suffix = varch.intrin_suffix(vtype) suffix = varch.intrin_suffix(vtype)
rtype = atype = varch.intrin_type(vtype)
match op: match op:
case PsVecBroadcast(): case PsVecBroadcast():
opstr = "set1" opstr = "set1"
if vtype.scalar_type == SInt(64) and vtype.vector_entries <= 4: if vtype.scalar_type == SInt(64) and vtype.vector_entries <= 4:
suffix += "x" suffix += "x"
atype = vtype.scalar_type
case PsAdd(): case PsAdd():
opstr = "add" opstr = "add"
case PsSub(): case PsSub():
...@@ -236,4 +243,4 @@ def _x86_op_intrin( ...@@ -236,4 +243,4 @@ def _x86_op_intrin(
raise MaterializationError(f"Unable to select operation intrinsic for {type(op)}") raise MaterializationError(f"Unable to select operation intrinsic for {type(op)}")
num_args = 1 if isinstance(op, PsUnOp) else 2 num_args = 1 if isinstance(op, PsUnOp) else 2
return CFunction(f"{prefix}_{opstr}_{suffix}", (vtype,) * num_args, vtype) return CFunction(f"{prefix}_{opstr}_{suffix}", (atype,) * num_args, rtype)
...@@ -45,7 +45,7 @@ from ...types import ( ...@@ -45,7 +45,7 @@ from ...types import (
PsBoolType, PsBoolType,
PsScalarType, PsScalarType,
PsVectorType, PsVectorType,
PsTypeError, constify
) )
...@@ -57,9 +57,9 @@ class ECContext: ...@@ -57,9 +57,9 @@ class ECContext:
self._ctx = ctx self._ctx = ctx
self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict() self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict()
from ..emission import CAstPrinter from ..emission import IRAstPrinter
self._printer = CAstPrinter(0) self._printer = IRAstPrinter(indent_width=0, annotate_constants=False)
@property @property
def extractions(self) -> Iterable[tuple[PsSymbol, PsExpression]]: def extractions(self) -> Iterable[tuple[PsSymbol, PsExpression]]:
...@@ -89,10 +89,7 @@ class ECContext: ...@@ -89,10 +89,7 @@ class ECContext:
if expr_wrapped not in self._extracted_constants: if expr_wrapped not in self._extracted_constants:
symb_name = self._get_symb_name(expr) symb_name = self._get_symb_name(expr)
try: symb = self._ctx.get_new_symbol(symb_name, constify(dtype))
symb = self._ctx.get_symbol(symb_name, dtype)
except PsTypeError:
symb = self._ctx.get_symbol(f"{symb_name}_{dtype.c_string()}", dtype)
self._extracted_constants[expr_wrapped] = symb self._extracted_constants[expr_wrapped] = symb
else: else:
...@@ -133,6 +130,10 @@ class EliminateConstants: ...@@ -133,6 +130,10 @@ class EliminateConstants:
def __call__(self, node: PsExpression) -> PsExpression: def __call__(self, node: PsExpression) -> PsExpression:
pass pass
@overload
def __call__(self, node: PsBlock) -> PsBlock:
pass
@overload @overload
def __call__(self, node: PsAstNode) -> PsAstNode: def __call__(self, node: PsAstNode) -> PsAstNode:
pass pass
......
...@@ -61,6 +61,24 @@ class LoopVectorizer: ...@@ -61,6 +61,24 @@ class LoopVectorizer:
self._vectorize_ast = AstVectorizer(ctx) self._vectorize_ast = AstVectorizer(ctx)
self._fold = EliminateConstants(ctx) self._fold = EliminateConstants(ctx)
@overload
def vectorize_select_loops(
self, node: PsBlock, predicate: Callable[[PsLoop], bool]
) -> PsBlock:
...
@overload
def vectorize_select_loops(
self, node: PsLoop, predicate: Callable[[PsLoop], bool]
) -> PsLoop | PsBlock:
...
@overload
def vectorize_select_loops(
self, node: PsAstNode, predicate: Callable[[PsLoop], bool]
) -> PsAstNode:
...
def vectorize_select_loops( def vectorize_select_loops(
self, node: PsAstNode, predicate: Callable[[PsLoop], bool] self, node: PsAstNode, predicate: Callable[[PsLoop], bool]
) -> PsAstNode: ) -> PsAstNode:
......
...@@ -37,9 +37,13 @@ class LowerToC: ...@@ -37,9 +37,13 @@ class LowerToC:
def __init__(self, ctx: KernelCreationContext) -> None: def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx self._ctx = ctx
self._substitutions: dict[PsSymbol, PsSymbol] = dict()
self._typify = Typifier(ctx) self._typify = Typifier(ctx)
self._substitutions: dict[PsSymbol, PsSymbol] = dict() from .eliminate_constants import EliminateConstants
self._fold = EliminateConstants(self._ctx)
def __call__(self, node: PsAstNode) -> PsAstNode: def __call__(self, node: PsAstNode) -> PsAstNode:
self._substitutions = dict() self._substitutions = dict()
...@@ -65,7 +69,8 @@ class LowerToC: ...@@ -65,7 +69,8 @@ class LowerToC:
return i return i
summands: list[PsExpression] = [ summands: list[PsExpression] = [
maybe_cast(cast(PsExpression, self.visit(idx).clone())) * PsExpression.make(stride) maybe_cast(cast(PsExpression, self.visit(idx).clone()))
* PsExpression.make(stride)
for idx, stride in zip(indices, buf.strides, strict=True) for idx, stride in zip(indices, buf.strides, strict=True)
] ]
...@@ -77,9 +82,11 @@ class LowerToC: ...@@ -77,9 +82,11 @@ class LowerToC:
mem_acc = PsMemAcc(bptr.clone(), linearized_idx) mem_acc = PsMemAcc(bptr.clone(), linearized_idx)
return self._typify.typify_expression( return self._fold(
mem_acc, target_type=buf.element_type self._typify.typify_expression(
)[0] mem_acc, target_type=buf.element_type
)[0]
)
case PsLookup(aggr, member_name) if isinstance( case PsLookup(aggr, member_name) if isinstance(
aggr, PsBufferAcc aggr, PsBufferAcc
...@@ -115,10 +122,7 @@ class LowerToC: ...@@ -115,10 +122,7 @@ class LowerToC:
const=bp_type.const, const=bp_type.const,
restrict=bp_type.restrict, restrict=bp_type.restrict,
) )
type_erased_bp = PsSymbol( type_erased_bp = PsSymbol(bp.name, erased_type)
bp.name,
erased_type
)
type_erased_bp.add_property(BufferBasePtr(buf)) type_erased_bp.add_property(BufferBasePtr(buf))
self._substitutions[bp] = type_erased_bp self._substitutions[bp] = type_erased_bp
else: else:
......
...@@ -5,12 +5,18 @@ from warnings import warn ...@@ -5,12 +5,18 @@ from warnings import warn
from collections.abc import Collection from collections.abc import Collection
from typing import Sequence from typing import Sequence
from dataclasses import dataclass, InitVar from dataclasses import dataclass, InitVar, replace
from .target import Target from .target import Target
from .field import Field, FieldType from .field import Field, FieldType
from .types import PsIntegerType, UserTypeSpec, PsIeeeFloatType, create_type from .types import (
PsIntegerType,
UserTypeSpec,
PsIeeeFloatType,
PsScalarType,
create_type,
)
from .defaults import DEFAULTS from .defaults import DEFAULTS
...@@ -90,6 +96,14 @@ class CpuOptimConfig: ...@@ -90,6 +96,14 @@ class CpuOptimConfig:
to produce cacheline zeroing instructions where possible. to produce cacheline zeroing instructions where possible.
""" """
def get_vectorization_config(self) -> VectorizationConfig | None:
if self.vectorize is True:
return VectorizationConfig()
elif isinstance(self.vectorize, VectorizationConfig):
return self.vectorize
else:
return None
@dataclass @dataclass
class VectorizationConfig: class VectorizationConfig:
...@@ -99,14 +113,13 @@ class VectorizationConfig: ...@@ -99,14 +113,13 @@ class VectorizationConfig:
in `CreateKernelConfig.target`, an error will be raised. in `CreateKernelConfig.target`, an error will be raised.
""" """
vector_width: int | None = None lanes: int | None = None
"""Desired vector register width in bits. """Number of SIMD lanes to be used in vectorization.
If set to an integer value, the vectorizer will use this as the desired vector register width.
If set to `None`, the vector register width will be automatically set to the broadest possible. If set to `None` (the default), the vector register width will be automatically set to the broadest possible.
If the selected CPU does not support the given width, an error will be raised. If the CPU architecture specified in `target <CreateKernelConfig.target>` does not support some
operation contained in the kernel with the given number of lanes, an error will be raised.
""" """
use_nontemporal_stores: bool | Collection[str | Field] = False use_nontemporal_stores: bool | Collection[str | Field] = False
...@@ -134,6 +147,25 @@ class VectorizationConfig: ...@@ -134,6 +147,25 @@ class VectorizationConfig:
that is not equal to one, an error will be raised. that is not equal to one, an error will be raised.
""" """
@staticmethod
def default_lanes(target: Target, dtype: PsScalarType):
if not target.is_vector_cpu():
raise ValueError(f"Given target {target} is no vector CPU target.")
assert dtype.itemsize is not None
match target:
case Target.X86_SSE:
return 128 // (dtype.itemsize * 8)
case Target.X86_AVX:
return 256 // (dtype.itemsize * 8)
case Target.X86_AVX512 | Target.X86_AVX512_FP16:
return 512 // (dtype.itemsize * 8)
case _:
raise NotImplementedError(
f"No default number of lanes known for {dtype} on {target}"
)
@dataclass @dataclass
class GpuIndexingConfig: class GpuIndexingConfig:
...@@ -266,6 +298,13 @@ class CreateKernelConfig: ...@@ -266,6 +298,13 @@ class CreateKernelConfig:
# Getters # Getters
def get_target(self) -> Target:
match self.target:
case Target.CurrentCPU:
return Target.auto_cpu()
case _:
return self.target
def get_jit(self) -> JitBase: def get_jit(self) -> JitBase:
"""Returns either the user-specified JIT compiler, or infers one from the target if none is given.""" """Returns either the user-specified JIT compiler, or infers one from the target if none is given."""
if self.jit is None: if self.jit is None:
...@@ -371,7 +410,7 @@ class CreateKernelConfig: ...@@ -371,7 +410,7 @@ class CreateKernelConfig:
warn( warn(
"Setting the deprecated `data_type` will override the value of `default_dtype`. " "Setting the deprecated `data_type` will override the value of `default_dtype`. "
"Set `default_dtype` instead.", "Set `default_dtype` instead.",
FutureWarning, UserWarning,
) )
self.default_dtype = data_type self.default_dtype = data_type
...@@ -394,7 +433,52 @@ class CreateKernelConfig: ...@@ -394,7 +433,52 @@ class CreateKernelConfig:
if cpu_vectorize_info is not None: if cpu_vectorize_info is not None:
_deprecated_option("cpu_vectorize_info", "cpu_optim.vectorize") _deprecated_option("cpu_vectorize_info", "cpu_optim.vectorize")
raise NotImplementedError("CPU vectorization is not implemented yet") if "instruction_set" in cpu_vectorize_info:
if self.target != Target.GenericCPU:
raise PsOptionsError(
"Setting 'instruction_set' in the deprecated 'cpu_vectorize_info' option is only "
"valid if `target == Target.CPU`."
)
isa = cpu_vectorize_info["instruction_set"]
vec_target: Target
match isa:
case "best":
vec_target = Target.available_vector_cpu_targets().pop()
case "sse":
vec_target = Target.X86_SSE
case "avx":
vec_target = Target.X86_AVX
case "avx512":
vec_target = Target.X86_AVX512
case "avx512vl":
vec_target = Target.X86_AVX512 | Target._VL
case _:
raise PsOptionsError(
f'Value {isa} in `cpu_vectorize_info["instruction_set"]` is not supported.'
)
warn(
f"Value {isa} for `instruction_set` in deprecated `cpu_vectorize_info` "
"will override the `target` option. "
f"Set `target` to {vec_target} instead.",
UserWarning,
)
self.target = vec_target
deprecated_vec_opts = VectorizationConfig(
assume_inner_stride_one=cpu_vectorize_info.get(
"assume_inner_stride_one", False
),
assume_aligned=cpu_vectorize_info.get("assume_aligned", False),
use_nontemporal_stores=cpu_vectorize_info.get("nontemporal", False),
)
if optim is not None:
optim = replace(optim, vectorize=deprecated_vec_opts)
else:
optim = CpuOptimConfig(vectorize=deprecated_vec_opts)
if optim is not None: if optim is not None:
if self.cpu_optim is not None: if self.cpu_optim is not None:
......
from typing import TypeVar, Generic, Callable from .types import (
from .types import PsType, PsIeeeFloatType, PsIntegerType, PsSignedIntegerType, PsStructType PsIeeeFloatType,
PsIntegerType,
PsSignedIntegerType,
PsStructType,
UserTypeSpec,
create_type,
)
from pystencils.sympyextensions.typed_sympy import TypedSymbol from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType
SymbolT = TypeVar("SymbolT")
class SympyDefaults:
class GenericDefaults(Generic[SymbolT]): def __init__(self):
def __init__(self, symcreate: Callable[[str, PsType], SymbolT]):
self.numeric_dtype = PsIeeeFloatType(64) self.numeric_dtype = PsIeeeFloatType(64)
"""Default data type for numerical computations""" """Default data type for numerical computations"""
...@@ -18,37 +22,38 @@ class GenericDefaults(Generic[SymbolT]): ...@@ -18,37 +22,38 @@ class GenericDefaults(Generic[SymbolT]):
"""Names of the default spatial counters""" """Names of the default spatial counters"""
self.spatial_counters = ( self.spatial_counters = (
symcreate("ctr_0", self.index_dtype), TypedSymbol("ctr_0", DynamicType.INDEX_TYPE),
symcreate("ctr_1", self.index_dtype), TypedSymbol("ctr_1", DynamicType.INDEX_TYPE),
symcreate("ctr_2", self.index_dtype), TypedSymbol("ctr_2", DynamicType.INDEX_TYPE),
) )
"""Default spatial counters""" """Default spatial counters"""
self._index_struct_coordinate_names = ("x", "y", "z") self.index_struct_coordinate_names = ("x", "y", "z")
"""Default names of spatial coordinate members in index list structures""" """Default names of spatial coordinate members in index list structures"""
self.index_struct_coordinates = (
PsStructType.Member("x", self.index_dtype),
PsStructType.Member("y", self.index_dtype),
PsStructType.Member("z", self.index_dtype),
)
"""Default spatial coordinate members in index list structures"""
self.sparse_counter_name = "sparse_idx" self.sparse_counter_name = "sparse_idx"
"""Name of the default sparse iteration counter""" """Name of the default sparse iteration counter"""
self.sparse_counter = symcreate(self.sparse_counter_name, self.index_dtype) self.sparse_counter = TypedSymbol(
self.sparse_counter_name, DynamicType.INDEX_TYPE
)
"""Default sparse iteration counter.""" """Default sparse iteration counter."""
def field_shape_name(self, field_name: str, coord: int): def field_shape_name(self, field_name: str, coord: int):
return f"_size_{field_name}_{coord}" return f"_size_{field_name}_{coord}"
def field_stride_name(self, field_name: str, coord: int): def field_stride_name(self, field_name: str, coord: int):
return f"_stride_{field_name}_{coord}" return f"_stride_{field_name}_{coord}"
def field_pointer_name(self, field_name: str): def field_pointer_name(self, field_name: str):
return f"_data_{field_name}" return f"_data_{field_name}"
def index_struct(self, index_dtype: UserTypeSpec, dim: int) -> PsStructType:
idx_type = create_type(index_dtype)
return PsStructType(
[(name, idx_type) for name in self.index_struct_coordinate_names[:dim]]
)
DEFAULTS = GenericDefaults[TypedSymbol](TypedSymbol) DEFAULTS = SympyDefaults()
"""Default names and symbols used throughout code generation""" """Default names and symbols used throughout code generation"""
...@@ -2,27 +2,35 @@ from typing import cast, Sequence ...@@ -2,27 +2,35 @@ from typing import cast, Sequence
from dataclasses import replace from dataclasses import replace
from .target import Target from .target import Target
from .config import CreateKernelConfig from .config import (
CreateKernelConfig,
OpenMpConfig,
VectorizationConfig,
)
from .backend import KernelFunction from .backend import KernelFunction
from .types import create_numeric_type, PsIntegerType from .types import create_numeric_type, PsIntegerType, PsScalarType
from .backend.ast.structural import PsBlock from .backend.ast.structural import PsBlock, PsLoop
from .backend.kernelcreation import ( from .backend.kernelcreation import (
KernelCreationContext, KernelCreationContext,
KernelAnalysis, KernelAnalysis,
FreezeExpressions, FreezeExpressions,
Typifier, Typifier,
) )
from .backend.constants import PsConstant
from .backend.kernelcreation.iteration_space import ( from .backend.kernelcreation.iteration_space import (
create_sparse_iteration_space, create_sparse_iteration_space,
create_full_iteration_space, create_full_iteration_space,
FullIterationSpace,
) )
from .backend.platforms import Platform, GenericCpu, GenericVectorCpu, GenericGpu
from .backend.exceptions import VectorizationError
from .backend.transformations import ( from .backend.transformations import (
EliminateConstants, EliminateConstants,
LowerToC, LowerToC,
SelectFunctions, SelectFunctions,
CanonicalizeSymbols, CanonicalizeSymbols,
HoistLoopInvariantDeclarations,
) )
from .backend.kernelfunction import ( from .backend.kernelfunction import (
create_cpu_kernel_function, create_cpu_kernel_function,
...@@ -60,125 +68,245 @@ def create_kernel( ...@@ -60,125 +68,245 @@ def create_kernel(
if kwargs: if kwargs:
config = replace(config, **kwargs) config = replace(config, **kwargs)
idx_dtype = create_numeric_type(config.index_dtype) driver = DefaultKernelCreationDriver(config)
assert isinstance(idx_dtype, PsIntegerType) return driver(assignments)
ctx = KernelCreationContext(
default_dtype=create_numeric_type(config.default_dtype),
index_dtype=idx_dtype,
)
if isinstance(assignments, AssignmentBase): class DefaultKernelCreationDriver:
assignments = [assignments] def __init__(self, cfg: CreateKernelConfig):
self._cfg = cfg
if not isinstance(assignments, AssignmentCollection): idx_dtype = create_numeric_type(self._cfg.index_dtype)
assignments = AssignmentCollection(assignments) # type: ignore assert isinstance(idx_dtype, PsIntegerType)
_ = _parse_simplification_hints(assignments) self._ctx = KernelCreationContext(
default_dtype=create_numeric_type(self._cfg.default_dtype),
index_dtype=idx_dtype,
)
analysis = KernelAnalysis( self._target = self._cfg.get_target()
ctx, not config.skip_independence_check, not config.allow_double_writes self._platform = self._get_platform()
)
analysis(assignments)
if len(ctx.fields.index_fields) > 0 or config.index_field is not None: def __call__(
ispace = create_sparse_iteration_space( self,
ctx, assignments, index_field=config.index_field assignments: AssignmentCollection | Sequence[AssignmentBase] | AssignmentBase,
) ):
else: if isinstance(assignments, AssignmentBase):
ispace = create_full_iteration_space( assignments = [assignments]
ctx,
assignments, if not isinstance(assignments, AssignmentCollection):
ghost_layers=config.ghost_layers, assignments = AssignmentCollection(assignments) # type: ignore
iteration_slice=config.iteration_slice,
_ = _parse_simplification_hints(assignments)
analysis = KernelAnalysis(
self._ctx,
not self._cfg.skip_independence_check,
not self._cfg.allow_double_writes,
) )
analysis(assignments)
ctx.set_iteration_space(ispace) if len(self._ctx.fields.index_fields) > 0 or self._cfg.index_field is not None:
ispace = create_sparse_iteration_space(
self._ctx, assignments, index_field=self._cfg.index_field
)
else:
ispace = create_full_iteration_space(
self._ctx,
assignments,
ghost_layers=self._cfg.ghost_layers,
iteration_slice=self._cfg.iteration_slice,
)
freeze = FreezeExpressions(ctx) self._ctx.set_iteration_space(ispace)
kernel_body = freeze(assignments)
typify = Typifier(ctx) freeze = FreezeExpressions(self._ctx)
kernel_body = typify(kernel_body) kernel_body = freeze(assignments)
match config.target: typify = Typifier(self._ctx)
case Target.GenericCPU: kernel_body = typify(kernel_body)
from .backend.platforms import GenericCpu
platform = GenericCpu(ctx) match self._platform:
kernel_ast = platform.materialize_iteration_space(kernel_body, ispace) case GenericCpu():
kernel_ast = self._platform.materialize_iteration_space(
kernel_body, ispace
)
case GenericGpu():
kernel_ast, gpu_threads = self._platform.materialize_iteration_space(
kernel_body, ispace
)
case target if target.is_gpu(): # Fold and extract constants
match target: elim_constants = EliminateConstants(self._ctx, extract_constant_exprs=True)
case Target.SYCL: kernel_ast = cast(PsBlock, elim_constants(kernel_ast))
from .backend.platforms import SyclPlatform
platform = SyclPlatform(ctx, config.gpu_indexing) # Target-Specific optimizations
case Target.CUDA: if self._cfg.target.is_cpu():
from .backend.platforms import CudaPlatform kernel_ast = self._transform_for_cpu(kernel_ast)
# Note: After this point, the AST may contain intrinsics, so type-dependent
# transformations cannot be run any more
# Lowering
lower_to_c = LowerToC(self._ctx)
kernel_ast = cast(PsBlock, lower_to_c(kernel_ast))
select_functions = SelectFunctions(self._platform)
kernel_ast = cast(PsBlock, select_functions(kernel_ast))
# Late canonicalization pass: Canonicalize new symbols introduced by LowerToC
platform = CudaPlatform(ctx, config.gpu_indexing) canonicalize = CanonicalizeSymbols(self._ctx, True)
case _: kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
raise NotImplementedError(
f"Code generation for target {target} not implemented"
)
kernel_ast, gpu_threads = platform.materialize_iteration_space( if self._cfg.target.is_cpu():
kernel_body, ispace return create_cpu_kernel_function(
self._ctx,
self._platform,
kernel_ast,
self._cfg.function_name,
self._cfg.target,
self._cfg.get_jit(),
) )
else:
return create_gpu_kernel_function(
self._ctx,
self._platform,
kernel_ast,
gpu_threads,
self._cfg.function_name,
self._cfg.target,
self._cfg.get_jit(),
)
def _transform_for_cpu(self, kernel_ast: PsBlock):
canonicalize = CanonicalizeSymbols(self._ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
hoist_invariants = HoistLoopInvariantDeclarations(self._ctx)
kernel_ast = cast(PsBlock, hoist_invariants(kernel_ast))
cpu_cfg = self._cfg.cpu_optim
if cpu_cfg is None:
return kernel_ast
case _: if cpu_cfg.loop_blocking:
raise NotImplementedError( raise NotImplementedError("Loop blocking not implemented yet.")
f"Code generation for target {target} not implemented"
kernel_ast = self._vectorize(kernel_ast)
if cpu_cfg.openmp is not False:
from .backend.transformations import AddOpenMP
params = (
cpu_cfg.openmp
if isinstance(cpu_cfg.openmp, OpenMpConfig)
else OpenMpConfig()
) )
add_omp = AddOpenMP(self._ctx, params)
kernel_ast = cast(PsBlock, add_omp(kernel_ast))
if cpu_cfg.use_cacheline_zeroing:
raise NotImplementedError("CL-zeroing not implemented yet")
# Fold and extract constants return kernel_ast
elim_constants = EliminateConstants(ctx, extract_constant_exprs=True)
kernel_ast = cast(PsBlock, elim_constants(kernel_ast))
# Target-Specific optimizations def _vectorize(self, kernel_ast: PsBlock) -> PsBlock:
if config.target.is_cpu(): assert self._cfg.cpu_optim is not None
from .backend.kernelcreation import optimize_cpu vec_config = self._cfg.cpu_optim.get_vectorization_config()
if vec_config is None:
return kernel_ast
assert isinstance(platform, GenericCpu) from .backend.transformations import LoopVectorizer, SelectIntrinsics
kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) assert isinstance(self._platform, GenericVectorCpu)
# Lowering ispace = self._ctx.get_iteration_space()
lower_to_c = LowerToC(ctx) if not isinstance(ispace, FullIterationSpace):
kernel_ast = cast(PsBlock, lower_to_c(kernel_ast)) raise VectorizationError(
"Unable to vectorize kernel: The kernel is not using a dense iteration space."
)
select_functions = SelectFunctions(platform) inner_loop_coord = ispace.loop_order[-1]
kernel_ast = cast(PsBlock, select_functions(kernel_ast)) inner_loop_dim = ispace.dimensions[inner_loop_coord]
# Apply stride (TODO: and alignment) assumptions
if vec_config.assume_inner_stride_one:
for field in self._ctx.fields:
buf = self._ctx.get_buffer(field)
inner_stride = buf.strides[inner_loop_coord]
if isinstance(inner_stride, PsConstant):
if inner_stride.value != 1:
raise VectorizationError(
f"Unable to apply assumption 'assume_inner_stride_one': "
f"Field {field} has fixed stride {inner_stride} "
f"set in the inner coordinate {inner_loop_coord}."
)
else:
buf.strides[inner_loop_coord] = PsConstant(1, buf.index_type)
# TODO: Communicate assumption to runtime system via a precondition
# Call loop vectorizer
if vec_config.lanes is None:
lanes = VectorizationConfig.default_lanes(
self._target, cast(PsScalarType, self._ctx.default_dtype)
)
else:
lanes = vec_config.lanes
vectorizer = LoopVectorizer(self._ctx, lanes)
def loop_predicate(loop: PsLoop):
return loop.counter.symbol == inner_loop_dim.counter
kernel_ast = vectorizer.vectorize_select_loops(kernel_ast, loop_predicate)
select_intrin = SelectIntrinsics(self._ctx, self._platform)
kernel_ast = cast(PsBlock, select_intrin(kernel_ast))
return kernel_ast
def _get_platform(self) -> Platform:
if Target._CPU in self._target:
if Target._X86 in self._target:
from .backend.platforms.x86 import X86VectorArch, X86VectorCpu
arch: X86VectorArch
if Target._SSE in self._target:
arch = X86VectorArch.SSE
elif Target._AVX in self._target:
arch = X86VectorArch.AVX
elif Target._AVX512 in self._target:
if Target._FP16 in self._target:
arch = X86VectorArch.AVX512_FP16
else:
arch = X86VectorArch.AVX512
else:
assert False, "unreachable code"
return X86VectorCpu(self._ctx, arch)
elif self._target == Target.GenericCPU:
return GenericCpu(self._ctx)
else:
raise NotImplementedError(
f"No platform is currently available for CPU target {self._target}"
)
elif Target._GPU in self._target:
match self._target:
case Target.SYCL:
from .backend.platforms import SyclPlatform
# Late canonicalization and constant elimination passes return SyclPlatform(self._ctx, self._cfg.gpu_indexing)
# * Since lowering introduces new index calculations and indexing symbols into the AST, case Target.CUDA:
# * these need to be handled here from .backend.platforms import CudaPlatform
canonicalize = CanonicalizeSymbols(ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
late_fold_constants = EliminateConstants(ctx, extract_constant_exprs=False) return CudaPlatform(self._ctx, self._cfg.gpu_indexing)
kernel_ast = cast(PsBlock, late_fold_constants(kernel_ast))
if config.target.is_cpu(): raise NotImplementedError(
return create_cpu_kernel_function( f"Code generation for target {self._target} not implemented"
ctx,
platform,
kernel_ast,
config.function_name,
config.target,
config.get_jit(),
)
else:
return create_gpu_kernel_function(
ctx,
platform,
kernel_ast,
gpu_threads,
config.function_name,
config.target,
config.get_jit(),
) )
...@@ -192,6 +320,9 @@ def create_staggered_kernel( ...@@ -192,6 +320,9 @@ def create_staggered_kernel(
# Internals # Internals
def _parse_simplification_hints(ac: AssignmentCollection): def _parse_simplification_hints(ac: AssignmentCollection):
if "split_groups" in ac.simplification_hints: if "split_groups" in ac.simplification_hints:
raise NotImplementedError("Loop splitting was requested, but is not implemented yet") raise NotImplementedError(
"Loop splitting was requested, but is not implemented yet"
)
...@@ -118,7 +118,7 @@ class Target(Flag): ...@@ -118,7 +118,7 @@ class Target(Flag):
@staticmethod @staticmethod
def available_vector_cpu_targets() -> list[Target]: def available_vector_cpu_targets() -> list[Target]:
"""Returns a list of available (vector) CPU targets, ordered from least to most capable.""" """Returns a list of available vector CPU targets, ordered from least to most capable."""
return _available_vector_targets() return _available_vector_targets()
......
...@@ -138,7 +138,7 @@ class PsType(metaclass=PsTypeMeta): ...@@ -138,7 +138,7 @@ class PsType(metaclass=PsTypeMeta):
@property @property
def itemsize(self) -> int | None: def itemsize(self) -> int | None:
"""If this type has a valid in-memory size, return that size.""" """If this type has a valid in-memory size, return that size in bytes."""
return None return None
@property @property
......
...@@ -100,15 +100,19 @@ class PsPointerType(PsDereferencableType): ...@@ -100,15 +100,19 @@ class PsPointerType(PsDereferencableType):
class PsArrayType(PsDereferencableType): class PsArrayType(PsDereferencableType):
"""Multidimensional array of fixed shape. """Multidimensional array of fixed shape.
The element type of an array is never const; only the array itself can be. The element type of an array is never const; only the array itself can be.
If ``element_type`` is const, its constness will be removed. If ``element_type`` is const, its constness will be removed.
""" """
def __init__( def __init__(
self, element_type: PsType, shape: SupportsIndex | Sequence[SupportsIndex], const: bool = False self,
element_type: PsType,
shape: SupportsIndex | Sequence[SupportsIndex],
const: bool = False,
): ):
from operator import index from operator import index
if isinstance(shape, SupportsIndex): if isinstance(shape, SupportsIndex):
shape = (index(shape),) shape = (index(shape),)
else: else:
...@@ -116,10 +120,10 @@ class PsArrayType(PsDereferencableType): ...@@ -116,10 +120,10 @@ class PsArrayType(PsDereferencableType):
if not shape or any(s <= 0 for s in shape): if not shape or any(s <= 0 for s in shape):
raise ValueError(f"Invalid array shape: {shape}") raise ValueError(f"Invalid array shape: {shape}")
if isinstance(element_type, PsArrayType): if isinstance(element_type, PsArrayType):
raise ValueError("Element type of array cannot be another array.") raise ValueError("Element type of array cannot be another array.")
element_type = deconstify(element_type) element_type = deconstify(element_type)
self._shape = shape self._shape = shape
...@@ -137,7 +141,7 @@ class PsArrayType(PsDereferencableType): ...@@ -137,7 +141,7 @@ class PsArrayType(PsDereferencableType):
def shape(self) -> tuple[int, ...]: def shape(self) -> tuple[int, ...]:
"""Shape of this array""" """Shape of this array"""
return self._shape return self._shape
@property @property
def dim(self) -> int: def dim(self) -> int:
"""Dimensionality of this array""" """Dimensionality of this array"""
...@@ -396,12 +400,13 @@ class PsVectorType(PsNumericType): ...@@ -396,12 +400,13 @@ class PsVectorType(PsNumericType):
return np.dtype((self._scalar_type.numpy_dtype, (self._vector_entries,))) return np.dtype((self._scalar_type.numpy_dtype, (self._vector_entries,)))
def create_constant(self, value: Any) -> Any: def create_constant(self, value: Any) -> Any:
if ( if isinstance(value, np.ndarray):
isinstance(value, np.ndarray) if value.shape != (self._vector_entries,):
and value.dtype == self.scalar_type.numpy_dtype raise PsTypeError(
and value.shape == (self._vector_entries,) f"Cannot create constant of vector type {self} from array of shape {value.shape}"
): )
return value.copy()
return np.array([self._scalar_type.create_constant(v) for v in value])
element = self._scalar_type.create_constant(value) element = self._scalar_type.create_constant(value)
return np.array( return np.array(
...@@ -552,7 +557,7 @@ class PsIntegerType(PsScalarType, ABC): ...@@ -552,7 +557,7 @@ class PsIntegerType(PsScalarType, ABC):
def c_string(self) -> str: def c_string(self) -> str:
return f"{self._const_string()}{self._str_without_const()}_t" return f"{self._const_string()}{self._str_without_const()}_t"
def __str__(self) -> str: def __str__(self) -> str:
return f"{self._const_string()}{self._str_without_const()}" return f"{self._const_string()}{self._str_without_const()}"
......
...@@ -2,32 +2,84 @@ import pytest ...@@ -2,32 +2,84 @@ import pytest
import sympy as sp import sympy as sp
import numpy as np import numpy as np
from pystencils import fields, Field, AssignmentCollection, Target, CreateKernelConfig from dataclasses import replace
from pystencils import (
fields,
Field,
AssignmentCollection,
Target,
CreateKernelConfig,
CpuOptimConfig,
VectorizationConfig,
)
from pystencils.assignment import assignment_from_stencil from pystencils.assignment import assignment_from_stencil
from pystencils.kernelcreation import create_kernel from pystencils.kernelcreation import create_kernel, KernelFunction
from pystencils.backend.emission import emit_code
AVAILABLE_TARGETS = [Target.GenericCPU]
@pytest.mark.parametrize("target", (Target.GenericCPU, Target.CUDA)) try:
def test_filter_kernel(target): import cupy
if target == Target.CUDA:
xp = pytest.importorskip("cupy") AVAILABLE_TARGETS += [Target.CUDA]
except ImportError:
pass
AVAILABLE_TARGETS += Target.available_vector_cpu_targets()
TEST_IDS = [t.name for t in AVAILABLE_TARGETS]
@pytest.fixture(params=AVAILABLE_TARGETS, ids=TEST_IDS)
def gen_config(request):
target: Target = request.param
gen_config = CreateKernelConfig(target=target)
if Target._VECTOR in target:
gen_config = replace(
gen_config,
cpu_optim=CpuOptimConfig(
vectorize=VectorizationConfig(assume_inner_stride_one=True)
),
)
return gen_config
def inspect_dp_kernel(kernel: KernelFunction, gen_config: CreateKernelConfig):
code = emit_code(kernel)
match gen_config.target:
case Target.X86_SSE:
assert "_mm_loadu_pd" in code
assert "_mm_storeu_pd" in code
case Target.X86_AVX:
assert "_mm256_loadu_pd" in code
assert "_mm256_storeu_pd" in code
case Target.X86_AVX512:
assert "_mm512_loadu_pd" in code
assert "_mm512_storeu_pd" in code
def test_filter_kernel(gen_config):
if gen_config.target == Target.CUDA:
import cupy as cp
xp = cp
else: else:
xp = np xp = np
weight = sp.Symbol("weight") weight = sp.Symbol("weight")
stencil = [ stencil = [[1, 1, 1], [1, 1, 1], [1, 1, 1]]
[1, 1, 1],
[1, 1, 1],
[1, 1, 1]
]
src, dst = fields("src, dst: [2D]") src, dst = fields("src, dst: [2D]")
asm = assignment_from_stencil(stencil, src, dst, normalization_factor=weight) asm = assignment_from_stencil(stencil, src, dst, normalization_factor=weight)
asms = AssignmentCollection([asm]) asms = AssignmentCollection([asm])
gen_config = CreateKernelConfig(target=target)
ast = create_kernel(asms, gen_config) ast = create_kernel(asms, gen_config)
inspect_dp_kernel(ast, gen_config)
kernel = ast.compile() kernel = ast.compile()
src_arr = xp.ones((42, 31)) src_arr = xp.ones((42, 31))
...@@ -41,31 +93,28 @@ def test_filter_kernel(target): ...@@ -41,31 +93,28 @@ def test_filter_kernel(target):
xp.testing.assert_allclose(dst_arr, expected) xp.testing.assert_allclose(dst_arr, expected)
@pytest.mark.parametrize("target", (Target.GenericCPU, Target.CUDA)) def test_filter_kernel_fixedsize(gen_config):
def test_filter_kernel_fixedsize(target): if gen_config.target == Target.CUDA:
if target == Target.CUDA: import cupy as cp
xp = pytest.importorskip("cupy")
xp = cp
else: else:
xp = np xp = np
weight = sp.Symbol("weight") weight = sp.Symbol("weight")
stencil = [ stencil = [[1, 1, 1], [1, 1, 1], [1, 1, 1]]
[1, 1, 1],
[1, 1, 1],
[1, 1, 1]
]
src_arr = xp.ones((42, 31)) src_arr = xp.ones((42, 31))
dst_arr = xp.zeros_like(src_arr) dst_arr = xp.zeros_like(src_arr)
src = Field.create_from_numpy_array("src", src_arr) src = Field.create_from_numpy_array("src", src_arr)
dst = Field.create_from_numpy_array("dst", dst_arr) dst = Field.create_from_numpy_array("dst", dst_arr)
asm = assignment_from_stencil(stencil, src, dst, normalization_factor=weight) asm = assignment_from_stencil(stencil, src, dst, normalization_factor=weight)
asms = AssignmentCollection([asm]) asms = AssignmentCollection([asm])
gen_config = CreateKernelConfig(target=target)
ast = create_kernel(asms, gen_config) ast = create_kernel(asms, gen_config)
inspect_dp_kernel(ast, gen_config)
kernel = ast.compile() kernel = ast.compile()
kernel(src=src_arr, dst=dst_arr, weight=2.0) kernel(src=src_arr, dst=dst_arr, weight=2.0)
......
...@@ -27,6 +27,3 @@ def test_sliced_iteration(): ...@@ -27,6 +27,3 @@ def test_sliced_iteration():
expected_result = np.zeros(size) expected_result = np.zeros(size)
expected_result[1:x_end_value, 1] = 1 expected_result[1:x_end_value, 1] = 1
np.testing.assert_almost_equal(expected_result, dst_arr) np.testing.assert_almost_equal(expected_result, dst_arr)
test_sliced_iteration()
import pytest
import numpy as np
from pystencils import (
Field,
Assignment,
create_kernel,
CreateKernelConfig,
DEFAULTS,
FieldType,
)
from pystencils.sympyextensions import CastFunc
@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"])
def test_spatial_counters_dense(index_dtype):
# Parametrized over index_dtype to make sure the `DynamicType.INDEX` in the
# DEFAULTS works validly
x, y, z = DEFAULTS.spatial_counters
f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
asms = [
Assignment(f(0), CastFunc.as_numeric(z)),
Assignment(f(1), CastFunc.as_numeric(y)),
Assignment(f(2), CastFunc.as_numeric(x)),
]
cfg = CreateKernelConfig(index_dtype=index_dtype)
kernel = create_kernel(asms, cfg).compile()
f_arr = np.zeros((16, 16, 16, 3))
kernel(f=f_arr)
expected = np.mgrid[0:16, 0:16, 0:16].astype(np.float64).transpose()
np.testing.assert_equal(f_arr, expected)
@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"])
def test_spatial_counters_sparse(index_dtype):
x, y, z = DEFAULTS.spatial_counters
f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
asms = [
Assignment(f(0), CastFunc.as_numeric(x)),
Assignment(f(1), CastFunc.as_numeric(y)),
Assignment(f(2), CastFunc.as_numeric(z)),
]
idx_struct = DEFAULTS.index_struct(index_dtype, 3)
idx_field = Field.create_generic(
"index", 1, idx_struct, field_type=FieldType.INDEXED
)
cfg = CreateKernelConfig(index_dtype=index_dtype, index_field=idx_field)
kernel = create_kernel(asms, cfg).compile()
f_arr = np.zeros((16, 16, 16, 3))
idx_arr = np.array(
[(1, 4, 3), (5, 1, 6), (9, 5, 1), (3, 13, 7)], dtype=idx_struct.numpy_dtype
)
kernel(f=f_arr, index=idx_arr)
for t in idx_arr:
assert f_arr[t[0], t[1], t[2], 0] == t[0].astype(np.float64)
assert f_arr[t[0], t[1], t[2], 1] == t[1].astype(np.float64)
assert f_arr[t[0], t[1], t[2], 2] == t[2].astype(np.float64)
...@@ -3,6 +3,8 @@ import sympy as sp ...@@ -3,6 +3,8 @@ import sympy as sp
import numpy as np import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from itertools import chain from itertools import chain
from functools import partial
from typing import Callable
from pystencils.backend.kernelcreation import ( from pystencils.backend.kernelcreation import (
KernelCreationContext, KernelCreationContext,
...@@ -28,64 +30,52 @@ from pystencils.types.quick import SInt, Fp ...@@ -28,64 +30,52 @@ from pystencils.types.quick import SInt, Fp
@dataclass @dataclass
class VectorTestSetup: class VectorTestSetup:
platform: GenericVectorCpu target: Target
platform_factory: Callable[[KernelCreationContext], GenericVectorCpu]
lanes: int lanes: int
numeric_dtype: PsScalarType numeric_dtype: PsScalarType
index_dtype: PsIntegerType index_dtype: PsIntegerType
@property @property
def name(self) -> str: def name(self) -> str:
if isinstance(self.platform, X86VectorCpu): return f"{self.target.name}/{self.numeric_dtype}<{self.lanes}>/{self.index_dtype}"
match self.platform.vector_arch:
case X86VectorArch.SSE:
isa = "SSE"
case X86VectorArch.AVX:
isa = "AVX"
case X86VectorArch.AVX512:
isa = "AVX512"
case X86VectorArch.AVX512_FP16:
isa = "AVX512_FP16"
else:
assert False
return f"{isa}/{self.numeric_dtype}<{self.lanes}>/{self.index_dtype}"
def get_setups(target: Target) -> list[VectorTestSetup]: def get_setups(target: Target) -> list[VectorTestSetup]:
match target: match target:
case Target.X86_SSE: case Target.X86_SSE:
sse_platform = X86VectorCpu(X86VectorArch.SSE) sse_platform = partial(X86VectorCpu, vector_arch=X86VectorArch.SSE)
return [ return [
VectorTestSetup(sse_platform, 4, Fp(32), SInt(32)), VectorTestSetup(target, sse_platform, 4, Fp(32), SInt(32)),
VectorTestSetup(sse_platform, 2, Fp(64), SInt(64)), VectorTestSetup(target, sse_platform, 2, Fp(64), SInt(64)),
] ]
case Target.X86_AVX: case Target.X86_AVX:
avx_platform = X86VectorCpu(X86VectorArch.AVX) avx_platform = partial(X86VectorCpu, vector_arch=X86VectorArch.AVX)
return [ return [
VectorTestSetup(avx_platform, 4, Fp(32), SInt(32)), VectorTestSetup(target, avx_platform, 4, Fp(32), SInt(32)),
VectorTestSetup(avx_platform, 8, Fp(32), SInt(32)), VectorTestSetup(target, avx_platform, 8, Fp(32), SInt(32)),
VectorTestSetup(avx_platform, 2, Fp(64), SInt(64)), VectorTestSetup(target, avx_platform, 2, Fp(64), SInt(64)),
VectorTestSetup(avx_platform, 4, Fp(64), SInt(64)), VectorTestSetup(target, avx_platform, 4, Fp(64), SInt(64)),
] ]
case Target.X86_AVX512: case Target.X86_AVX512:
avx512_platform = X86VectorCpu(X86VectorArch.AVX512) avx512_platform = partial(X86VectorCpu, vector_arch=X86VectorArch.AVX512)
return [ return [
VectorTestSetup(avx512_platform, 4, Fp(32), SInt(32)), VectorTestSetup(target, avx512_platform, 4, Fp(32), SInt(32)),
VectorTestSetup(avx512_platform, 8, Fp(32), SInt(32)), VectorTestSetup(target, avx512_platform, 8, Fp(32), SInt(32)),
VectorTestSetup(avx512_platform, 16, Fp(32), SInt(32)), VectorTestSetup(target, avx512_platform, 16, Fp(32), SInt(32)),
VectorTestSetup(avx512_platform, 2, Fp(64), SInt(64)), VectorTestSetup(target, avx512_platform, 2, Fp(64), SInt(64)),
VectorTestSetup(avx512_platform, 4, Fp(64), SInt(64)), VectorTestSetup(target, avx512_platform, 4, Fp(64), SInt(64)),
VectorTestSetup(avx512_platform, 8, Fp(64), SInt(64)), VectorTestSetup(target, avx512_platform, 8, Fp(64), SInt(64)),
] ]
case Target.X86_AVX512_FP16: case Target.X86_AVX512_FP16:
avx512_platform = X86VectorCpu(X86VectorArch.AVX512_FP16) avx512_platform = partial(X86VectorCpu, vector_arch=X86VectorArch.AVX512_FP16)
return [ return [
VectorTestSetup(avx512_platform, 8, Fp(16), SInt(32)), VectorTestSetup(target, avx512_platform, 8, Fp(16), SInt(32)),
VectorTestSetup(avx512_platform, 16, Fp(16), SInt(32)), VectorTestSetup(target, avx512_platform, 16, Fp(16), SInt(32)),
VectorTestSetup(avx512_platform, 32, Fp(16), SInt(32)), VectorTestSetup(target, avx512_platform, 32, Fp(16), SInt(32)),
] ]
case _: case _:
...@@ -108,6 +98,7 @@ def create_vector_kernel( ...@@ -108,6 +98,7 @@ def create_vector_kernel(
ctx = KernelCreationContext( ctx = KernelCreationContext(
default_dtype=setup.numeric_dtype, index_dtype=setup.index_dtype default_dtype=setup.numeric_dtype, index_dtype=setup.index_dtype
) )
platform = setup.platform_factory(ctx)
factory = AstFactory(ctx) factory = AstFactory(ctx)
...@@ -129,7 +120,7 @@ def create_vector_kernel( ...@@ -129,7 +120,7 @@ def create_vector_kernel(
loop_nest, lambda l: l.counter.symbol.name == "ctr_0" loop_nest, lambda l: l.counter.symbol.name == "ctr_0"
) )
select_intrin = SelectIntrinsics(ctx, setup.platform) select_intrin = SelectIntrinsics(ctx, platform)
loop_nest = select_intrin(loop_nest) loop_nest = select_intrin(loop_nest)
lower = LowerToC(ctx) lower = LowerToC(ctx)
...@@ -137,7 +128,7 @@ def create_vector_kernel( ...@@ -137,7 +128,7 @@ def create_vector_kernel(
func = create_cpu_kernel_function( func = create_cpu_kernel_function(
ctx, ctx,
setup.platform, platform,
PsBlock([loop_nest]), PsBlock([loop_nest]),
"vector_kernel", "vector_kernel",
Target.CPU, Target.CPU,
......