From 1dde0c9df90255d89daf8f0e4f624cf60e1278ce Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Tue, 25 Oct 2022 11:16:03 +0200
Subject: [PATCH] Sane Defaults for CreateKernelConfig

---
 pystencils/config.py            |  59 ++++++++++++++---
 pystencils/typing/types.py      |  11 ++-
 pystencils_tests/test_config.py | 114 ++++++++++++++++++++++++++++++++
 3 files changed, 173 insertions(+), 11 deletions(-)
 create mode 100644 pystencils_tests/test_config.py

diff --git a/pystencils/config.py b/pystencils/config.py
index ef7f3b17d..9ef7b62dd 100644
--- a/pystencils/config.py
+++ b/pystencils/config.py
@@ -3,13 +3,17 @@ from copy import copy
 from collections import defaultdict
 from dataclasses import dataclass, field
 from types import MappingProxyType
-from typing import Union, Tuple, List, Dict, Callable, Any
+from typing import Union, Tuple, List, Dict, Callable, Any, DefaultDict
 
 from pystencils import Target, Backend, Field
 from pystencils.typing.typed_sympy import BasicType
+from pystencils.typing.utilities import collate_types
 
 import numpy as np
 
+# TODO: There exists DTypeLike in NumPy which would be better than type for type hinting, to new at the moment
+# from numpy.typing import DTypeLike
+
 
 # TODO: CreateKernelConfig is bloated think of more classes better usage, factory whatever ...
 # Proposition: CreateKernelConfigs Classes for different targets?
@@ -30,17 +34,19 @@ class CreateKernelConfig:
     """
     Name of the generated function - only important if generated code is written out
     """
-    # TODO Sane defaults: config should check that the datatype is a Numpy type
-    # TODO Sane defaults: QoL default_number_float and default_number_int should be data_type if they are not specified
-    data_type: Union[str, Dict[str, BasicType]] = 'float64'
+    data_type: Union[type, str, DefaultDict[str, BasicType], Dict[str, BasicType]] = np.float64
     """
-    Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type
+    Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type.
+    If specified as a dict ideally a defaultdict is used to define a default value for symbols not listed in the
+    dict. If a plain dict is provided it will be transformed into a defaultdict internally. The default value 
+    will then be specified via type collation then.
     """
-    default_number_float: Union[str, np.dtype, BasicType] = 'float64'
+    default_number_float: Union[type, str, BasicType] = None
     """
-    Data type used for all untyped floating point numbers (i.e. 0.5)
+    Data type used for all untyped floating point numbers (i.e. 0.5). By default the value of data_type is used.
+    If data_type is given as a defaultdict its default_factory is used.
     """
-    default_number_int: Union[str, np.dtype, BasicType] = 'int64'
+    default_number_int: Union[type, str, BasicType] = np.int64
     """
     Data type used for all untyped integer numbers (i.e. 1)
     """
@@ -133,9 +139,22 @@ class CreateKernelConfig:
         def __call__(self):
             return BasicType(self.dt)
 
+    def _check_type(self, dtype_to_check):
+        if isinstance(dtype_to_check, str) and (dtype_to_check == 'float' or dtype_to_check == 'int'):
+            self._typing_error()
+
+        if isinstance(dtype_to_check, type) and not hasattr(dtype_to_check, "dtype"):
+            # NumPy-types are also of type 'type'. However, they have more properties
+            self._typing_error()
+
+    @staticmethod
+    def _typing_error():
+        raise ValueError("It is not possible to use python types (float, int) for datatypes because these "
+                         "types are ambiguous. For example float will map to double. "
+                         "Also the string version like 'float' is not allowed, e.g. use 'float64' instead")
+
     def __post_init__(self):
         # ----  Legacy parameters
