From 5a7cb0dbcc6f32898bbbda87f89baca378d5a44f Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 14 Mar 2024 14:00:20 +0100 Subject: [PATCH] fix field.create_fixed_size dtype param --- src/pystencils/__init__.py | 4 +++- src/pystencils/field.py | 9 +++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 56cf6c2bb..61016e14f 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -7,7 +7,7 @@ from . import stencil as stencil from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .field import Field, FieldType, fields from .cache import clear_cache -from .config import CreateKernelConfig +from .config import CreateKernelConfig, CpuOptimConfig, VectorizationConfig from .kernel_decorator import kernel, kernel_config from .kernelcreation import create_kernel from .backend.kernelfunction import KernelFunction @@ -36,6 +36,8 @@ __all__ = [ "TypedSymbol", "make_slice", "CreateKernelConfig", + "CpuOptimConfig", + "VectorizationConfig", "create_kernel", "KernelFunction", "Target", diff --git a/src/pystencils/field.py b/src/pystencils/field.py index b055ccb6b..3f019f566 100644 --- a/src/pystencils/field.py +++ b/src/pystencils/field.py @@ -160,6 +160,7 @@ class Field: dtype = create_type(dtype) np_data_type = dtype.numpy_dtype assert np_data_type is not None + if np_data_type.fields is not None: if index_dimensions != 0: raise ValueError("Structured arrays/fields are not allowed to have an index dimension") @@ -207,7 +208,8 @@ class Field: @staticmethod def create_fixed_size(field_name: str, shape: Tuple[int, ...], index_dimensions: int = 0, - dtype=np.float64, layout: str = 'numpy', strides: Optional[Sequence[int]] = None, + dtype: UserTypeSpec = np.float64, layout: str = 'numpy', + strides: Optional[Sequence[int]] = None, field_type=FieldType.GENERIC) -> 'Field': """ Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout @@ -234,7 +236,10 @@ class Field: assert len(strides) == len(shape) strides = tuple([s // np.dtype(dtype).itemsize for s in strides]) - numpy_dtype = np.dtype(dtype) + dtype = create_type(dtype) + numpy_dtype = dtype.numpy_dtype + assert numpy_dtype is not None + if numpy_dtype.fields is not None: if index_dimensions != 0: raise ValueError("Structured arrays/fields are not allowed to have an index dimension") -- GitLab