diff --git a/docs/source/api/kernelcreation/index.rst b/docs/source/api/kernelcreation/index.rst index 582ee16682f913cd8e1fac5a97d90eda505588f7..f455ab9604bf14ce8cbda4ccbca3cc229648d411 100644 --- a/docs/source/api/kernelcreation/index.rst +++ b/docs/source/api/kernelcreation/index.rst @@ -4,10 +4,15 @@ Kernel Creation The primary interface for creating numerical kernels in pystencils is the function `create_kernel`. +Configuration +============= -.. autoclass:: pystencils.CreateKernelConfig +.. automodule:: pystencils.config :members: +Creation +======== + .. autofunction:: pystencils.create_kernel .. autoclass:: pystencils.backend.KernelFunction diff --git a/docs/source/backend/ast.rst b/docs/source/backend/ast.rst index 0d0d794ab28ea230effa0b871e64c77d6a79dec2..41f23016664002fd100544d72c509f9f73d72bdd 100644 --- a/docs/source/backend/ast.rst +++ b/docs/source/backend/ast.rst @@ -2,11 +2,29 @@ Abstract Syntax Tree ******************** +Inheritance Diagramm +==================== + +.. inheritance-diagram:: pystencils.backend.ast.astnode.PsAstNode pystencils.backend.ast.structural pystencils.backend.ast.expressions pystencils.backend.extensions.foreign_ast + :top-classes: pystencils.types.PsAstNode + :parts: 1 + + +Base Classes +============ + .. automodule:: pystencils.backend.ast.astnode :members: +Structural Nodes +================ + .. automodule:: pystencils.backend.ast.structural :members: + +Expressions +=========== + .. automodule:: pystencils.backend.ast.expressions :members: diff --git a/docs/source/backend/extensions.rst b/docs/source/backend/extensions.rst new file mode 100644 index 0000000000000000000000000000000000000000..6fb95cda06b0cc1a7440764fa2f2fe10cd1f84c1 --- /dev/null +++ b/docs/source/backend/extensions.rst @@ -0,0 +1,5 @@ +************************************ +Extensions and Experimental Features +************************************ + +.. automodule:: pystencils.backend.extensions diff --git a/docs/source/backend/index.rst b/docs/source/backend/index.rst index e0e914b4d423fb5b9e32950185c6aa3474976d39..f2fe9346dbe4d38722b69dd9c279d0eb11c98773 100644 --- a/docs/source/backend/index.rst +++ b/docs/source/backend/index.rst @@ -16,6 +16,7 @@ who wish to customize or extend the behaviour of the code generator in their app platforms transformations jit + extensions Internal Representation ----------------------- diff --git a/mypy.ini b/mypy.ini index 07228fe24009da6ea4f21cb6cdf15a0516041149..e89adf9f5eaa918753f94ffbf6ba00b1be6e39cd 100644 --- a/mypy.ini +++ b/mypy.ini @@ -19,3 +19,6 @@ ignore_missing_imports=true [mypy-islpy.*] ignore_missing_imports=true + +[mypy-cupy.*] +ignore_missing_imports=true diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 1761c88f0e69c06f30935007d187a3af947c900f..f5cb3e10b646dd6ead4673827d0166c9f0f9e5ea 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -13,10 +13,12 @@ from .config import ( CpuOptimConfig, VectorizationConfig, OpenMpConfig, + GpuIndexingConfig, ) from .kernel_decorator import kernel, kernel_config from .kernelcreation import create_kernel, create_staggered_kernel from .backend.kernelfunction import KernelFunction +from .backend.jit import no_jit from .slicing import make_slice from .spatial_coordinates import ( x_, @@ -47,11 +49,13 @@ __all__ = [ "CreateKernelConfig", "CpuOptimConfig", "VectorizationConfig", + "GpuIndexingConfig", "OpenMpConfig", "create_kernel", "create_staggered_kernel", "KernelFunction", "Target", + "no_jit", "show_code", "to_dot", "get_code_obj", diff --git a/src/pystencils/backend/__init__.py b/src/pystencils/backend/__init__.py index 646fc3055a68c4ca3ff63035b2a72a80591e51a4..a0b1c8f747984e3fffde5a336f40e2aa46ad631d 100644 --- a/src/pystencils/backend/__init__.py +++ b/src/pystencils/backend/__init__.py @@ -5,6 +5,7 @@ from .kernelfunction import ( FieldStrideParam, FieldPointerParam, KernelFunction, + GpuKernelFunction, ) from .constraints import KernelParamsConstraint @@ -16,5 +17,6 @@ __all__ = [ "FieldStrideParam", "FieldPointerParam", "KernelFunction", + "GpuKernelFunction", "KernelParamsConstraint", ] diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 3b76e514e566c9c7113c1bbcaf1b63df71a8cb57..5f9c95d5d53d824620b030f7a618b79e8b81564f 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -761,3 +761,33 @@ class PsArrayInitList(PsExpression): def __repr__(self) -> str: return f"PsArrayInitList({repr(self._items)})" + + +def evaluate_expression( + expr: PsExpression, valuation: dict[str, Any] +) -> Any: + """Evaluate a pystencils backend expression tree with values assigned to symbols according to the given valuation. + + Only a subset of expression nodes can be processed by this evaluator. + """ + + def visit(node): + match node: + case PsSymbolExpr(symb): + return valuation[symb.name] + + case PsConstantExpr(c): + return c.value + + case PsUnOp(op1) if node.python_operator is not None: + return node.python_operator(visit(op1)) + + case PsBinOp(op1, op2) if node.python_operator is not None: + return node.python_operator(visit(op1), visit(op2)) + + case other: + raise NotImplementedError( + f"Unable to evaluate {other}: No implementation available." + ) + + return visit(expr) diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 0f220149d2913b218f3a41bb8ea5b35018aa251d..8928cc6894284a417341f0b12a9e8d6a07f8ba48 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -1,6 +1,8 @@ from __future__ import annotations from enum import Enum +from ..enums import Target + from .ast.structural import ( PsAstNode, PsBlock, @@ -50,10 +52,12 @@ from .ast.expressions import ( PsLe, ) +from .extensions.foreign_ast import PsForeignExpression + from .symbols import PsSymbol from ..types import PsScalarType, PsArrayType -from .kernelfunction import KernelFunction +from .kernelfunction import KernelFunction, GpuKernelFunction __all__ = ["emit_code", "CAstPrinter"] @@ -176,10 +180,11 @@ class CAstPrinter: return self.visit(obj, PrinterCtx()) def print_signature(self, func: KernelFunction) -> str: + prefix = self._func_prefix(func) params_str = ", ".join( f"{p.dtype.c_string()} {p.name}" for p in func.parameters ) - signature = f"FUNC_PREFIX void {func.name} ({params_str})" + signature = " ".join([prefix, "void", func.name, f"({params_str})"]) return signature def visit(self, node: PsAstNode, pc: PrinterCtx) -> str: @@ -356,9 +361,21 @@ class CAstPrinter: pc.pop_op() return "{ " + items_str + " }" + case PsForeignExpression(children): + pc.push_op(Ops.Weakest, LR.Middle) + foreign_code = node.get_code(self.visit(c, pc) for c in children) + pc.pop_op() + return foreign_code + case _: raise NotImplementedError(f"Don't know how to print {node}") + def _func_prefix(self, func: KernelFunction): + if isinstance(func, GpuKernelFunction) and func.target == Target.CUDA: + return "__global__" + else: + return "FUNC_PREFIX" + def _symbol_decl(self, symb: PsSymbol): dtype = symb.get_dtype() diff --git a/src/pystencils/backend/exceptions.py b/src/pystencils/backend/exceptions.py index 4c081224913bfc61c4f542501ce8f4b5a1ddc59c..d42f7c11fdc4dc13b7c520119057336da3b6e3e2 100644 --- a/src/pystencils/backend/exceptions.py +++ b/src/pystencils/backend/exceptions.py @@ -5,10 +5,6 @@ class PsInternalCompilerError(Exception): """Indicates an internal error during kernel translation, most likely due to a bug inside pystencils.""" -class PsOptionsError(Exception): - """Indicates an option clash in the `CreateKernelConfig`.""" - - class PsInputError(Exception): """Indicates unsupported user input to the translation system""" diff --git a/src/pystencils/backend/extensions/__init__.py b/src/pystencils/backend/extensions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c699a3ff02c5ff311c872c70a861a52fea13419 --- /dev/null +++ b/src/pystencils/backend/extensions/__init__.py @@ -0,0 +1,29 @@ +""" +The module `pystencils.backend.extensions` contains extensions to the pystencils code generator +beyond its core functionality. + +The tools and classes of this module are considered experimental; +their support by the remaining code generator is limited. +They can be used to model and generate code outside of the usual scope of pystencils, +such as non-standard syntax and types. +At the moment, the primary use case is the modelling of C++ syntax. + + +Foreign Syntax Support +====================== + +.. automodule:: pystencils.backend.extensions.foreign_ast + :members: + + +C++ Language Support +==================== + +.. automodule:: pystencils.backend.extensions.cpp + :members: + +""" + +from .foreign_ast import PsForeignExpression + +__all__ = ["PsForeignExpression"] diff --git a/src/pystencils/backend/extensions/cpp.py b/src/pystencils/backend/extensions/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..1055b79e9ab197d62c4307b70ac5b2a71c13f139 --- /dev/null +++ b/src/pystencils/backend/extensions/cpp.py @@ -0,0 +1,41 @@ +from __future__ import annotations +from typing import Iterable, cast + +from pystencils.backend.ast.astnode import PsAstNode + +from ..ast.expressions import PsExpression +from .foreign_ast import PsForeignExpression +from ...types import PsType + + +class CppMethodCall(PsForeignExpression): + """C++ method call on an expression.""" + + def __init__( + self, obj: PsExpression, method: str, return_type: PsType, args: Iterable = () + ): + self._method = method + self._return_type = return_type + children = [obj] + list(args) + super().__init__(children, return_type) + + def structurally_equal(self, other: PsAstNode) -> bool: + if not isinstance(other, CppMethodCall): + return False + + return super().structurally_equal(other) and self._method == other._method + + def clone(self) -> CppMethodCall: + return CppMethodCall( + cast(PsExpression, self.children[0]), + self._method, + self._return_type, + self.children[1:], + ) + + def get_code(self, children_code: Iterable[str]) -> str: + cs = list(children_code) + obj_code = cs[0] + args_code = cs[1:] + args = ", ".join(args_code) + return f"({obj_code}).{self._method}({args})" diff --git a/src/pystencils/backend/extensions/foreign_ast.py b/src/pystencils/backend/extensions/foreign_ast.py new file mode 100644 index 0000000000000000000000000000000000000000..67362ce37e050558d16e06ebc6dbb6b30ff5c6e1 --- /dev/null +++ b/src/pystencils/backend/extensions/foreign_ast.py @@ -0,0 +1,43 @@ +from __future__ import annotations +from typing import Iterable +from abc import ABC, abstractmethod + +from pystencils.backend.ast.astnode import PsAstNode + +from ..ast.expressions import PsExpression +from ..ast.util import failing_cast +from ...types import PsType + + +class PsForeignExpression(PsExpression, ABC): + """Base class for foreign expressions. + + Foreign expressions are expressions whose properties are not modelled by the pystencils AST, + and which pystencils therefore does not understand. + + There are many situations where non-supported expressions are needed; + the most common use case is C++ syntax. + Support for foreign expressions by the code generator is therefore very limited; + as a rule of thumb, only printing is supported. + Type checking and most transformations will fail when encountering a `PsForeignExpression`. + """ + + __match_args__ = ("children",) + + def __init__(self, children: Iterable[PsExpression], dtype: PsType | None = None): + self._children = list(children) + super().__init__(dtype) + + @abstractmethod + def get_code(self, children_code: Iterable[str]) -> str: + """Print this expression, with the given code for each of its children.""" + pass + + def get_children(self) -> tuple[PsAstNode, ...]: + return tuple(self._children) + + def set_child(self, idx: int, c: PsAstNode): + self._children[idx] = failing_cast(PsExpression, c) + + def __repr__(self) -> str: + return f"{type(self)}({self._children})" diff --git a/src/pystencils/backend/jit/__init__.py b/src/pystencils/backend/jit/__init__.py index 7938f70831bb077bd5bef0ecafaba92729b71a42..edc755dfd18e7ef27bcef04e34e44222752e93f6 100644 --- a/src/pystencils/backend/jit/__init__.py +++ b/src/pystencils/backend/jit/__init__.py @@ -26,15 +26,18 @@ Both are available here through `LegacyCpuJit` and `LegacyGpuJit`. """ -from .jit import JitBase, NoJit, LegacyCpuJit, LegacyGpuJit +from .jit import JitBase, NoJit, KernelWrapper, LegacyCpuJit, LegacyGpuJit +from .gpu_cupy import CupyJit no_jit = NoJit() """Disables just-in-time compilation for a kernel.""" __all__ = [ "JitBase", + "KernelWrapper", "LegacyCpuJit", "NoJit", "no_jit", "LegacyGpuJit", + "CupyJit", ] diff --git a/src/pystencils/backend/jit/gpu_cupy.py b/src/pystencils/backend/jit/gpu_cupy.py new file mode 100644 index 0000000000000000000000000000000000000000..0c9b7b8a9ee199401cac05247fdb4627f9ae256e --- /dev/null +++ b/src/pystencils/backend/jit/gpu_cupy.py @@ -0,0 +1,252 @@ +from typing import Callable, Any +from dataclasses import dataclass + +try: + import cupy as cp + + HAVE_CUPY = True +except ImportError: + HAVE_CUPY = False + +from ...enums import Target +from ...field import FieldType + +from ...types import PsType +from .jit import JitBase, JitError, KernelWrapper +from ..kernelfunction import ( + KernelFunction, + GpuKernelFunction, + FieldPointerParam, + FieldShapeParam, + FieldStrideParam, + KernelParameter, +) +from ..emission import emit_code + +from ...include import get_pystencils_include_path + + +@dataclass +class LaunchGrid: + grid: tuple[int, int, int] + block: tuple[int, int, int] + + +class CupyKernelWrapper(KernelWrapper): + def __init__( + self, + kfunc: GpuKernelFunction, + raw_kernel: Any, + block_size: tuple[int, int, int], + ): + self._kfunc: GpuKernelFunction = kfunc + self._raw_kernel = raw_kernel + self._block_size = block_size + self._args_cache: dict[Any, tuple] = dict() + + @property + def kernel_function(self) -> GpuKernelFunction: + return self._kfunc + + @property + def raw_kernel(self): + return self._raw_kernel + + @property + def block_size(self) -> tuple[int, int, int]: + return self._block_size + + @block_size.setter + def block_size(self, bs: tuple[int, int, int]): + self._block_size = bs + + def __call__(self, **kwargs: Any): + kernel_args, launch_grid = self._get_cached_args(**kwargs) + device = self._get_device(kernel_args) + with cp.cuda.Device(device): + self._raw_kernel(launch_grid.grid, launch_grid.block, kernel_args) + + def _get_device(self, kernel_args): + devices = set(a.device.id for a in kernel_args if type(a) is cp.ndarray) + if len(devices) != 1: + raise JitError("Could not determine CUDA device to execute on") + return devices.pop() + + def _get_cached_args(self, **kwargs): + key = tuple( + ( + (k, v.dtype, v.strides, v.shape) + if isinstance(v, cp.ndarray) + else (k, id(v)) + ) + for k, v in kwargs.items() + ) + + if key not in self._args_cache: + args = self._get_args(**kwargs) + self._args_cache[key] = args + return args + else: + return self._args_cache[key] + + def _get_args(self, **kwargs) -> tuple[tuple, LaunchGrid]: + args = [] + valuation: dict[str, Any] = dict() + + def add_arg(name: str, arg: Any, dtype: PsType): + nptype = dtype.numpy_dtype + assert nptype is not None + typecast = nptype.type + arg = typecast(arg) + args.append(arg) + valuation[name] = arg + + field_shapes = set() + index_shapes = set() + + def check_shape(field_ptr: FieldPointerParam, arr: cp.ndarray): + field = field_ptr.field + + if field.has_fixed_shape: + expected_shape = tuple(int(s) for s in field.shape) + actual_shape = arr.shape + if expected_shape != actual_shape: + raise ValueError( + f"Array kernel argument {field.name} had unexpected shape:\n" + f" Expected {expected_shape}, but got {actual_shape}" + ) + + expected_strides = tuple(int(s) for s in field.strides) + actual_strides = tuple(s // arr.dtype.itemsize for s in arr.strides) + if expected_strides != actual_strides: + raise ValueError( + f"Array kernel argument {field.name} had unexpected strides:\n" + f" Expected {expected_strides}, but got {actual_strides}" + ) + + match field.field_type: + case FieldType.GENERIC: + field_shapes.add(arr.shape) + + if len(field_shapes) > 1: + raise ValueError( + "Incompatible array shapes:" + "All arrays passed for generic fields to a kernel must have the same shape." + ) + + case FieldType.INDEXED: + index_shapes.add(arr.shape) + + if len(index_shapes) > 1: + raise ValueError( + "Incompatible array shapes:" + "All arrays passed for index fields to a kernel must have the same shape." + ) + + # Collect parameter values + arr: cp.ndarray + + for kparam in self._kfunc.parameters: + match kparam: + case FieldPointerParam(_, dtype, field): + arr = kwargs[field.name] + if arr.dtype != field.dtype.numpy_dtype: + raise JitError( + f"Data type mismatch at array argument {field.name}:" + f"Expected {field.dtype}, got {arr.dtype}" + ) + check_shape(kparam, arr) + args.append(arr) + + case FieldShapeParam(name, dtype, field, coord): + arr = kwargs[field.name] + add_arg(name, arr.shape[coord], dtype) + + case FieldStrideParam(name, dtype, field, coord): + arr = kwargs[field.name] + add_arg(name, arr.strides[coord] // arr.dtype.itemsize, dtype) + + case KernelParameter(name, dtype): + val: Any = kwargs[name] + add_arg(name, val, dtype) + + # Determine launch grid + from ..ast.expressions import evaluate_expression + + symbolic_threads_range = self._kfunc.threads_range + + threads_range: list[int] = [ + evaluate_expression(expr, valuation) + for expr in symbolic_threads_range.num_work_items + ] + + if symbolic_threads_range.dim < 3: + threads_range += [1] * (3 - symbolic_threads_range.dim) + + def div_ceil(a, b): + return a // b if a % b == 0 else a // b + 1 + + # TODO: Refine this? + grid_size = tuple( + div_ceil(threads, tpb) + for threads, tpb in zip(threads_range, self._block_size) + ) + assert len(grid_size) == 3 + + launch_grid = LaunchGrid(grid_size, self._block_size) + + return tuple(args), launch_grid + + +class CupyJit(JitBase): + + def __init__(self, default_block_size: tuple[int, int, int] = (128, 2, 1)): + self._runtime_headers = {"<cstdint>"} + self._default_block_size = default_block_size + + def compile(self, kfunc: KernelFunction) -> Callable[..., None]: + if not HAVE_CUPY: + raise JitError( + "`cupy` is not installed: just-in-time-compilation of CUDA kernels is unavailable." + ) + + if not isinstance(kfunc, GpuKernelFunction) or kfunc.target != Target.CUDA: + raise ValueError( + "The CupyJit just-in-time compiler only accepts kernels generated for CUDA or HIP" + ) + + options = self._compiler_options() + prelude = self._prelude(kfunc) + kernel_code = self._kernel_code(kfunc) + code = prelude + kernel_code + + raw_kernel = cp.RawKernel( + code, kfunc.name, options=options, backend="nvrtc", jitify=True + ) + return CupyKernelWrapper(kfunc, raw_kernel, self._default_block_size) + + def _compiler_options(self) -> tuple[str, ...]: + options = ["-w", "-std=c++11"] + options.append("-I" + get_pystencils_include_path()) + return tuple(options) + + def _prelude(self, kfunc: GpuKernelFunction) -> str: + headers = self._runtime_headers + headers |= kfunc.required_headers + + if '"half_precision.h"' in headers: + headers.remove('"half_precision.h"') + if cp.cuda.runtime.is_hip: + headers.add("<hip/hip_fp16.h>") + else: + headers.add("<cuda_fp16.h>") + + code = "\n".join(f"#include {header}" for header in headers) + + code += "\n\n#define RESTRICT __restrict__\n\n" + + return code + + def _kernel_code(self, kfunc: GpuKernelFunction) -> str: + kernel_code = emit_code(kfunc) + return f'extern "C" {kernel_code}' diff --git a/src/pystencils/backend/jit/jit.py b/src/pystencils/backend/jit/jit.py index d2c7bec9efd887d177499a2964ff068f2da868c4..b455c368054c26f7ed6b86893bcb9dca8b877daf 100644 --- a/src/pystencils/backend/jit/jit.py +++ b/src/pystencils/backend/jit/jit.py @@ -18,6 +18,21 @@ class JitBase(ABC): """Compile a kernel function and return a callable object which invokes the kernel.""" +class KernelWrapper: + def __init__(self, kfunc: KernelFunction) -> None: + self._kfunc = kfunc + + @property + def kernel_function(self) -> KernelFunction: + return self._kfunc + + @property + def code(self) -> str: + from pystencils.display_utils import get_code_str + + return get_code_str(str) + + class NoJit(JitBase): """Not a JIT compiler: Used to explicitly disable JIT compilation on an AST.""" diff --git a/src/pystencils/backend/kernelcreation/__init__.py b/src/pystencils/backend/kernelcreation/__init__.py index 5de83caadb3b4aa50112ef2b65c28c1ca7932aae..abba9d9d8d571fa7540f82807d009e02d522849f 100644 --- a/src/pystencils/backend/kernelcreation/__init__.py +++ b/src/pystencils/backend/kernelcreation/__init__.py @@ -5,6 +5,7 @@ from .typification import Typifier from .ast_factory import AstFactory from .iteration_space import ( + IterationSpace, FullIterationSpace, SparseIterationSpace, create_full_iteration_space, @@ -19,6 +20,7 @@ __all__ = [ "FreezeExpressions", "Typifier", "AstFactory", + "IterationSpace", "FullIterationSpace", "SparseIterationSpace", "create_full_iteration_space", diff --git a/src/pystencils/backend/kernelcreation/cpu_optimization.py b/src/pystencils/backend/kernelcreation/cpu_optimization.py index 29b133ff164e856783f14eb83357c8382db9ba5d..46fef660303fe9762383659883e62b9f2178bc71 100644 --- a/src/pystencils/backend/kernelcreation/cpu_optimization.py +++ b/src/pystencils/backend/kernelcreation/cpu_optimization.py @@ -1,12 +1,14 @@ from __future__ import annotations -from typing import cast +from typing import cast, TYPE_CHECKING from .context import KernelCreationContext -from ..platforms import GenericCpu from ..ast.structural import PsBlock from ...config import CpuOptimConfig, OpenMpConfig +if TYPE_CHECKING: + from ..platforms import GenericCpu + def optimize_cpu( ctx: KernelCreationContext, diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 56e58648966fab4e60b4eea64ab5442b92f91709..bfebd5af6925f560eae9bbddc7a75f41d2e5a876 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -46,7 +46,7 @@ class IterationSpace(ABC): return self._spatial_indices @property - def dim(self) -> int: + def rank(self) -> int: return len(self._spatial_indices) @@ -209,7 +209,20 @@ class FullIterationSpace(IterationSpace): def archetype_field(self) -> Field | None: return self._archetype_field - def actual_iterations(self, dimension: int | None = None) -> PsExpression: + def dimensions_in_loop_order(self) -> Sequence[FullIterationSpace.Dimension]: + """Return the dimensions of this iteration space ordered from the slowest to the fastest coordinate. + + If an archetype field is specified, the field layout is used to determine the ideal loop order; + otherwise, the dimensions are returned as they are + """ + if self._archetype_field is not None: + return [self._dimensions[i] for i in self._archetype_field.layout] + else: + return self._dimensions + + def actual_iterations( + self, dimension: int | FullIterationSpace.Dimension | None = None + ) -> PsExpression: from .typification import Typifier from ..transformations import EliminateConstants @@ -229,7 +242,10 @@ class FullIterationSpace(IterationSpace): ) ) else: - dim = self.dimensions[dimension] + if isinstance(dimension, FullIterationSpace.Dimension): + dim = dimension + else: + dim = self.dimensions[dimension] one = PsConstantExpr(PsConstant(1, self._ctx.index_dtype)) zero = PsConstantExpr(PsConstant(0, self._ctx.index_dtype)) return fold( @@ -246,7 +262,7 @@ class FullIterationSpace(IterationSpace): """Expression counting the actual number of items processed at the iteration defined by the counter tuple. Used primarily for indexing buffers.""" - actual_iters = [self.actual_iterations(d) for d in range(self.dim)] + actual_iters = [self.actual_iterations(d) for d in range(self.rank)] compressed_counters = [ (PsExpression.make(dim.counter) - dim.start) / dim.step for dim in self.dimensions diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py index 985f0bfa30cd1a7f3e31d4d7f99964d59a9f4e9e..32510731c9773f6fce8395c145b6eeaf3bb55b63 100644 --- a/src/pystencils/backend/kernelfunction.py +++ b/src/pystencils/backend/kernelfunction.py @@ -2,15 +2,19 @@ from __future__ import annotations from warnings import warn from abc import ABC -from typing import Callable, Sequence, Any +from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING from .._deprecation import _deprecated from .ast.structural import PsBlock +from .ast.analysis import collect_required_headers, collect_undefined_symbols +from .arrays import PsArrayShapeSymbol, PsArrayStrideSymbol, PsArrayBasePointer +from .symbols import PsSymbol +from .kernelcreation.context import KernelCreationContext +from .platforms import Platform, GpuThreadsRange from .constraints import KernelParamsConstraint from ..types import PsType -from .jit import JitBase, no_jit from ..enums import Target from ..field import Field @@ -21,6 +25,9 @@ from ..sympyextensions.typed_sympy import ( FieldPointerSymbol, ) +if TYPE_CHECKING: + from .jit import JitBase + class KernelParameter: __match_args__ = ("name", "dtype") @@ -196,7 +203,7 @@ class KernelFunction: parameters: Sequence[KernelParameter], required_headers: set[str], constraints: Sequence[KernelParamsConstraint], - jit: JitBase = no_jit, + jit: JitBase, ): self._body: PsBlock = body self._target = target @@ -267,3 +274,102 @@ class KernelFunction: def compile(self) -> Callable[..., None]: return self._jit.compile(self) + + +def create_cpu_kernel_function( + ctx: KernelCreationContext, + platform: Platform, + body: PsBlock, + function_name: str, + target_spec: Target, + jit: JitBase, +): + undef_symbols = collect_undefined_symbols(body) + + params = _get_function_params(ctx, undef_symbols) + req_headers = _get_headers(ctx, platform, body) + + kfunc = KernelFunction( + body, target_spec, function_name, params, req_headers, ctx.constraints, jit + ) + kfunc.metadata.update(ctx.metadata) + return kfunc + + +class GpuKernelFunction(KernelFunction): + def __init__( + self, + body: PsBlock, + threads_range: GpuThreadsRange, + target: Target, + name: str, + parameters: Sequence[KernelParameter], + required_headers: set[str], + constraints: Sequence[KernelParamsConstraint], + jit: JitBase, + ): + super().__init__( + body, target, name, parameters, required_headers, constraints, jit + ) + self._threads_range = threads_range + + @property + def threads_range(self) -> GpuThreadsRange: + return self._threads_range + + +def create_gpu_kernel_function( + ctx: KernelCreationContext, + platform: Platform, + body: PsBlock, + threads_range: GpuThreadsRange, + function_name: str, + target_spec: Target, + jit: JitBase, +): + undef_symbols = collect_undefined_symbols(body) + for threads in threads_range.num_work_items: + undef_symbols |= collect_undefined_symbols(threads) + + params = _get_function_params(ctx, undef_symbols) + req_headers = _get_headers(ctx, platform, body) + + kfunc = GpuKernelFunction( + body, + threads_range, + target_spec, + function_name, + params, + req_headers, + ctx.constraints, + jit, + ) + kfunc.metadata.update(ctx.metadata) + return kfunc + + +def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]): + params: list[KernelParameter] = [] + for symb in symbols: + match symb: + case PsArrayShapeSymbol(name, _, arr, coord): + field = ctx.find_field(arr.name) + params.append(FieldShapeParam(name, symb.get_dtype(), field, coord)) + case PsArrayStrideSymbol(name, _, arr, coord): + field = ctx.find_field(arr.name) + params.append(FieldStrideParam(name, symb.get_dtype(), field, coord)) + case PsArrayBasePointer(name, _, arr): + field = ctx.find_field(arr.name) + params.append(FieldPointerParam(name, symb.get_dtype(), field)) + case PsSymbol(name, _): + params.append(KernelParameter(name, symb.get_dtype())) + + params.sort(key=lambda p: p.name) + return params + + +def _get_headers(ctx: KernelCreationContext, platform: Platform, body: PsBlock): + req_headers = collect_required_headers(body) + req_headers |= platform.required_headers + req_headers |= ctx.required_headers + return req_headers diff --git a/src/pystencils/backend/platforms/__init__.py b/src/pystencils/backend/platforms/__init__.py index 0b816bf9396bcc0b746d8858417eb087de0ad46f..9332453c6c1b60255f1869f011bfa661ee670ea0 100644 --- a/src/pystencils/backend/platforms/__init__.py +++ b/src/pystencils/backend/platforms/__init__.py @@ -1,7 +1,9 @@ from .platform import Platform from .generic_cpu import GenericCpu, GenericVectorCpu -from .generic_gpu import GenericGpu +from .generic_gpu import GenericGpu, GpuThreadsRange +from .cuda import CudaPlatform from .x86 import X86VectorCpu, X86VectorArch +from .sycl import SyclPlatform __all__ = [ "Platform", @@ -10,4 +12,7 @@ __all__ = [ "X86VectorCpu", "X86VectorArch", "GenericGpu", + "GpuThreadsRange", + "CudaPlatform", + "SyclPlatform", ] diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..4233784825659362bb028f40d5f15d1c28402be1 --- /dev/null +++ b/src/pystencils/backend/platforms/cuda.py @@ -0,0 +1,168 @@ +from ...types import constify +from ..exceptions import MaterializationError +from .generic_gpu import GenericGpu, GpuThreadsRange + +from ..kernelcreation import ( + Typifier, + IterationSpace, + FullIterationSpace, + SparseIterationSpace, +) + +from ..kernelcreation.context import KernelCreationContext +from ..ast.structural import PsBlock, PsConditional, PsDeclaration +from ..ast.expressions import PsExpression, PsLiteralExpr, PsCast, PsCall +from ..ast.expressions import PsLt, PsAnd +from ...types import PsSignedIntegerType, PsIeeeFloatType +from ..literals import PsLiteral +from ..functions import PsMathFunction, MathFunctions, CFunction +from ...config import GpuIndexingConfig + +int32 = PsSignedIntegerType(width=32, const=False) + +BLOCK_IDX = [ + PsLiteralExpr(PsLiteral(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z") +] +THREAD_IDX = [ + PsLiteralExpr(PsLiteral(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z") +] +BLOCK_DIM = [ + PsLiteralExpr(PsLiteral(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z") +] +GRID_DIM = [ + PsLiteralExpr(PsLiteral(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z") +] + + +class CudaPlatform(GenericGpu): + + def __init__( + self, ctx: KernelCreationContext, indexing_cfg: GpuIndexingConfig | None = None + ) -> None: + super().__init__(ctx) + self._cfg = indexing_cfg if indexing_cfg is not None else GpuIndexingConfig() + self._typify = Typifier(ctx) + + @property + def required_headers(self) -> set[str]: + return {'"gpu_defines.h"'} + + def materialize_iteration_space( + self, body: PsBlock, ispace: IterationSpace + ) -> tuple[PsBlock, GpuThreadsRange]: + if isinstance(ispace, FullIterationSpace): + return self._prepend_dense_translation(body, ispace) + elif isinstance(ispace, SparseIterationSpace): + return self._prepend_sparse_translation(body, ispace) + else: + raise MaterializationError(f"Unknown type of iteration space: {ispace}") + + def select_function(self, call: PsCall) -> PsExpression: + assert isinstance(call.function, PsMathFunction) + + func = call.function.func + dtype = call.get_dtype() + arg_types = (dtype,) * func.num_args + + if isinstance(dtype, PsIeeeFloatType): + match func: + case ( + MathFunctions.Exp + | MathFunctions.Log + | MathFunctions.Sin + | MathFunctions.Cos + | MathFunctions.Ceil + | MathFunctions.Floor + ) if dtype.width in (16, 32, 64): + prefix = "h" if dtype.width == 16 else "" + suffix = "f" if dtype.width == 32 else "" + name = f"{prefix}{func.function_name}{suffix}" + cfunc = CFunction(name, arg_types, dtype) + + case ( + MathFunctions.Pow + | MathFunctions.Tan + | MathFunctions.Sinh + | MathFunctions.Cosh + | MathFunctions.ASin + | MathFunctions.ACos + | MathFunctions.ATan + | MathFunctions.ATan2 + ) if dtype.width in (32, 64): + # These are unavailable for fp16 + suffix = "f" if dtype.width == 32 else "" + name = f"{func.function_name}{suffix}" + cfunc = CFunction(name, arg_types, dtype) + + case ( + MathFunctions.Min | MathFunctions.Max | MathFunctions.Abs + ) if dtype.width in (32, 64): + suffix = "f" if dtype.width == 32 else "" + name = f"f{func.function_name}{suffix}" + cfunc = CFunction(name, arg_types, dtype) + + case MathFunctions.Abs if dtype.width == 16: + cfunc = CFunction(" __habs", arg_types, dtype) + + call.function = cfunc + return call + + raise MaterializationError( + f"No implementation available for function {func} on data type {dtype}" + ) + + # Internals + + def _prepend_dense_translation( + self, body: PsBlock, ispace: FullIterationSpace + ) -> tuple[PsBlock, GpuThreadsRange]: + dimensions = ispace.dimensions_in_loop_order() + launch_config = GpuThreadsRange.from_ispace(ispace) + + indexing_decls = [] + conds = [] + for i, dim in enumerate(dimensions[::-1]): + dim.counter.dtype = constify(dim.counter.get_dtype()) + + ctr = PsExpression.make(dim.counter) + indexing_decls.append( + self._typify( + PsDeclaration( + ctr, + dim.start + + dim.step + * PsCast(ctr.get_dtype(), self._linear_thread_idx(i)), + ) + ) + ) + if not self._cfg.omit_range_check: + conds.append(PsLt(ctr, dim.stop)) + + if conds: + condition: PsExpression = conds[0] + for cond in conds[1:]: + condition = PsAnd(condition, cond) + ast = PsBlock(indexing_decls + [PsConditional(condition, body)]) + else: + body.statements = indexing_decls + body.statements + ast = body + + return ast, launch_config + + def _prepend_sparse_translation( + self, body: PsBlock, ispace: SparseIterationSpace + ) -> tuple[PsBlock, GpuThreadsRange]: + ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype()) + + ctr = PsExpression.make(ispace.sparse_counter) + thread_idx = self._linear_thread_idx(0) + idx_decl = self._typify(PsDeclaration(ctr, PsCast(ctr.get_dtype(), thread_idx))) + body.statements = [idx_decl] + body.statements + + return body, GpuThreadsRange.from_ispace(ispace) + + def _linear_thread_idx(self, coord: int): + block_size = BLOCK_DIM[coord] + block_idx = BLOCK_IDX[coord] + thread_idx = THREAD_IDX[coord] + return block_idx * block_size + thread_idx diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 25228ba8fba81f84d45844ea23bd8813fd15ce73..a1505e672862264a50a8b035a83dc8dcdfb0769d 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -54,7 +54,7 @@ class GenericCpu(Platform): elif isinstance(ispace, SparseIterationSpace): return self._create_sparse_loop(body, ispace) else: - assert False, "unreachable code" + raise MaterializationError(f"Unknown type of iteration space: {ispace}") def select_function(self, call: PsCall) -> PsExpression: assert isinstance(call.function, PsMathFunction) diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 1403b8f5ca1329812749bcc9d9bd50d4fcf4ac98..774b9405cd04b8dc6489cd6b6ae36e4aa563f157 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -1,85 +1,68 @@ -from .platform import Platform +from __future__ import annotations +from typing import Sequence +from abc import abstractmethod +from ..ast.expressions import PsExpression +from ..ast.structural import PsBlock from ..kernelcreation.iteration_space import ( IterationSpace, FullIterationSpace, - # SparseIterationSpace, -) - -from ..ast.structural import PsBlock, PsConditional -from ..ast.expressions import ( - PsExpression, - PsLiteralExpr, - PsAdd, - PsCall + SparseIterationSpace, ) -from ..ast.expressions import PsLt, PsAnd -from ...types import PsSignedIntegerType -from ..literals import PsLiteral - -int32 = PsSignedIntegerType(width=32, const=False) - -BLOCK_IDX = [ - PsLiteralExpr(PsLiteral(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z") -] -THREAD_IDX = [ - PsLiteralExpr(PsLiteral(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z") -] -BLOCK_DIM = [ - PsLiteralExpr(PsLiteral(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z") -] -GRID_DIM = [ - PsLiteralExpr(PsLiteral(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z") -] +from .platform import Platform -class GenericGpu(Platform): +class GpuThreadsRange: + """Number of threads required by a GPU kernel, in order (x, y, z).""" - @property - def required_headers(self) -> set[str]: - return {"gpu_defines.h"} - - def materialize_iteration_space( - self, body: PsBlock, ispace: IterationSpace - ) -> PsBlock: + @staticmethod + def from_ispace(ispace: IterationSpace) -> GpuThreadsRange: if isinstance(ispace, FullIterationSpace): - return self._guard_full_iteration_space(body, ispace) + return GpuThreadsRange._from_full_ispace(ispace) + elif isinstance(ispace, SparseIterationSpace): + work_items = (PsExpression.make(ispace.index_list.shape[0]),) + return GpuThreadsRange(work_items) else: - assert False, "unreachable code" + assert False - def cuda_indices(self, dim): - block_size = BLOCK_DIM - indices = [ - block_index * bs + thread_idx - for block_index, bs, thread_idx in zip(BLOCK_IDX, block_size, THREAD_IDX) - ] + def __init__( + self, + num_work_items: Sequence[PsExpression], + ): + self._dim = len(num_work_items) + self._num_work_items = tuple(num_work_items) - return indices[:dim] + # @property + # def grid_size(self) -> tuple[PsExpression, ...]: + # return self._grid_size - def select_function(self, call: PsCall) -> PsExpression: - raise NotImplementedError() + # @property + # def block_size(self) -> tuple[PsExpression, ...]: + # return self._block_size - # Internals - def _guard_full_iteration_space( - self, body: PsBlock, ispace: FullIterationSpace - ) -> PsBlock: - - dimensions = ispace.dimensions + @property + def num_work_items(self) -> tuple[PsExpression, ...]: + """Number of work items in (x, y, z)-order.""" + return self._num_work_items - # Determine loop order by permuting dimensions - archetype_field = ispace.archetype_field - if archetype_field is not None: - loop_order = archetype_field.layout - dimensions = [dimensions[coordinate] for coordinate in loop_order] + @property + def dim(self) -> int: + return self._dim - start = [ - PsAdd(c, d.start) - for c, d in zip(self.cuda_indices(len(dimensions)), dimensions[::-1]) - ] - conditions = [PsLt(c, d.stop) for c, d in zip(start, dimensions[::-1])] + @staticmethod + def _from_full_ispace(ispace: FullIterationSpace) -> GpuThreadsRange: + dimensions = ispace.dimensions_in_loop_order()[::-1] + if len(dimensions) > 3: + raise NotImplementedError( + f"Cannot create a GPU threads range for an {len(dimensions)}-dimensional iteration space" + ) + work_items = [ispace.actual_iterations(dim) for dim in dimensions] + return GpuThreadsRange(work_items) - condition: PsExpression = conditions[0] - for c in conditions[1:]: - condition = PsAnd(condition, c) - return PsBlock([PsConditional(condition, body)]) +class GenericGpu(Platform): + @abstractmethod + def materialize_iteration_space( + self, block: PsBlock, ispace: IterationSpace + ) -> tuple[PsBlock, GpuThreadsRange]: + pass diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py index 3f8912e81c6f42ba776dfd5e9cd7d8895f93ae4b..cab4d0a7b0143eabd4e40602b45ef2f66592ee8c 100644 --- a/src/pystencils/backend/platforms/platform.py +++ b/src/pystencils/backend/platforms/platform.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Any from ..ast.structural import PsBlock from ..ast.expressions import PsCall, PsExpression @@ -27,7 +28,7 @@ class Platform(ABC): @abstractmethod def materialize_iteration_space( self, block: PsBlock, ispace: IterationSpace - ) -> PsBlock: + ) -> PsBlock | tuple[PsBlock, Any]: pass @abstractmethod diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py new file mode 100644 index 0000000000000000000000000000000000000000..52953115ab3a50139c6bfdf047cc8f9fa53714d8 --- /dev/null +++ b/src/pystencils/backend/platforms/sycl.py @@ -0,0 +1,171 @@ +from ..functions import CFunction, PsMathFunction, MathFunctions +from ..kernelcreation.iteration_space import ( + IterationSpace, + FullIterationSpace, + SparseIterationSpace, +) +from ..ast.structural import PsDeclaration, PsBlock, PsConditional +from ..ast.expressions import PsExpression, PsSymbolExpr, PsSubscript, PsLt, PsAnd, PsCall, PsGe, PsLe, PsTernary +from ..extensions.cpp import CppMethodCall + +from ..kernelcreation.context import KernelCreationContext +from ..constants import PsConstant +from .generic_gpu import GenericGpu, GpuThreadsRange +from ..exceptions import MaterializationError +from ...types import PsCustomType, PsIeeeFloatType, constify, PsIntegerType +from ...config import GpuIndexingConfig + + +class SyclPlatform(GenericGpu): + + def __init__( + self, ctx: KernelCreationContext, indexing_cfg: GpuIndexingConfig | None = None + ): + super().__init__(ctx) + self._cfg = indexing_cfg if indexing_cfg is not None else GpuIndexingConfig() + + @property + def required_headers(self) -> set[str]: + return {"<sycl/sycl.hpp>"} + + def materialize_iteration_space( + self, body: PsBlock, ispace: IterationSpace + ) -> tuple[PsBlock, GpuThreadsRange]: + if isinstance(ispace, FullIterationSpace): + return self._prepend_dense_translation(body, ispace) + elif isinstance(ispace, SparseIterationSpace): + return self._prepend_sparse_translation(body, ispace) + else: + raise MaterializationError(f"Unknown type of iteration space: {ispace}") + + def select_function(self, call: PsCall) -> PsExpression: + assert isinstance(call.function, PsMathFunction) + + func = call.function.func + dtype = call.get_dtype() + arg_types = (dtype,) * func.num_args + + if isinstance(dtype, PsIeeeFloatType) and dtype.width in (16, 32, 64): + match func: + case ( + MathFunctions.Exp + | MathFunctions.Log + | MathFunctions.Sin + | MathFunctions.Cos + | MathFunctions.Tan + | MathFunctions.Sinh + | MathFunctions.Cosh + | MathFunctions.ASin + | MathFunctions.ACos + | MathFunctions.ATan + | MathFunctions.ATan2 + | MathFunctions.Pow + | MathFunctions.Floor + | MathFunctions.Ceil + ): + cfunc = CFunction(f"sycl::{func.function_name}", arg_types, dtype) + + case MathFunctions.Abs | MathFunctions.Min | MathFunctions.Max: + cfunc = CFunction(f"sycl::f{func.function_name}", arg_types, dtype) + + call.function = cfunc + return call + + if isinstance(dtype, PsIntegerType): + match func: + case MathFunctions.Abs: + zero = PsExpression.make(PsConstant(0, dtype)) + arg = call.args[0] + return PsTernary(PsGe(arg, zero), arg, - arg) + case MathFunctions.Min: + arg1, arg2 = call.args + return PsTernary(PsLe(arg1, arg2), arg1, arg2) + case MathFunctions.Max: + arg1, arg2 = call.args + return PsTernary(PsGe(arg1, arg2), arg1, arg2) + + raise MaterializationError( + f"No implementation available for function {func} on data type {dtype}" + ) + + def _prepend_dense_translation( + self, body: PsBlock, ispace: FullIterationSpace + ) -> tuple[PsBlock, GpuThreadsRange]: + rank = ispace.rank + id_type = self._id_type(rank) + id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type)) + id_decl = self._id_declaration(rank, id_symbol) + + dimensions = ispace.dimensions_in_loop_order() + launch_config = GpuThreadsRange.from_ispace(ispace) + + indexing_decls = [id_decl] + conds = [] + + # Other than in CUDA, SYCL ids are linearized in C order + # The leftmost entry of an ID varies slowest, and the rightmost entry varies fastest + # See https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#sec:multi-dim-linearization + + for i, dim in enumerate(dimensions): + # Slowest to fastest + coord = PsExpression.make(PsConstant(i, self._ctx.index_dtype)) + work_item_idx = PsSubscript(id_symbol, coord) + + dim.counter.dtype = constify(dim.counter.get_dtype()) + work_item_idx.dtype = dim.counter.get_dtype() + + ctr = PsExpression.make(dim.counter) + indexing_decls.append( + PsDeclaration(ctr, dim.start + work_item_idx * dim.step) + ) + if not self._cfg.omit_range_check: + conds.append(PsLt(ctr, dim.stop)) + + if conds: + condition: PsExpression = conds[0] + for cond in conds[1:]: + condition = PsAnd(condition, cond) + ast = PsBlock(indexing_decls + [PsConditional(condition, body)]) + else: + body.statements = indexing_decls + body.statements + ast = body + + return ast, launch_config + + def _prepend_sparse_translation( + self, body: PsBlock, ispace: SparseIterationSpace + ) -> tuple[PsBlock, GpuThreadsRange]: + id_type = PsCustomType("sycl::id< 1 >", const=True) + id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type)) + + zero = PsExpression.make(PsConstant(0, self._ctx.index_dtype)) + subscript = PsSubscript(id_symbol, zero) + + ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype()) + subscript.dtype = ispace.sparse_counter.get_dtype() + + ctr = PsExpression.make(ispace.sparse_counter) + unpacking = PsDeclaration(ctr, subscript) + body.statements = [unpacking] + body.statements + + return body, GpuThreadsRange.from_ispace(ispace) + + def _item_type(self, rank: int): + if not self._cfg.sycl_automatic_block_size: + return PsCustomType(f"sycl::nd_item< {rank} >", const=True) + else: + return PsCustomType(f"sycl::item< {rank} >", const=True) + + def _id_type(self, rank: int): + return PsCustomType(f"sycl::id< {rank} >", const=True) + + def _id_declaration(self, rank: int, id: PsSymbolExpr) -> PsDeclaration: + item_type = self._item_type(rank) + item = PsExpression.make(self._ctx.get_symbol("sycl_item", item_type)) + + if not self._cfg.sycl_automatic_block_size: + rhs = CppMethodCall(item, "get_global_id", self._id_type(rank)) + else: + rhs = CppMethodCall(item, "get_id", self._id_type(rank)) + + return PsDeclaration(id, rhs) diff --git a/src/pystencils/config.py b/src/pystencils/config.py index 69eb418e310c85c3e9e566c8a31ce201fc2b9814..8e9909b31d2e8d13306794f9914f6a6a22c5170c 100644 --- a/src/pystencils/config.py +++ b/src/pystencils/config.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import TYPE_CHECKING from warnings import warn from collections.abc import Collection @@ -9,12 +10,17 @@ from dataclasses import dataclass, InitVar from .enums import Target from .field import Field, FieldType -from .backend.jit import JitBase -from .backend.exceptions import PsOptionsError from .types import PsIntegerType, UserTypeSpec, PsIeeeFloatType from .defaults import DEFAULTS +if TYPE_CHECKING: + from .backend.jit import JitBase + + +class PsOptionsError(Exception): + """Indicates an option clash in the `CreateKernelConfig`.""" + @dataclass class OpenMpConfig: @@ -129,6 +135,32 @@ class VectorizationConfig: """ +@dataclass +class GpuIndexingConfig: + """Configure index translation behaviour for kernels generated for GPU targets.""" + + omit_range_check: bool = False + """If set to `True`, omit the iteration counter range check. + + By default, the code generator introduces a check if the iteration counters computed from GPU block and thread + indices are within the prescribed loop range. + This check can be discarded through this option, at your own peril. + """ + + sycl_automatic_block_size: bool = True + """If set to `True` while generating for `Target.SYCL`, let the SYCL runtime decide on the block size. + + If set to `True`, the kernel is generated for execution via + `parallel_for <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_parallel_for_invoke>`_ + -dispatch using + a flat `sycl::range`. In this case, the GPU block size will be inferred by the SYCL runtime. + + If set to `False`, the kernel will receive an `nd_item` and has to be executed using + `parallel_for <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_parallel_for_invoke>`_ + with an `nd_range`. This allows manual specification of the block size. + """ + + @dataclass class CreateKernelConfig: """Options for create_kernel.""" @@ -191,6 +223,12 @@ class CreateKernelConfig: If this parameter is set while `target` is a non-CPU target, an error will be raised. """ + gpu_indexing: None | GpuIndexingConfig = None + """Configure index translation for GPU kernels. + + It this parameter is set while `target` is not a GPU target, an error will be raised. + """ + # Deprecated Options data_type: InitVar[UserTypeSpec | None] = None @@ -245,12 +283,32 @@ class CreateKernelConfig: f"Cannot enable auto-vectorization for non-vector CPU target {self.target}" ) + if self.gpu_indexing is not None: + if self.target != Target.SYCL: + raise PsOptionsError( + f"`gpu_indexing` cannot be set for non-SYCL target {self.target}" + ) + # Infer JIT if self.jit is None: if self.target.is_cpu(): from .backend.jit import LegacyCpuJit self.jit = LegacyCpuJit() + elif self.target == Target.CUDA: + try: + from .backend.jit.gpu_cupy import CupyJit + + self.jit = CupyJit() + except ImportError: + from .backend.jit import no_jit + + self.jit = no_jit + + elif self.target == Target.SYCL: + from .backend.jit import no_jit + + self.jit = no_jit else: raise NotImplementedError( f"No default JIT compiler implemented yet for target {self.target}" diff --git a/src/pystencils/display_utils.py b/src/pystencils/display_utils.py index a2fa13adf2984c4676ae7d3e1cc565619a88ac3b..301cdef0f106fc2dbaef6f016b21e14c5e34911d 100644 --- a/src/pystencils/display_utils.py +++ b/src/pystencils/display_utils.py @@ -3,7 +3,8 @@ from typing import Any, Dict, Optional import sympy as sp from pystencils.backend import KernelFunction -from pystencils.kernel_wrapper import KernelWrapper +from pystencils.kernel_wrapper import KernelWrapper as OldKernelWrapper +from .backend.jit import KernelWrapper def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True): @@ -47,8 +48,10 @@ def get_code_obj(ast: KernelWrapper | KernelFunction, custom_backend=None): """ from pystencils.backend.emission import emit_code - if isinstance(ast, KernelWrapper): + if isinstance(ast, OldKernelWrapper): ast = ast.ast + elif isinstance(ast, KernelWrapper): + ast = ast.kernel_function class CodeDisplay: def __init__(self, ast_input): diff --git a/src/pystencils/enums.py b/src/pystencils/enums.py index 632c753f4cfa1f61f9a482c6724f7d03fe10e91d..23c255ef0949e02ac5b0af57551ceec1bf6cfee2 100644 --- a/src/pystencils/enums.py +++ b/src/pystencils/enums.py @@ -25,6 +25,8 @@ class Target(Flag): _CUDA = auto() + _SYCL = auto() + _AUTOMATIC = auto() # ------------------ Actual Targets ------------------------------------------------------------------- @@ -63,23 +65,27 @@ class Target(Flag): """ARM architecture with SVE vector extensions""" CurrentGPU = _GPU | _AUTOMATIC - """ - Auto-best GPU target. + """Auto-best GPU target. `CurrentGPU` causes the code generator to automatically select a GPU target according to GPU devices found on the current machine and runtime environment. """ - GenericCUDA = _GPU | _CUDA - """ - Generic CUDA GPU target. + CUDA = _GPU | _CUDA + """Generic CUDA GPU target. Generate a CUDA kernel for a generic Nvidia GPU. """ - GPU = GenericCUDA + GPU = CUDA """Alias for backward compatibility.""" + SYCL = _GPU | _SYCL + """SYCL kernel target. + + Generate a function to be called within a SYCL parallel command. + """ + def is_automatic(self) -> bool: return Target._AUTOMATIC in self diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 71ff965487265f47193e471afccd6dedc802eddb..54a2c3585a4948193edf01a5d7ae86e4fb523a71 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -2,18 +2,9 @@ from typing import cast from .enums import Target from .config import CreateKernelConfig +from .backend import KernelFunction from .types import create_numeric_type -from .backend import ( - KernelFunction, - KernelParameter, - FieldShapeParam, - FieldStrideParam, - FieldPointerParam, -) -from .backend.symbols import PsSymbol -from .backend.jit import JitBase from .backend.ast.structural import PsBlock -from .backend.arrays import PsArrayShapeSymbol, PsArrayStrideSymbol, PsArrayBasePointer from .backend.kernelcreation import ( KernelCreationContext, KernelAnalysis, @@ -25,12 +16,16 @@ from .backend.kernelcreation.iteration_space import ( create_full_iteration_space, ) -from .backend.ast.analysis import collect_required_headers, collect_undefined_symbols + from .backend.transformations import ( EliminateConstants, EraseAnonymousStructTypes, SelectFunctions, ) +from .backend.kernelfunction import ( + create_cpu_kernel_function, + create_gpu_kernel_function, +) from .simp import AssignmentCollection from .assignment import Assignment @@ -93,12 +88,31 @@ def create_kernel( from .backend.platforms import GenericCpu platform = GenericCpu(ctx) - case _: - # TODO: CUDA/HIP platform - # TODO: SYCL platform (?) - raise NotImplementedError("Target platform not implemented") + kernel_ast = platform.materialize_iteration_space(kernel_body, ispace) + + case target if target.is_gpu(): + match target: + case Target.SYCL: + from .backend.platforms import SyclPlatform - kernel_ast = platform.materialize_iteration_space(kernel_body, ispace) + platform = SyclPlatform(ctx, config.gpu_indexing) + case Target.CUDA: + from .backend.platforms import CudaPlatform + + platform = CudaPlatform(ctx, config.gpu_indexing) + case _: + raise NotImplementedError( + f"Code generation for target {target} not implemented" + ) + + kernel_ast, gpu_threads = platform.materialize_iteration_space( + kernel_body, ispace + ) + + case _: + raise NotImplementedError( + f"Code generation for target {target} not implemented" + ) # Simplifying transformations elim_constants = EliminateConstants(ctx, extract_constant_exprs=True) @@ -108,6 +122,8 @@ def create_kernel( if config.target.is_cpu(): from .backend.kernelcreation import optimize_cpu + assert isinstance(platform, GenericCpu) + kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) erase_anons = EraseAnonymousStructTypes(ctx) @@ -117,45 +133,21 @@ def create_kernel( kernel_ast = cast(PsBlock, select_functions(kernel_ast)) assert config.jit is not None - return create_kernel_function( - ctx, kernel_ast, config.function_name, config.target, config.jit - ) - -def create_kernel_function( - ctx: KernelCreationContext, - body: PsBlock, - function_name: str, - target_spec: Target, - jit: JitBase, -): - undef_symbols = collect_undefined_symbols(body) - - params = [] - for symb in undef_symbols: - match symb: - case PsArrayShapeSymbol(name, _, arr, coord): - field = ctx.find_field(arr.name) - params.append(FieldShapeParam(name, symb.get_dtype(), field, coord)) - case PsArrayStrideSymbol(name, _, arr, coord): - field = ctx.find_field(arr.name) - params.append(FieldStrideParam(name, symb.get_dtype(), field, coord)) - case PsArrayBasePointer(name, _, arr): - field = ctx.find_field(arr.name) - params.append(FieldPointerParam(name, symb.get_dtype(), field)) - case PsSymbol(name, _): - params.append(KernelParameter(name, symb.get_dtype())) - - params.sort(key=lambda p: p.name) - - req_headers = collect_required_headers(body) - req_headers |= ctx.required_headers - - kfunc = KernelFunction( - body, target_spec, function_name, params, req_headers, ctx.constraints, jit - ) - kfunc.metadata.update(ctx.metadata) - return kfunc + if config.target.is_cpu(): + return create_cpu_kernel_function( + ctx, platform, kernel_ast, config.function_name, config.target, config.jit + ) + else: + return create_gpu_kernel_function( + ctx, + platform, + kernel_ast, + gpu_threads, + config.function_name, + config.target, + config.jit, + ) def create_staggered_kernel(assignments, target: Target = Target.CPU, gpu_exclusive_conditions=False, **kwargs): diff --git a/tests/nbackend/kernelcreation/platform/test_basic_gpu.py b/tests/nbackend/kernelcreation/platform/test_basic_gpu.py deleted file mode 100644 index e47f38e4d30c3f94ded9469c7a7351e9a3f298da..0000000000000000000000000000000000000000 --- a/tests/nbackend/kernelcreation/platform/test_basic_gpu.py +++ /dev/null @@ -1,28 +0,0 @@ -import pytest - -from pystencils.field import Field - -from pystencils.backend.kernelcreation import ( - KernelCreationContext, - FullIterationSpace -) - -from pystencils.backend.ast.structural import PsBlock, PsLoop, PsComment -from pystencils.backend.ast.expressions import PsExpression -from pystencils.backend.ast import dfs_preorder - -from pystencils.backend.platforms import GenericGpu - - -@pytest.mark.parametrize("layout", ["fzyx", "zyxf", "c", "f"]) -def test_loop_nest(layout): - ctx = KernelCreationContext() - - body = PsBlock([PsComment("Loop body goes here")]) - platform = GenericGpu(ctx) - - # FZYX Order - archetype_field = Field.create_generic("fzyx_field", spatial_dimensions=3, layout=layout) - ispace = FullIterationSpace.create_with_ghost_layers(ctx, 0, archetype_field) - - condition = platform.materialize_iteration_space(body, ispace) diff --git a/tests/nbackend/kernelcreation/platform/test_gpu_platforms.py b/tests/nbackend/kernelcreation/platform/test_gpu_platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..da2b3a5ad3a0e224bc47a5dd0fa4f16b0ccde520 --- /dev/null +++ b/tests/nbackend/kernelcreation/platform/test_gpu_platforms.py @@ -0,0 +1,43 @@ +import pytest + +from pystencils.field import Field + +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + FullIterationSpace +) + +from pystencils.backend.ast.structural import PsBlock, PsComment + +from pystencils.backend.platforms import CudaPlatform, SyclPlatform + + +@pytest.mark.parametrize("layout", ["fzyx", "zyxf", "c", "f"]) +@pytest.mark.parametrize("platform_class", [CudaPlatform, SyclPlatform]) +def test_thread_range(platform_class, layout): + ctx = KernelCreationContext() + + body = PsBlock([PsComment("Kernel body goes here")]) + platform = platform_class(ctx) + + dim = 3 + archetype_field = Field.create_generic("field", spatial_dimensions=dim, layout=layout) + ispace = FullIterationSpace.create_with_ghost_layers(ctx, 1, archetype_field) + + _, threads_range = platform.materialize_iteration_space(body, ispace) + + assert threads_range.dim == dim + + match layout: + case "fzyx" | "zyxf" | "f": + indexing_order = [0, 1, 2] + case "c": + indexing_order = [2, 1, 0] + + for i in range(dim): + # Slowest to fastest coordinate + coordinate = indexing_order[i] + dimension = ispace.dimensions[coordinate] + witems = threads_range.num_work_items[i] + desired = dimension.stop - dimension.start + assert witems.structurally_equal(desired) diff --git a/tests/nbackend/kernelcreation/test_domain_kernels.py b/tests/nbackend/kernelcreation/test_domain_kernels.py index c9cc81abbe988be4d9c92dc1d833260f38691433..5850c94d79b8eb5293c5853f65bf67c91cfd452d 100644 --- a/tests/nbackend/kernelcreation/test_domain_kernels.py +++ b/tests/nbackend/kernelcreation/test_domain_kernels.py @@ -1,13 +1,20 @@ +import pytest import sympy as sp import numpy as np -from pystencils import fields, Field, AssignmentCollection +from pystencils import fields, Field, AssignmentCollection, Target, CreateKernelConfig from pystencils.assignment import assignment_from_stencil from pystencils.kernelcreation import create_kernel -def test_filter_kernel(): +@pytest.mark.parametrize("target", (Target.GenericCPU, Target.CUDA)) +def test_filter_kernel(target): + if target == Target.CUDA: + xp = pytest.importorskip("cupy") + else: + xp = np + weight = sp.Symbol("weight") stencil = [ [1, 1, 1], @@ -19,21 +26,28 @@ def test_filter_kernel(): asm = assignment_from_stencil(stencil, src, dst, normalization_factor=weight) asms = AssignmentCollection([asm]) - ast = create_kernel(asms) + gen_config = CreateKernelConfig(target=target) + ast = create_kernel(asms, gen_config) kernel = ast.compile() - src_arr = np.ones((42, 42)) - dst_arr = np.zeros_like(src_arr) + src_arr = xp.ones((42, 31)) + dst_arr = xp.zeros_like(src_arr) kernel(src=src_arr, dst=dst_arr, weight=2.0) - expected = np.zeros_like(src_arr) + expected = xp.zeros_like(src_arr) expected[1:-1, 1:-1].fill(18.0) - np.testing.assert_allclose(dst_arr, expected) + xp.testing.assert_allclose(dst_arr, expected) + +@pytest.mark.parametrize("target", (Target.GenericCPU, Target.CUDA)) +def test_filter_kernel_fixedsize(target): + if target == Target.CUDA: + xp = pytest.importorskip("cupy") + else: + xp = np -def test_filter_kernel_fixedsize(): weight = sp.Symbol("weight") stencil = [ [1, 1, 1], @@ -41,8 +55,8 @@ def test_filter_kernel_fixedsize(): [1, 1, 1] ] - src_arr = np.ones((42, 42)) - dst_arr = np.zeros_like(src_arr) + src_arr = xp.ones((42, 31)) + dst_arr = xp.zeros_like(src_arr) src = Field.create_from_numpy_array("src", src_arr) dst = Field.create_from_numpy_array("dst", dst_arr) @@ -50,12 +64,13 @@ def test_filter_kernel_fixedsize(): asm = assignment_from_stencil(stencil, src, dst, normalization_factor=weight) asms = AssignmentCollection([asm]) - ast = create_kernel(asms) + gen_config = CreateKernelConfig(target=target) + ast = create_kernel(asms, gen_config) kernel = ast.compile() kernel(src=src_arr, dst=dst_arr, weight=2.0) - expected = np.zeros_like(src_arr) + expected = xp.zeros_like(src_arr) expected[1:-1, 1:-1].fill(18.0) - np.testing.assert_allclose(dst_arr, expected) + xp.testing.assert_allclose(dst_arr, expected) diff --git a/tests/nbackend/test_functions.py b/tests/nbackend/test_functions.py index c14e118a088a8ffb18b9444f8481f266b85f03e3..325e593402036c5add3d14d9a2d8a3e2d9c3e9ba 100644 --- a/tests/nbackend/test_functions.py +++ b/tests/nbackend/test_functions.py @@ -4,43 +4,69 @@ import pytest from pystencils import create_kernel, CreateKernelConfig, Target, Assignment, Field -UNARY_FUNCTIONS = { - "exp": (sp.exp, np.exp), - "log": (sp.log, np.log), - "sin": (sp.sin, np.sin), - "cos": (sp.cos, np.cos), - "tan": (sp.tan, np.tan), - "sinh": (sp.sinh, np.sinh), - "cosh": (sp.cosh, np.cosh), - "asin": (sp.asin, np.arcsin), - "acos": (sp.acos, np.arccos), - "atan": (sp.atan, np.arctan), - "abs": (sp.Abs, np.abs), - "floor": (sp.floor, np.floor), - "ceil": (sp.ceiling, np.ceil), -} - -BINARY_FUNCTIONS = { - "min": (sp.Min, np.fmin), - "max": (sp.Max, np.fmax), - "pow": (sp.Pow, np.power), - "atan2": (sp.atan2, np.arctan2), -} - - -@pytest.mark.parametrize("target", (Target.GenericCPU,)) -@pytest.mark.parametrize("function_name", UNARY_FUNCTIONS.keys()) + +def unary_function(name, xp): + return { + "exp": (sp.exp, xp.exp), + "log": (sp.log, xp.log), + "sin": (sp.sin, xp.sin), + "cos": (sp.cos, xp.cos), + "tan": (sp.tan, xp.tan), + "sinh": (sp.sinh, xp.sinh), + "cosh": (sp.cosh, xp.cosh), + "asin": (sp.asin, xp.arcsin), + "acos": (sp.acos, xp.arccos), + "atan": (sp.atan, xp.arctan), + "abs": (sp.Abs, xp.abs), + "floor": (sp.floor, xp.floor), + "ceil": (sp.ceiling, xp.ceil), + }[name] + + +def binary_function(name, xp): + return { + "min": (sp.Min, xp.fmin), + "max": (sp.Max, xp.fmax), + "pow": (sp.Pow, xp.power), + "atan2": (sp.atan2, xp.arctan2), + }[name] + + +@pytest.mark.parametrize("target", (Target.GenericCPU, Target.CUDA)) +@pytest.mark.parametrize( + "function_name", + ( + "exp", + "log", + "sin", + "cos", + "tan", + "sinh", + "cosh", + "asin", + "acos", + "atan", + "abs", + "floor", + "ceil", + ), +) @pytest.mark.parametrize("dtype", (np.float32, np.float64)) def test_unary_functions(target, function_name, dtype): - sp_func, np_func = UNARY_FUNCTIONS[function_name] + if target == Target.CUDA: + xp = pytest.importorskip("cupy") + else: + xp = np + + sp_func, xp_func = unary_function(function_name, xp) resolution: dtype = np.finfo(dtype).resolution - inp = np.array( - [[0.1, 0.2, 0.3], [-0.8, -1.6, -12.592], [np.pi, np.e, 0.0]], dtype=dtype + inp = xp.array( + [[0.1, 0.2, 0.3], [-0.8, -1.6, -12.592], [xp.pi, xp.e, 0.0]], dtype=dtype ) - outp = np.zeros_like(inp) + outp = xp.zeros_like(inp) - reference = np_func(inp) + reference = xp_func(inp) inp_field = Field.create_from_numpy_array("inp", inp) outp_field = inp_field.new_field_with_different_name("outp") @@ -52,23 +78,29 @@ def test_unary_functions(target, function_name, dtype): kfunc = kernel.compile() kfunc(inp=inp, outp=outp) - np.testing.assert_allclose(outp, reference, rtol=resolution) + xp.testing.assert_allclose(outp, reference, rtol=resolution) -@pytest.mark.parametrize("target", (Target.GenericCPU,)) -@pytest.mark.parametrize("function_name", BINARY_FUNCTIONS.keys()) +@pytest.mark.parametrize("target", (Target.GenericCPU, Target.CUDA)) +@pytest.mark.parametrize("function_name", ("min", "max", "pow", "atan2")) @pytest.mark.parametrize("dtype", (np.float32, np.float64)) def test_binary_functions(target, function_name, dtype): - sp_func, np_func = BINARY_FUNCTIONS[function_name] + if target == Target.CUDA: + xp = pytest.importorskip("cupy") + else: + xp = np + + sp_func, np_func = binary_function(function_name, xp) resolution: dtype = np.finfo(dtype).resolution - inp = np.array( - [[0.1, 0.2, 0.3], [-0.8, -1.6, -12.592], [np.pi, np.e, 0.0]], dtype=dtype + inp = xp.array( + [[0.1, 0.2, 0.3], [-0.8, -1.6, -12.592], [xp.pi, xp.e, 0.0]], dtype=dtype ) - inp2 = np.array( - [[3.1, -0.5, 21.409], [11.0, 1.0, -14e3], [2.0 * np.pi, - np.e, 0.0]], dtype=dtype + inp2 = xp.array( + [[3.1, -0.5, 21.409], [11.0, 1.0, -14e3], [2.0 * xp.pi, -xp.e, 0.0]], + dtype=dtype, ) - outp = np.zeros_like(inp) + outp = xp.zeros_like(inp) reference = np_func(inp, inp2) @@ -76,11 +108,15 @@ def test_binary_functions(target, function_name, dtype): inp2_field = Field.create_from_numpy_array("inp2", inp) outp_field = inp_field.new_field_with_different_name("outp") - asms = [Assignment(outp_field.center(), sp_func(inp_field.center(), inp2_field.center()))] + asms = [ + Assignment( + outp_field.center(), sp_func(inp_field.center(), inp2_field.center()) + ) + ] gen_config = CreateKernelConfig(target=target, default_dtype=dtype) kernel = create_kernel(asms, gen_config) kfunc = kernel.compile() kfunc(inp=inp, inp2=inp2, outp=outp) - np.testing.assert_allclose(outp, reference, rtol=resolution) + xp.testing.assert_allclose(outp, reference, rtol=resolution) diff --git a/tests/symbolics/test_abs.py b/tests/symbolics/test_abs.py index daa4b17c17c51909c3124079677dcfa1eb87cb33..daab354a3804332c5d36358a0b850f13b8d4f0a6 100644 --- a/tests/symbolics/test_abs.py +++ b/tests/symbolics/test_abs.py @@ -8,7 +8,7 @@ import sympy def test_abs(target): if target == ps.Target.GPU: # FIXME - pytest.skip("GPU target not ready yet") + pytest.xfail("GPU target not ready yet") x, y, z = ps.fields('x, y, z: int64[2d]')