diff --git a/src/pystencils/_deprecation.py b/src/pystencils/_deprecation.py new file mode 100644 index 0000000000000000000000000000000000000000..29ee648a7a9655e556986ae5404e335b94924fc9 --- /dev/null +++ b/src/pystencils/_deprecation.py @@ -0,0 +1,8 @@ +def _deprecated(feature, instead, version="2.1"): + from warnings import warn + + warn( + f"{feature} is deprecated and will be removed in pystencils {version}." + f"Use {instead} instead.", + DeprecationWarning, + ) diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 3b957f3b321e212158eb2eb5d6229037f286b307..0f220149d2913b218f3a41bb8ea5b35018aa251d 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -169,15 +169,19 @@ class CAstPrinter: def __call__(self, obj: PsAstNode | KernelFunction) -> str: if isinstance(obj, KernelFunction): - params_str = ", ".join( - f"{p.dtype.c_string()} {p.name}" for p in obj.parameters - ) - decl = f"FUNC_PREFIX void {obj.name} ({params_str})" + sig = self.print_signature(obj) body_code = self.visit(obj.body, PrinterCtx()) - return f"{decl}\n{body_code}" + return f"{sig}\n{body_code}" else: return self.visit(obj, PrinterCtx()) + def print_signature(self, func: KernelFunction) -> str: + params_str = ", ".join( + f"{p.dtype.c_string()} {p.name}" for p in func.parameters + ) + signature = f"FUNC_PREFIX void {func.name} ({params_str})" + return signature + def visit(self, node: PsAstNode, pc: PrinterCtx) -> str: match node: case PsBlock(statements): diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 263c2f48ecfff15dd4c6271f77ca2a7578b86d09..9df186470086f8115fe0c832917a8676d04aa7bf 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Iterable, Iterator +from typing import Iterable, Iterator, Any from itertools import chain, count from types import EllipsisType from collections import namedtuple, defaultdict @@ -79,6 +79,8 @@ class KernelCreationContext: self._constraints: list[KernelParamsConstraint] = [] self._req_headers: set[str] = set() + self._metadata: dict[str, Any] = dict() + @property def default_dtype(self) -> PsNumericType: return self._default_dtype @@ -95,6 +97,10 @@ class KernelCreationContext: @property def constraints(self) -> tuple[KernelParamsConstraint, ...]: return tuple(self._constraints) + + @property + def metadata(self) -> dict[str, Any]: + return self._metadata # Symbols diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 5208c906caa86a3b687057d777908daaf5f6e2ab..56e58648966fab4e60b4eea64ab5442b92f91709 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -404,11 +404,9 @@ def create_full_iteration_space( if len(domain_field_accesses) > 0: archetype_field = get_archetype_field(ctx.fields.domain_fields) - inferred_gls = max([fa.required_ghost_layers for fa in domain_field_accesses]) elif len(ctx.fields.custom_fields) > 0: # TODO: Warn about inferring iteration space from custom fields archetype_field = get_archetype_field(ctx.fields.custom_fields) - inferred_gls = 0 else: raise PsInputError( "Unable to construct iteration space: The kernel contains no accesses to domain or custom fields." @@ -419,6 +417,7 @@ def create_full_iteration_space( # Otherwise, use the inferred ghost layers if ghost_layers is not None: + ctx.metadata["ghost_layers"] = ghost_layers return FullIterationSpace.create_with_ghost_layers( ctx, ghost_layers, archetype_field ) @@ -427,6 +426,12 @@ def create_full_iteration_space( ctx, iteration_slice, archetype_field ) else: + if len(domain_field_accesses) > 0: + inferred_gls = max([fa.required_ghost_layers for fa in domain_field_accesses]) + else: + inferred_gls = 0 + + ctx.metadata["ghost_layers"] = inferred_gls return FullIterationSpace.create_with_ghost_layers( ctx, inferred_gls, archetype_field ) diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py index 97837492b65cf24d2954bad72d1711d454c8925b..985f0bfa30cd1a7f3e31d4d7f99964d59a9f4e9e 100644 --- a/src/pystencils/backend/kernelfunction.py +++ b/src/pystencils/backend/kernelfunction.py @@ -1,7 +1,10 @@ from __future__ import annotations +from warnings import warn from abc import ABC -from typing import Callable, Sequence +from typing import Callable, Sequence, Any + +from .._deprecation import _deprecated from .ast.structural import PsBlock @@ -11,6 +14,12 @@ from .jit import JitBase, no_jit from ..enums import Target from ..field import Field +from ..sympyextensions import TypedSymbol +from ..sympyextensions.typed_sympy import ( + FieldShapeSymbol, + FieldStrideSymbol, + FieldPointerSymbol, +) class KernelParameter: @@ -49,6 +58,46 @@ class KernelParameter: def __repr__(self) -> str: return f"{type(self).__name__}(name = {self._name}, dtype = {self._dtype})" + @property + def symbol(self) -> TypedSymbol: + return TypedSymbol(self.name, self.dtype) + + @property + def is_field_parameter(self) -> bool: + warn( + "`is_field_parameter` is deprecated and will be removed in a future version of pystencils. " + "Use `isinstance(param, FieldParameter)` instead.", + DeprecationWarning, + ) + return isinstance(self, FieldParameter) + + @property + def is_field_pointer(self) -> bool: + warn( + "`is_field_pointer` is deprecated and will be removed in a future version of pystencils. " + "Use `isinstance(param, FieldPointerParam)` instead.", + DeprecationWarning, + ) + return isinstance(self, FieldPointerParam) + + @property + def is_field_stride(self) -> bool: + warn( + "`is_field_stride` is deprecated and will be removed in a future version of pystencils. " + "Use `isinstance(param, FieldStrideParam)` instead.", + DeprecationWarning, + ) + return isinstance(self, FieldStrideParam) + + @property + def is_field_shape(self) -> bool: + warn( + "`is_field_shape` is deprecated and will be removed in a future version of pystencils. " + "Use `isinstance(param, FieldShapeParam)` instead.", + DeprecationWarning, + ) + return isinstance(self, FieldShapeParam) + class FieldParameter(KernelParameter, ABC): __match_args__ = KernelParameter.__match_args__ + ("field",) @@ -61,6 +110,25 @@ class FieldParameter(KernelParameter, ABC): def field(self): return self._field + @property + def fields(self): + warn( + "`fields` is deprecated and will be removed in a future version of pystencils. " + "In pystencils >= 2.0, field parameters are only associated with a single field." + "Use the `field` property instead.", + DeprecationWarning, + ) + return [self._field] + + @property + def field_name(self) -> str: + warn( + "`field_name` is deprecated and will be removed in a future version of pystencils. " + "Use `field.name` instead.", + DeprecationWarning, + ) + return self._field.name + def _hashable_contents(self): return super()._hashable_contents() + (self._field,) @@ -76,6 +144,10 @@ class FieldShapeParam(FieldParameter): def coordinate(self): return self._coordinate + @property + def symbol(self) -> FieldShapeSymbol: + return FieldShapeSymbol(self.field.name, self.coordinate, self.dtype) + def _hashable_contents(self): return super()._hashable_contents() + (self._coordinate,) @@ -91,6 +163,10 @@ class FieldStrideParam(FieldParameter): def coordinate(self): return self._coordinate + @property + def symbol(self) -> FieldStrideSymbol: + return FieldStrideSymbol(self.field.name, self.coordinate, self.dtype) + def _hashable_contents(self): return super()._hashable_contents() + (self._coordinate,) @@ -99,6 +175,10 @@ class FieldPointerParam(FieldParameter): def __init__(self, name: str, dtype: PsType, field: Field): super().__init__(name, dtype, field) + @property + def symbol(self) -> FieldPointerSymbol: + return FieldPointerSymbol(self.field.name, self.field.dtype, const=True) + class KernelFunction: """A pystencils kernel function. @@ -125,6 +205,11 @@ class KernelFunction: self._required_headers = required_headers self._constraints = tuple(constraints) self._jit = jit + self._metadata: dict[str, Any] = dict() + + @property + def metadata(self) -> dict[str, Any]: + return self._metadata @property def body(self) -> PsBlock: @@ -144,13 +229,34 @@ class KernelFunction: @property def function_name(self) -> str: - """For backward compatibility""" + _deprecated("function_name", "name") return self._name + @function_name.setter + def function_name(self, n: str): + _deprecated("function_name", "name") + self._name = n + @property def parameters(self) -> tuple[KernelParameter, ...]: return self._params + def get_parameters(self) -> tuple[KernelParameter, ...]: + _deprecated("KernelFunction.get_parameters", "KernelFunction.parameters") + return self.parameters + + def get_fields(self) -> set[Field]: + return set(p.field for p in self._params if isinstance(p, FieldParameter)) + + @property + def fields_accessed(self) -> set[Field]: + warn( + "`fields_accessed` is deprecated and will be removed in a future version of pystencils. " + "Use `get_fields` instead.", + DeprecationWarning, + ) + return self.get_fields() + @property def required_headers(self) -> set[str]: return self._required_headers diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index c7015ccb62042e6beaf131e68485f7c8186b2a2a..bd7f592fea3ea07af33454ac5f07b3e8065f3a8a 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -107,6 +107,9 @@ class AddOpenMP: pragma_text += " parallel" if not omp_params.omit_parallel_construct else "" pragma_text += f" for schedule({omp_params.schedule})" + if omp_params.num_threads is not None: + pragma_text += f" num_threads({str(omp_params.num_threads)})" + if omp_params.collapse > 0: pragma_text += f" collapse({str(omp_params.collapse)})" diff --git a/src/pystencils/config.py b/src/pystencils/config.py index fe0e87900800c019d042b0994211bfb9cdd99b15..69eb418e310c85c3e9e566c8a31ce201fc2b9814 100644 --- a/src/pystencils/config.py +++ b/src/pystencils/config.py @@ -1,9 +1,10 @@ from __future__ import annotations +from warnings import warn from collections.abc import Collection from typing import Sequence -from dataclasses import dataclass +from dataclasses import dataclass, InitVar from .enums import Target from .field import Field, FieldType @@ -28,12 +29,21 @@ class OpenMpConfig: schedule: str = "static" """Argument to the OpenMP ``schedule`` clause""" + num_threads: int | None = None + """Set the number of OpenMP threads to execute the parallel region.""" + omit_parallel_construct: bool = False """If set to ``True``, the OpenMP ``parallel`` construct is omitted, producing just a ``#pragma omp for``. Use this option only if you intend to wrap the kernel into an external ``#pragma omp parallel`` region. """ + def __post_init__(self): + if self.omit_parallel_construct and self.num_threads is not None: + raise PsOptionsError( + "Cannot specify `num_threads` if `omit_parallel_construct` is set." + ) + @dataclass class CpuOptimConfig: @@ -181,7 +191,24 @@ class CreateKernelConfig: If this parameter is set while `target` is a non-CPU target, an error will be raised. """ - def __post_init__(self): + # Deprecated Options + + data_type: InitVar[UserTypeSpec | None] = None + """Deprecated; use `default_dtype` instead""" + + cpu_openmp: InitVar[bool | int | None] = None + """Deprecated; use `cpu_optim.openmp` instead.""" + + cpu_vectorize_info: InitVar[dict | None] = None + """Deprecated; use `cpu_optim.vectorize` instead.""" + + # Postprocessing + + def __post_init__(self, *args): + + # Check deprecated options + self._check_deprecations(*args) + # Check iteration space argument consistency if ( int(self.iteration_slice is not None) @@ -228,3 +255,59 @@ class CreateKernelConfig: raise NotImplementedError( f"No default JIT compiler implemented yet for target {self.target}" ) + + def _check_deprecations( + self, + data_type: UserTypeSpec | None, + cpu_openmp: bool | int | None, + cpu_vectorize_info: dict | None, + ): + optim: CpuOptimConfig | None = None + + if data_type is not None: + _deprecated_option("data_type", "default_dtype") + warn( + "Setting the deprecated `data_type` will override the value of `default_dtype`. " + "Set `default_dtype` instead.", + FutureWarning, + ) + self.default_dtype = data_type + + if cpu_openmp is not None: + _deprecated_option("cpu_openmp", "cpu_optim.openmp") + + match cpu_openmp: + case True: + deprecated_omp = OpenMpConfig() + case False: + deprecated_omp = False + case int(): + deprecated_omp = OpenMpConfig(num_threads=cpu_openmp) + case _: + raise PsOptionsError( + f"Invalid option for `cpu_openmp`: {cpu_openmp}" + ) + + optim = CpuOptimConfig(openmp=deprecated_omp) + + if cpu_vectorize_info is not None: + _deprecated_option("cpu_vectorize_info", "cpu_optim.vectorize") + raise NotImplementedError("CPU vectorization is not implemented yet") + + if optim is not None: + if self.cpu_optim is not None: + raise PsOptionsError( + "Cannot specify both `cpu_optim` and a deprecated legacy optimization option at the same time." + ) + else: + self.cpu_optim = optim + + +def _deprecated_option(name, instead): + from warnings import warn + + warn( + f"The `{name}` option of CreateKernelConfig is deprecated and will be removed in pystencils 2.1. " + f"Use `{instead}` instead.", + FutureWarning, + ) diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index e31ef2fd42ec80b8c3c123cb20b11fa2a61c47df..71ff965487265f47193e471afccd6dedc802eddb 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -55,7 +55,8 @@ def create_kernel( """ ctx = KernelCreationContext( - default_dtype=create_numeric_type(config.default_dtype), index_dtype=config.index_dtype + default_dtype=create_numeric_type(config.default_dtype), + index_dtype=config.index_dtype, ) if isinstance(assignments, Assignment): @@ -150,9 +151,11 @@ def create_kernel_function( req_headers = collect_required_headers(body) req_headers |= ctx.required_headers - return KernelFunction( + kfunc = KernelFunction( body, target_spec, function_name, params, req_headers, ctx.constraints, jit ) + kfunc.metadata.update(ctx.metadata) + return kfunc def create_staggered_kernel(assignments, target: Target = Target.CPU, gpu_exclusive_conditions=False, **kwargs): diff --git a/src/pystencils/types/meta.py b/src/pystencils/types/meta.py index 1d605edf89d7f30169dbc2601c85fa4787bb3128..9389279bcb8167c0c7832dd5a3c1289dcbf51ca0 100644 --- a/src/pystencils/types/meta.py +++ b/src/pystencils/types/meta.py @@ -36,6 +36,7 @@ of types, as well as for const-conversion. from __future__ import annotations +from warnings import warn from abc import ABCMeta, abstractmethod from typing import TypeVar, Any, cast import numpy as np @@ -159,6 +160,16 @@ class PsType(metaclass=PsTypeMeta): def c_string(self) -> str: pass + @property + def c_name(self) -> str: + """Returns the C name of this type without const-qualifiers.""" + warn( + "`c_name` is deprecated and will be removed in a future version of pystencils. " + "Use `c_string()` instead.", + DeprecationWarning, + ) + return deconstify(self).c_string() + def __str__(self) -> str: return self.c_string()