Skip to content
Snippets Groups Projects
Commit e295e9f9 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'fhennig/api-compatibility' into 'backend-rework'

API Backward Compatibility and Deprecations

See merge request pycodegen/pystencils!402
parents 1e7b4f56 155a25c5
Branches
Tags
1 merge request!402API Backward Compatibility and Deprecations
Pipeline #67491 passed with stages
in 1 minute and 17 seconds
def _deprecated(feature, instead, version="2.1"):
from warnings import warn
warn(
f"{feature} is deprecated and will be removed in pystencils {version}."
f"Use {instead} instead.",
DeprecationWarning,
)
......@@ -169,15 +169,19 @@ class CAstPrinter:
def __call__(self, obj: PsAstNode | KernelFunction) -> str:
if isinstance(obj, KernelFunction):
params_str = ", ".join(
f"{p.dtype.c_string()} {p.name}" for p in obj.parameters
)
decl = f"FUNC_PREFIX void {obj.name} ({params_str})"
sig = self.print_signature(obj)
body_code = self.visit(obj.body, PrinterCtx())
return f"{decl}\n{body_code}"
return f"{sig}\n{body_code}"
else:
return self.visit(obj, PrinterCtx())
def print_signature(self, func: KernelFunction) -> str:
params_str = ", ".join(
f"{p.dtype.c_string()} {p.name}" for p in func.parameters
)
signature = f"FUNC_PREFIX void {func.name} ({params_str})"
return signature
def visit(self, node: PsAstNode, pc: PrinterCtx) -> str:
match node:
case PsBlock(statements):
......
from __future__ import annotations
from typing import Iterable, Iterator
from typing import Iterable, Iterator, Any
from itertools import chain, count
from types import EllipsisType
from collections import namedtuple, defaultdict
......@@ -79,6 +79,8 @@ class KernelCreationContext:
self._constraints: list[KernelParamsConstraint] = []
self._req_headers: set[str] = set()
self._metadata: dict[str, Any] = dict()
@property
def default_dtype(self) -> PsNumericType:
return self._default_dtype
......@@ -95,6 +97,10 @@ class KernelCreationContext:
@property
def constraints(self) -> tuple[KernelParamsConstraint, ...]:
return tuple(self._constraints)
@property
def metadata(self) -> dict[str, Any]:
return self._metadata
# Symbols
......
......@@ -404,11 +404,9 @@ def create_full_iteration_space(
if len(domain_field_accesses) > 0:
archetype_field = get_archetype_field(ctx.fields.domain_fields)
inferred_gls = max([fa.required_ghost_layers for fa in domain_field_accesses])
elif len(ctx.fields.custom_fields) > 0:
# TODO: Warn about inferring iteration space from custom fields
archetype_field = get_archetype_field(ctx.fields.custom_fields)
inferred_gls = 0
else:
raise PsInputError(
"Unable to construct iteration space: The kernel contains no accesses to domain or custom fields."
......@@ -419,6 +417,7 @@ def create_full_iteration_space(
# Otherwise, use the inferred ghost layers
if ghost_layers is not None:
ctx.metadata["ghost_layers"] = ghost_layers
return FullIterationSpace.create_with_ghost_layers(
ctx, ghost_layers, archetype_field
)
......@@ -427,6 +426,12 @@ def create_full_iteration_space(
ctx, iteration_slice, archetype_field
)
else:
if len(domain_field_accesses) > 0:
inferred_gls = max([fa.required_ghost_layers for fa in domain_field_accesses])
else:
inferred_gls = 0
ctx.metadata["ghost_layers"] = inferred_gls
return FullIterationSpace.create_with_ghost_layers(
ctx, inferred_gls, archetype_field
)
from __future__ import annotations
from warnings import warn
from abc import ABC
from typing import Callable, Sequence
from typing import Callable, Sequence, Any
from .._deprecation import _deprecated
from .ast.structural import PsBlock
......@@ -11,6 +14,12 @@ from .jit import JitBase, no_jit
from ..enums import Target
from ..field import Field
from ..sympyextensions import TypedSymbol
from ..sympyextensions.typed_sympy import (
FieldShapeSymbol,
FieldStrideSymbol,
FieldPointerSymbol,
)
class KernelParameter:
......@@ -49,6 +58,46 @@ class KernelParameter:
def __repr__(self) -> str:
return f"{type(self).__name__}(name = {self._name}, dtype = {self._dtype})"
@property
def symbol(self) -> TypedSymbol:
return TypedSymbol(self.name, self.dtype)
@property
def is_field_parameter(self) -> bool:
warn(
"`is_field_parameter` is deprecated and will be removed in a future version of pystencils. "
"Use `isinstance(param, FieldParameter)` instead.",
DeprecationWarning,
)
return isinstance(self, FieldParameter)
@property
def is_field_pointer(self) -> bool:
warn(
"`is_field_pointer` is deprecated and will be removed in a future version of pystencils. "
"Use `isinstance(param, FieldPointerParam)` instead.",
DeprecationWarning,
)
return isinstance(self, FieldPointerParam)
@property
def is_field_stride(self) -> bool:
warn(
"`is_field_stride` is deprecated and will be removed in a future version of pystencils. "
"Use `isinstance(param, FieldStrideParam)` instead.",
DeprecationWarning,
)
return isinstance(self, FieldStrideParam)
@property
def is_field_shape(self) -> bool:
warn(
"`is_field_shape` is deprecated and will be removed in a future version of pystencils. "
"Use `isinstance(param, FieldShapeParam)` instead.",
DeprecationWarning,
)
return isinstance(self, FieldShapeParam)
class FieldParameter(KernelParameter, ABC):
__match_args__ = KernelParameter.__match_args__ + ("field",)
......@@ -61,6 +110,25 @@ class FieldParameter(KernelParameter, ABC):
def field(self):
return self._field
@property
def fields(self):
warn(
"`fields` is deprecated and will be removed in a future version of pystencils. "
"In pystencils >= 2.0, field parameters are only associated with a single field."
"Use the `field` property instead.",
DeprecationWarning,
)
return [self._field]
@property
def field_name(self) -> str:
warn(
"`field_name` is deprecated and will be removed in a future version of pystencils. "
"Use `field.name` instead.",
DeprecationWarning,
)
return self._field.name
def _hashable_contents(self):
return super()._hashable_contents() + (self._field,)
......@@ -76,6 +144,10 @@ class FieldShapeParam(FieldParameter):
def coordinate(self):
return self._coordinate
@property
def symbol(self) -> FieldShapeSymbol:
return FieldShapeSymbol(self.field.name, self.coordinate, self.dtype)
def _hashable_contents(self):
return super()._hashable_contents() + (self._coordinate,)
......@@ -91,6 +163,10 @@ class FieldStrideParam(FieldParameter):
def coordinate(self):
return self._coordinate
@property
def symbol(self) -> FieldStrideSymbol:
return FieldStrideSymbol(self.field.name, self.coordinate, self.dtype)
def _hashable_contents(self):
return super()._hashable_contents() + (self._coordinate,)
......@@ -99,6 +175,10 @@ class FieldPointerParam(FieldParameter):
def __init__(self, name: str, dtype: PsType, field: Field):
super().__init__(name, dtype, field)
@property
def symbol(self) -> FieldPointerSymbol:
return FieldPointerSymbol(self.field.name, self.field.dtype, const=True)
class KernelFunction:
"""A pystencils kernel function.
......@@ -125,6 +205,11 @@ class KernelFunction:
self._required_headers = required_headers
self._constraints = tuple(constraints)
self._jit = jit
self._metadata: dict[str, Any] = dict()
@property
def metadata(self) -> dict[str, Any]:
return self._metadata
@property
def body(self) -> PsBlock:
......@@ -144,13 +229,34 @@ class KernelFunction:
@property
def function_name(self) -> str:
"""For backward compatibility"""
_deprecated("function_name", "name")
return self._name
@function_name.setter
def function_name(self, n: str):
_deprecated("function_name", "name")
self._name = n
@property
def parameters(self) -> tuple[KernelParameter, ...]:
return self._params
def get_parameters(self) -> tuple[KernelParameter, ...]:
_deprecated("KernelFunction.get_parameters", "KernelFunction.parameters")
return self.parameters
def get_fields(self) -> set[Field]:
return set(p.field for p in self._params if isinstance(p, FieldParameter))
@property
def fields_accessed(self) -> set[Field]:
warn(
"`fields_accessed` is deprecated and will be removed in a future version of pystencils. "
"Use `get_fields` instead.",
DeprecationWarning,
)
return self.get_fields()
@property
def required_headers(self) -> set[str]:
return self._required_headers
......
......@@ -107,6 +107,9 @@ class AddOpenMP:
pragma_text += " parallel" if not omp_params.omit_parallel_construct else ""
pragma_text += f" for schedule({omp_params.schedule})"
if omp_params.num_threads is not None:
pragma_text += f" num_threads({str(omp_params.num_threads)})"
if omp_params.collapse > 0:
pragma_text += f" collapse({str(omp_params.collapse)})"
......
from __future__ import annotations
from warnings import warn
from collections.abc import Collection
from typing import Sequence
from dataclasses import dataclass
from dataclasses import dataclass, InitVar
from .enums import Target
from .field import Field, FieldType
......@@ -28,12 +29,21 @@ class OpenMpConfig:
schedule: str = "static"
"""Argument to the OpenMP ``schedule`` clause"""
num_threads: int | None = None
"""Set the number of OpenMP threads to execute the parallel region."""
omit_parallel_construct: bool = False
"""If set to ``True``, the OpenMP ``parallel`` construct is omitted, producing just a ``#pragma omp for``.
Use this option only if you intend to wrap the kernel into an external ``#pragma omp parallel`` region.
"""
def __post_init__(self):
if self.omit_parallel_construct and self.num_threads is not None:
raise PsOptionsError(
"Cannot specify `num_threads` if `omit_parallel_construct` is set."
)
@dataclass
class CpuOptimConfig:
......@@ -181,7 +191,24 @@ class CreateKernelConfig:
If this parameter is set while `target` is a non-CPU target, an error will be raised.
"""
def __post_init__(self):
# Deprecated Options
data_type: InitVar[UserTypeSpec | None] = None
"""Deprecated; use `default_dtype` instead"""
cpu_openmp: InitVar[bool | int | None] = None
"""Deprecated; use `cpu_optim.openmp` instead."""
cpu_vectorize_info: InitVar[dict | None] = None
"""Deprecated; use `cpu_optim.vectorize` instead."""
# Postprocessing
def __post_init__(self, *args):
# Check deprecated options
self._check_deprecations(*args)
# Check iteration space argument consistency
if (
int(self.iteration_slice is not None)
......@@ -228,3 +255,59 @@ class CreateKernelConfig:
raise NotImplementedError(
f"No default JIT compiler implemented yet for target {self.target}"
)
def _check_deprecations(
self,
data_type: UserTypeSpec | None,
cpu_openmp: bool | int | None,
cpu_vectorize_info: dict | None,
):
optim: CpuOptimConfig | None = None
if data_type is not None:
_deprecated_option("data_type", "default_dtype")
warn(
"Setting the deprecated `data_type` will override the value of `default_dtype`. "
"Set `default_dtype` instead.",
FutureWarning,
)
self.default_dtype = data_type
if cpu_openmp is not None:
_deprecated_option("cpu_openmp", "cpu_optim.openmp")
match cpu_openmp:
case True:
deprecated_omp = OpenMpConfig()
case False:
deprecated_omp = False
case int():
deprecated_omp = OpenMpConfig(num_threads=cpu_openmp)
case _:
raise PsOptionsError(
f"Invalid option for `cpu_openmp`: {cpu_openmp}"
)
optim = CpuOptimConfig(openmp=deprecated_omp)
if cpu_vectorize_info is not None:
_deprecated_option("cpu_vectorize_info", "cpu_optim.vectorize")
raise NotImplementedError("CPU vectorization is not implemented yet")
if optim is not None:
if self.cpu_optim is not None:
raise PsOptionsError(
"Cannot specify both `cpu_optim` and a deprecated legacy optimization option at the same time."
)
else:
self.cpu_optim = optim
def _deprecated_option(name, instead):
from warnings import warn
warn(
f"The `{name}` option of CreateKernelConfig is deprecated and will be removed in pystencils 2.1. "
f"Use `{instead}` instead.",
FutureWarning,
)
......@@ -55,7 +55,8 @@ def create_kernel(
"""
ctx = KernelCreationContext(
default_dtype=create_numeric_type(config.default_dtype), index_dtype=config.index_dtype
default_dtype=create_numeric_type(config.default_dtype),
index_dtype=config.index_dtype,
)
if isinstance(assignments, Assignment):
......@@ -150,9 +151,11 @@ def create_kernel_function(
req_headers = collect_required_headers(body)
req_headers |= ctx.required_headers
return KernelFunction(
kfunc = KernelFunction(
body, target_spec, function_name, params, req_headers, ctx.constraints, jit
)
kfunc.metadata.update(ctx.metadata)
return kfunc
def create_staggered_kernel(assignments, target: Target = Target.CPU, gpu_exclusive_conditions=False, **kwargs):
......
......@@ -36,6 +36,7 @@ of types, as well as for const-conversion.
from __future__ import annotations
from warnings import warn
from abc import ABCMeta, abstractmethod
from typing import TypeVar, Any, cast
import numpy as np
......@@ -159,6 +160,16 @@ class PsType(metaclass=PsTypeMeta):
def c_string(self) -> str:
pass
@property
def c_name(self) -> str:
"""Returns the C name of this type without const-qualifiers."""
warn(
"`c_name` is deprecated and will be removed in a future version of pystencils. "
"Use `c_string()` instead.",
DeprecationWarning,
)
return deconstify(self).c_string()
def __str__(self) -> str:
return self.c_string()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment