From 5a7cb0dbcc6f32898bbbda87f89baca378d5a44f Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 14 Mar 2024 14:00:20 +0100
Subject: [PATCH] fix field.create_fixed_size dtype param

---
 src/pystencils/__init__.py | 4 +++-
 src/pystencils/field.py    | 9 +++++++--
 2 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py
index 56cf6c2bb..61016e14f 100644
--- a/src/pystencils/__init__.py
+++ b/src/pystencils/__init__.py
@@ -7,7 +7,7 @@ from . import stencil as stencil
 from .display_utils import get_code_obj, get_code_str, show_code, to_dot
 from .field import Field, FieldType, fields
 from .cache import clear_cache
-from .config import CreateKernelConfig
+from .config import CreateKernelConfig, CpuOptimConfig, VectorizationConfig
 from .kernel_decorator import kernel, kernel_config
 from .kernelcreation import create_kernel
 from .backend.kernelfunction import KernelFunction
@@ -36,6 +36,8 @@ __all__ = [
     "TypedSymbol",
     "make_slice",
     "CreateKernelConfig",
+    "CpuOptimConfig",
+    "VectorizationConfig",
     "create_kernel",
     "KernelFunction",
     "Target",
diff --git a/src/pystencils/field.py b/src/pystencils/field.py
index b055ccb6b..3f019f566 100644
--- a/src/pystencils/field.py
+++ b/src/pystencils/field.py
@@ -160,6 +160,7 @@ class Field:
         dtype = create_type(dtype)
         np_data_type = dtype.numpy_dtype
         assert np_data_type is not None
+        
         if np_data_type.fields is not None:
             if index_dimensions != 0:
                 raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
@@ -207,7 +208,8 @@ class Field:
 
     @staticmethod
     def create_fixed_size(field_name: str, shape: Tuple[int, ...], index_dimensions: int = 0,
-                          dtype=np.float64, layout: str = 'numpy', strides: Optional[Sequence[int]] = None,
+                          dtype: UserTypeSpec = np.float64, layout: str = 'numpy',
+                          strides: Optional[Sequence[int]] = None,
                           field_type=FieldType.GENERIC) -> 'Field':
         """
         Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout
@@ -234,7 +236,10 @@ class Field:
             assert len(strides) == len(shape)
             strides = tuple([s // np.dtype(dtype).itemsize for s in strides])
 
-        numpy_dtype = np.dtype(dtype)
+        dtype = create_type(dtype)
+        numpy_dtype = dtype.numpy_dtype
+        assert numpy_dtype is not None
+
         if numpy_dtype.fields is not None:
             if index_dimensions != 0:
                 raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
-- 
GitLab