-        # TODO Sane defaults: Check for abmigous types like "float", python float, which are dangerous for users
         if isinstance(self.target, str):
             new_target = Target[self.target.upper()]
             warnings.warn(f'Target "{self.target}" as str is deprecated. Use {new_target} instead',
@@ -150,10 +169,30 @@ class CreateKernelConfig:
             else:
                 raise NotImplementedError(f'Target {self.target} has no default backend')
 
-        #  Normalise data types
+        # Normalise data types
+        for dtype in [self.data_type, self.default_number_float, self.default_number_int]:
+            self._check_type(dtype)
+
         if not isinstance(self.data_type, dict):
             dt = copy(self.data_type)  # The copy is necessary because BasicType has sympy shinanigans
             self.data_type = defaultdict(self.DataTypeFactory(dt))
+
+        if isinstance(self.data_type, dict) and not isinstance(self.data_type, defaultdict):
+            for dtype in self.data_type.values():
+                self._check_type(dtype)
+
+            dt = collate_types([BasicType(dtype) for dtype in self.data_type.values()])
+            dtype_dict = self.data_type
+            self.data_type = defaultdict(self.DataTypeFactory(dt), dtype_dict)
+
+        assert isinstance(self.data_type, defaultdict), "At this point data_type must be a defaultdict!"
+        for dtype in self.data_type.values():
+            self._check_type(dtype)
+        self._check_type(self.data_type.default_factory())
+
+        if self.default_number_float is None:
+            self.default_number_float = self.data_type.default_factory()
+
         if not isinstance(self.default_number_float, BasicType):
             self.default_number_float = BasicType(self.default_number_float)
         if not isinstance(self.default_number_int, BasicType):
diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py
index 06a2888ac..d1c473a0a 100644
--- a/pystencils/typing/types.py
+++ b/pystencils/typing/types.py
@@ -120,7 +120,10 @@ class BasicType(AbstractType):
         return f'{self.c_name}{" const" if self.const else ""}'
 
     def __repr__(self):
-        return str(self)
+        return f'BasicType( {str(self)} )'
+
+    def _repr_html_(self):
+        return f'BasicType( {str(self)} )'
 
     def __eq__(self, other):
         return self.dtype_eq(other) and self.const == other.const
@@ -216,6 +219,9 @@ class PointerType(AbstractType):
     def __repr__(self):
         return str(self)
 
+    def _repr_html_(self):
+        return str(self)
+
     def __hash__(self):
         return hash((self._base_type, self.const, self.restrict))
 
@@ -273,6 +279,9 @@ class StructType(AbstractType):
     def __repr__(self):
         return str(self)
 
+    def _repr_html_(self):
+        return str(self)
+
     def __hash__(self):
         return hash((self.numpy_dtype, self.const))
 
diff --git a/pystencils_tests/test_config.py b/pystencils_tests/test_config.py
new file mode 100644
index 000000000..31f6f9ec9
--- /dev/null
+++ b/pystencils_tests/test_config.py
@@ -0,0 +1,114 @@
+from collections import defaultdict
+import numpy as np
+import pytest
+
+from pystencils import CreateKernelConfig, Target, Backend
+from pystencils.typing import BasicType
+
+
+def test_config():
+    # targets
+    config = CreateKernelConfig(target=Target.CPU)
+    assert config.target == Target.CPU
+    assert config.backend == Backend.C
+
+    config = CreateKernelConfig(target=Target.GPU)
+    assert config.target == Target.GPU
+    assert config.backend == Backend.CUDA
+
+    # typing
+    config = CreateKernelConfig(data_type=np.float64)
+    assert isinstance(config.data_type, defaultdict)
+    assert config.data_type.default_factory() == BasicType('float64')
+    assert config.default_number_float == BasicType('float64')
+    assert config.default_number_int == BasicType('int64')
+
+    config = CreateKernelConfig(data_type=np.float32)
+    assert isinstance(config.data_type, defaultdict)
+    assert config.data_type.default_factory() == BasicType('float32')
+    assert config.default_number_float == BasicType('float32')
+    assert config.default_number_int == BasicType('int64')
+
+    config = CreateKernelConfig(data_type=np.float32, default_number_float=np.float64)
+    assert isinstance(config.data_type, defaultdict)
+    assert config.data_type.default_factory() == BasicType('float32')
+    assert config.default_number_float == BasicType('float64')
+    assert config.default_number_int == BasicType('int64')
+
+    config = CreateKernelConfig(data_type=np.float32, default_number_float=np.float64, default_number_int=np.int16)
+    assert isinstance(config.data_type, defaultdict)
+    assert config.data_type.default_factory() == BasicType('float32')
+    assert config.default_number_float == BasicType('float64')
+    assert config.default_number_int == BasicType('int16')
+
+    config = CreateKernelConfig(data_type='float64')
+    assert isinstance(config.data_type, defaultdict)
+    assert config.data_type.default_factory() == BasicType('float64')
+    assert config.default_number_float == BasicType('float64')
+    assert config.default_number_int == BasicType('int64')
+
+    config = CreateKernelConfig(data_type={'a': np.float64, 'b': np.float32})
+    assert isinstance(config.data_type, defaultdict)
+    assert config.data_type.default_factory() == BasicType('float64')
+    assert config.default_number_float == BasicType('float64')
+    assert config.default_number_int == BasicType('int64')
+
+    config = CreateKernelConfig(data_type={'a': np.float32, 'b': np.int32})
+    assert isinstance(config.data_type, defaultdict)
+    assert config.data_type.default_factory() == BasicType('float32')
+    assert config.default_number_float == BasicType('float32')
+    assert config.default_number_int == BasicType('int64')
+
+
+def test_config_python_types():
+    with pytest.raises(ValueError):
+        CreateKernelConfig(data_type=float)
+
+
+def test_config_python_types2():
+    with pytest.raises(ValueError):
+        CreateKernelConfig(data_type={'a': float})
+
+
+def test_config_python_types3():
+    with pytest.raises(ValueError):
+        CreateKernelConfig(default_number_float=float)
+
+
+def test_config_python_types4():
+    with pytest.raises(ValueError):
+        CreateKernelConfig(default_number_int=int)
+
+
+def test_config_python_types5():
+    with pytest.raises(ValueError):
+        CreateKernelConfig(data_type="float")
+
+
+def test_config_python_types6():
+    with pytest.raises(ValueError):
+        CreateKernelConfig(default_number_float="float")
+
+
+def test_config_python_types7():
+    dtype = defaultdict(lambda: 'float', {'a': np.float64, 'b': np.int64})
+    with pytest.raises(ValueError):
+        CreateKernelConfig(data_type=dtype)
+
+
+def test_config_python_types8():
+    dtype = defaultdict(lambda: float, {'a': np.float64, 'b': np.int64})
+    with pytest.raises(ValueError):
+        CreateKernelConfig(data_type=dtype)
+
+
+def test_config_python_types9():
+    dtype = defaultdict(lambda: 'float32', {'a': 'float', 'b': np.int64})
+    with pytest.raises(ValueError):
+        CreateKernelConfig(data_type=dtype)
+
+
+def test_config_python_types10():
+    dtype = defaultdict(lambda: 'float32', {'a': float, 'b': np.int64})
+    with pytest.raises(ValueError):
+        CreateKernelConfig(data_type=dtype)
-- 
GitLab