Skip to content
Snippets Groups Projects
Commit e295e9f9 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'fhennig/api-compatibility' into 'backend-rework'

API Backward Compatibility and Deprecations

See merge request !402
parents 1e7b4f56 155a25c5
1 merge request!402API Backward Compatibility and Deprecations
Pipeline #67491 passed with stages
in 1 minute and 17 seconds
def _deprecated(feature, instead, version="2.1"):
from warnings import warn
warn(
f"{feature} is deprecated and will be removed in pystencils {version}."
f"Use {instead} instead.",
DeprecationWarning,
)
...@@ -169,15 +169,19 @@ class CAstPrinter: ...@@ -169,15 +169,19 @@ class CAstPrinter:
def __call__(self, obj: PsAstNode | KernelFunction) -> str: def __call__(self, obj: PsAstNode | KernelFunction) -> str:
if isinstance(obj, KernelFunction): if isinstance(obj, KernelFunction):
params_str = ", ".join( sig = self.print_signature(obj)
f"{p.dtype.c_string()} {p.name}" for p in obj.parameters
)
decl = f"FUNC_PREFIX void {obj.name} ({params_str})"
body_code = self.visit(obj.body, PrinterCtx()) body_code = self.visit(obj.body, PrinterCtx())
return f"{decl}\n{body_code}" return f"{sig}\n{body_code}"
else: else:
return self.visit(obj, PrinterCtx()) return self.visit(obj, PrinterCtx())
def print_signature(self, func: KernelFunction) -> str:
params_str = ", ".join(
f"{p.dtype.c_string()} {p.name}" for p in func.parameters
)
signature = f"FUNC_PREFIX void {func.name} ({params_str})"
return signature
def visit(self, node: PsAstNode, pc: PrinterCtx) -> str: def visit(self, node: PsAstNode, pc: PrinterCtx) -> str:
match node: match node:
case PsBlock(statements): case PsBlock(statements):
......
from __future__ import annotations from __future__ import annotations
from typing import Iterable, Iterator from typing import Iterable, Iterator, Any
from itertools import chain, count from itertools import chain, count
from types import EllipsisType from types import EllipsisType
from collections import namedtuple, defaultdict from collections import namedtuple, defaultdict
...@@ -79,6 +79,8 @@ class KernelCreationContext: ...@@ -79,6 +79,8 @@ class KernelCreationContext:
self._constraints: list[KernelParamsConstraint] = [] self._constraints: list[KernelParamsConstraint] = []
self._req_headers: set[str] = set() self._req_headers: set[str] = set()
self._metadata: dict[str, Any] = dict()
@property @property
def default_dtype(self) -> PsNumericType: def default_dtype(self) -> PsNumericType:
return self._default_dtype return self._default_dtype
...@@ -95,6 +97,10 @@ class KernelCreationContext: ...@@ -95,6 +97,10 @@ class KernelCreationContext:
@property @property
def constraints(self) -> tuple[KernelParamsConstraint, ...]: def constraints(self) -> tuple[KernelParamsConstraint, ...]:
return tuple(self._constraints) return tuple(self._constraints)
@property
def metadata(self) -> dict[str, Any]:
return self._metadata
# Symbols # Symbols
......
...@@ -404,11 +404,9 @@ def create_full_iteration_space( ...@@ -404,11 +404,9 @@ def create_full_iteration_space(
if len(domain_field_accesses) > 0: if len(domain_field_accesses) > 0:
archetype_field = get_archetype_field(ctx.fields.domain_fields) archetype_field = get_archetype_field(ctx.fields.domain_fields)
inferred_gls = max([fa.required_ghost_layers for fa in domain_field_accesses])
elif len(ctx.fields.custom_fields) > 0: elif len(ctx.fields.custom_fields) > 0:
# TODO: Warn about inferring iteration space from custom fields # TODO: Warn about inferring iteration space from custom fields
archetype_field = get_archetype_field(ctx.fields.custom_fields) archetype_field = get_archetype_field(ctx.fields.custom_fields)
inferred_gls = 0
else: else:
raise PsInputError( raise PsInputError(
"Unable to construct iteration space: The kernel contains no accesses to domain or custom fields." "Unable to construct iteration space: The kernel contains no accesses to domain or custom fields."
...@@ -419,6 +417,7 @@ def create_full_iteration_space( ...@@ -419,6 +417,7 @@ def create_full_iteration_space(
# Otherwise, use the inferred ghost layers # Otherwise, use the inferred ghost layers
if ghost_layers is not None: if ghost_layers is not None:
ctx.metadata["ghost_layers"] = ghost_layers
return FullIterationSpace.create_with_ghost_layers( return FullIterationSpace.create_with_ghost_layers(
ctx, ghost_layers, archetype_field ctx, ghost_layers, archetype_field
) )
...@@ -427,6 +426,12 @@ def create_full_iteration_space( ...@@ -427,6 +426,12 @@ def create_full_iteration_space(
ctx, iteration_slice, archetype_field ctx, iteration_slice, archetype_field
) )
else: else:
if len(domain_field_accesses) > 0:
inferred_gls = max([fa.required_ghost_layers for fa in domain_field_accesses])
else:
inferred_gls = 0
ctx.metadata["ghost_layers"] = inferred_gls
return FullIterationSpace.create_with_ghost_layers( return FullIterationSpace.create_with_ghost_layers(
ctx, inferred_gls, archetype_field ctx, inferred_gls, archetype_field
) )
from __future__ import annotations from __future__ import annotations
from warnings import warn
from abc import ABC from abc import ABC
from typing import Callable, Sequence from typing import Callable, Sequence, Any
from .._deprecation import _deprecated
from .ast.structural import PsBlock from .ast.structural import PsBlock
...@@ -11,6 +14,12 @@ from .jit import JitBase, no_jit ...@@ -11,6 +14,12 @@ from .jit import JitBase, no_jit
from ..enums import Target from ..enums import Target
from ..field import Field from ..field import Field
from ..sympyextensions import TypedSymbol
from ..sympyextensions.typed_sympy import (
FieldShapeSymbol,
FieldStrideSymbol,
FieldPointerSymbol,
)
class KernelParameter: class KernelParameter:
...@@ -49,6 +58,46 @@ class KernelParameter: ...@@ -49,6 +58,46 @@ class KernelParameter:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{type(self).__name__}(name = {self._name}, dtype = {self._dtype})" return f"{type(self).__name__}(name = {self._name}, dtype = {self._dtype})"
@property
def symbol(self) -> TypedSymbol:
return TypedSymbol(self.name, self.dtype)
@property
def is_field_parameter(self) -> bool:
warn(
"`is_field_parameter` is deprecated and will be removed in a future version of pystencils. "
"Use `isinstance(param, FieldParameter)` instead.",
DeprecationWarning,
)
return isinstance(self, FieldParameter)
@property
def is_field_pointer(self) -> bool:
warn(
"`is_field_pointer` is deprecated and will be removed in a future version of pystencils. "
"Use `isinstance(param, FieldPointerParam)` instead.",
DeprecationWarning,
)
return isinstance(self, FieldPointerParam)
@property
def is_field_stride(self) -> bool:
warn(
"`is_field_stride` is deprecated and will be removed in a future version of pystencils. "
"Use `isinstance(param, FieldStrideParam)` instead.",
DeprecationWarning,
)
return isinstance(self, FieldStrideParam)
@property
def is_field_shape(self) -> bool:
warn(
"`is_field_shape` is deprecated and will be removed in a future version of pystencils. "
"Use `isinstance(param, FieldShapeParam)` instead.",
DeprecationWarning,
)
return isinstance(self, FieldShapeParam)
class FieldParameter(KernelParameter, ABC): class FieldParameter(KernelParameter, ABC):
__match_args__ = KernelParameter.__match_args__ + ("field",) __match_args__ = KernelParameter.__match_args__ + ("field",)
...@@ -61,6 +110,25 @@ class FieldParameter(KernelParameter, ABC): ...@@ -61,6 +110,25 @@ class FieldParameter(KernelParameter, ABC):
def field(self): def field(self):
return self._field return self._field
@property
def fields(self):
warn(
"`fields` is deprecated and will be removed in a future version of pystencils. "
"In pystencils >= 2.0, field parameters are only associated with a single field."
"Use the `field` property instead.",
DeprecationWarning,
)
return [self._field]
@property
def field_name(self) -> str:
warn(
"`field_name` is deprecated and will be removed in a future version of pystencils. "
"Use `field.name` instead.",
DeprecationWarning,
)
return self._field.name
def _hashable_contents(self): def _hashable_contents(self):
return super()._hashable_contents() + (self._field,) return super()._hashable_contents() + (self._field,)
...@@ -76,6 +144,10 @@ class FieldShapeParam(FieldParameter): ...@@ -76,6 +144,10 @@ class FieldShapeParam(FieldParameter):
def coordinate(self): def coordinate(self):
return self._coordinate return self._coordinate
@property
def symbol(self) -> FieldShapeSymbol:
return FieldShapeSymbol(self.field.name, self.coordinate, self.dtype)
def _hashable_contents(self): def _hashable_contents(self):
return super()._hashable_contents() + (self._coordinate,) return super()._hashable_contents() + (self._coordinate,)
...@@ -91,6 +163,10 @@ class FieldStrideParam(FieldParameter): ...@@ -91,6 +163,10 @@ class FieldStrideParam(FieldParameter):
def coordinate(self): def coordinate(self):
return self._coordinate return self._coordinate
@property
def symbol(self) -> FieldStrideSymbol:
return FieldStrideSymbol(self.field.name, self.coordinate, self.dtype)
def _hashable_contents(self): def _hashable_contents(self):
return super()._hashable_contents() + (self._coordinate,) return super()._hashable_contents() + (self._coordinate,)
...@@ -99,6 +175,10 @@ class FieldPointerParam(FieldParameter): ...@@ -99,6 +175,10 @@ class FieldPointerParam(FieldParameter):
def __init__(self, name: str, dtype: PsType, field: Field): def __init__(self, name: str, dtype: PsType, field: Field):
super().__init__(name, dtype, field) super().__init__(name, dtype, field)
@property
def symbol(self) -> FieldPointerSymbol:
return FieldPointerSymbol(self.field.name, self.field.dtype, const=True)
class KernelFunction: class KernelFunction:
"""A pystencils kernel function. """A pystencils kernel function.
...@@ -125,6 +205,11 @@ class KernelFunction: ...@@ -125,6 +205,11 @@ class KernelFunction:
self._required_headers = required_headers self._required_headers = required_headers
self._constraints = tuple(constraints) self._constraints = tuple(constraints)
self._jit = jit self._jit = jit
self._metadata: dict[str, Any] = dict()
@property
def metadata(self) -> dict[str, Any]:
return self._metadata
@property @property
def body(self) -> PsBlock: def body(self) -> PsBlock:
...@@ -144,13 +229,34 @@ class KernelFunction: ...@@ -144,13 +229,34 @@ class KernelFunction:
@property @property
def function_name(self) -> str: def function_name(self) -> str:
"""For backward compatibility""" _deprecated("function_name", "name")
return self._name return self._name
@function_name.setter
def function_name(self, n: str):
_deprecated("function_name", "name")
self._name = n
@property @property
def parameters(self) -> tuple[KernelParameter, ...]: def parameters(self) -> tuple[KernelParameter, ...]:
return self._params return self._params
def get_parameters(self) -> tuple[KernelParameter, ...]:
_deprecated("KernelFunction.get_parameters", "KernelFunction.parameters")
return self.parameters
def get_fields(self) -> set[Field]:
return set(p.field for p in self._params if isinstance(p, FieldParameter))
@property
def fields_accessed(self) -> set[Field]:
warn(
"`fields_accessed` is deprecated and will be removed in a future version of pystencils. "
"Use `get_fields` instead.",
DeprecationWarning,
)
return self.get_fields()
@property @property
def required_headers(self) -> set[str]: def required_headers(self) -> set[str]:
return self._required_headers return self._required_headers
......
...@@ -107,6 +107,9 @@ class AddOpenMP: ...@@ -107,6 +107,9 @@ class AddOpenMP:
pragma_text += " parallel" if not omp_params.omit_parallel_construct else "" pragma_text += " parallel" if not omp_params.omit_parallel_construct else ""
pragma_text += f" for schedule({omp_params.schedule})" pragma_text += f" for schedule({omp_params.schedule})"
if omp_params.num_threads is not None:
pragma_text += f" num_threads({str(omp_params.num_threads)})"
if omp_params.collapse > 0: if omp_params.collapse > 0:
pragma_text += f" collapse({str(omp_params.collapse)})" pragma_text += f" collapse({str(omp_params.collapse)})"
......
from __future__ import annotations from __future__ import annotations
from warnings import warn
from collections.abc import Collection from collections.abc import Collection
from typing import Sequence from typing import Sequence
from dataclasses import dataclass from dataclasses import dataclass, InitVar
from .enums import Target from .enums import Target
from .field import Field, FieldType from .field import Field, FieldType
...@@ -28,12 +29,21 @@ class OpenMpConfig: ...@@ -28,12 +29,21 @@ class OpenMpConfig:
schedule: str = "static" schedule: str = "static"
"""Argument to the OpenMP ``schedule`` clause""" """Argument to the OpenMP ``schedule`` clause"""
num_threads: int | None = None
"""Set the number of OpenMP threads to execute the parallel region."""
omit_parallel_construct: bool = False omit_parallel_construct: bool = False
"""If set to ``True``, the OpenMP ``parallel`` construct is omitted, producing just a ``#pragma omp for``. """If set to ``True``, the OpenMP ``parallel`` construct is omitted, producing just a ``#pragma omp for``.
Use this option only if you intend to wrap the kernel into an external ``#pragma omp parallel`` region. Use this option only if you intend to wrap the kernel into an external ``#pragma omp parallel`` region.
""" """
def __post_init__(self):
if self.omit_parallel_construct and self.num_threads is not None:
raise PsOptionsError(
"Cannot specify `num_threads` if `omit_parallel_construct` is set."
)
@dataclass @dataclass
class CpuOptimConfig: class CpuOptimConfig:
...@@ -181,7 +191,24 @@ class CreateKernelConfig: ...@@ -181,7 +191,24 @@ class CreateKernelConfig:
If this parameter is set while `target` is a non-CPU target, an error will be raised. If this parameter is set while `target` is a non-CPU target, an error will be raised.
""" """
def __post_init__(self): # Deprecated Options
data_type: InitVar[UserTypeSpec | None] = None
"""Deprecated; use `default_dtype` instead"""
cpu_openmp: InitVar[bool | int | None] = None
"""Deprecated; use `cpu_optim.openmp` instead."""
cpu_vectorize_info: InitVar[dict | None] = None
"""Deprecated; use `cpu_optim.vectorize` instead."""
# Postprocessing
def __post_init__(self, *args):
# Check deprecated options
self._check_deprecations(*args)
# Check iteration space argument consistency # Check iteration space argument consistency
if ( if (
int(self.iteration_slice is not None) int(self.iteration_slice is not None)
...@@ -228,3 +255,59 @@ class CreateKernelConfig: ...@@ -228,3 +255,59 @@ class CreateKernelConfig:
raise NotImplementedError( raise NotImplementedError(
f"No default JIT compiler implemented yet for target {self.target}" f"No default JIT compiler implemented yet for target {self.target}"
) )
def _check_deprecations(
self,
data_type: UserTypeSpec | None,
cpu_openmp: bool | int | None,
cpu_vectorize_info: dict | None,
):
optim: CpuOptimConfig | None = None
if data_type is not None:
_deprecated_option("data_type", "default_dtype")
warn(
"Setting the deprecated `data_type` will override the value of `default_dtype`. "
"Set `default_dtype` instead.",
FutureWarning,
)
self.default_dtype = data_type
if cpu_openmp is not None:
_deprecated_option("cpu_openmp", "cpu_optim.openmp")
match cpu_openmp:
case True:
deprecated_omp = OpenMpConfig()
case False:
deprecated_omp = False
case int():
deprecated_omp = OpenMpConfig(num_threads=cpu_openmp)
case _:
raise PsOptionsError(
f"Invalid option for `cpu_openmp`: {cpu_openmp}"
)
optim = CpuOptimConfig(openmp=deprecated_omp)
if cpu_vectorize_info is not None:
_deprecated_option("cpu_vectorize_info", "cpu_optim.vectorize")
raise NotImplementedError("CPU vectorization is not implemented yet")
if optim is not None:
if self.cpu_optim is not None:
raise PsOptionsError(
"Cannot specify both `cpu_optim` and a deprecated legacy optimization option at the same time."
)
else:
self.cpu_optim = optim
def _deprecated_option(name, instead):
from warnings import warn
warn(
f"The `{name}` option of CreateKernelConfig is deprecated and will be removed in pystencils 2.1. "
f"Use `{instead}` instead.",
FutureWarning,
)
...@@ -55,7 +55,8 @@ def create_kernel( ...@@ -55,7 +55,8 @@ def create_kernel(
""" """
ctx = KernelCreationContext( ctx = KernelCreationContext(
default_dtype=create_numeric_type(config.default_dtype), index_dtype=config.index_dtype default_dtype=create_numeric_type(config.default_dtype),
index_dtype=config.index_dtype,
) )
if isinstance(assignments, Assignment): if isinstance(assignments, Assignment):
...@@ -150,9 +151,11 @@ def create_kernel_function( ...@@ -150,9 +151,11 @@ def create_kernel_function(
req_headers = collect_required_headers(body) req_headers = collect_required_headers(body)
req_headers |= ctx.required_headers req_headers |= ctx.required_headers
return KernelFunction( kfunc = KernelFunction(
body, target_spec, function_name, params, req_headers, ctx.constraints, jit body, target_spec, function_name, params, req_headers, ctx.constraints, jit
) )
kfunc.metadata.update(ctx.metadata)
return kfunc
def create_staggered_kernel(assignments, target: Target = Target.CPU, gpu_exclusive_conditions=False, **kwargs): def create_staggered_kernel(assignments, target: Target = Target.CPU, gpu_exclusive_conditions=False, **kwargs):
......
...@@ -36,6 +36,7 @@ of types, as well as for const-conversion. ...@@ -36,6 +36,7 @@ of types, as well as for const-conversion.
from __future__ import annotations from __future__ import annotations
from warnings import warn
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import TypeVar, Any, cast from typing import TypeVar, Any, cast
import numpy as np import numpy as np
...@@ -159,6 +160,16 @@ class PsType(metaclass=PsTypeMeta): ...@@ -159,6 +160,16 @@ class PsType(metaclass=PsTypeMeta):
def c_string(self) -> str: def c_string(self) -> str:
pass pass
@property
def c_name(self) -> str:
"""Returns the C name of this type without const-qualifiers."""
warn(
"`c_name` is deprecated and will be removed in a future version of pystencils. "
"Use `c_string()` instead.",
DeprecationWarning,
)
return deconstify(self).c_string()
def __str__(self) -> str: def __str__(self) -> str:
return self.c_string() return self.c_string()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment