From 1dde0c9df90255d89daf8f0e4f624cf60e1278ce Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Tue, 25 Oct 2022 11:16:03 +0200 Subject: [PATCH] Sane Defaults for CreateKernelConfig --- pystencils/config.py | 59 ++++++++++++++--- pystencils/typing/types.py | 11 ++- pystencils_tests/test_config.py | 114 ++++++++++++++++++++++++++++++++ 3 files changed, 173 insertions(+), 11 deletions(-) create mode 100644 pystencils_tests/test_config.py diff --git a/pystencils/config.py b/pystencils/config.py index ef7f3b17d..9ef7b62dd 100644 --- a/pystencils/config.py +++ b/pystencils/config.py @@ -3,13 +3,17 @@ from copy import copy from collections import defaultdict from dataclasses import dataclass, field from types import MappingProxyType -from typing import Union, Tuple, List, Dict, Callable, Any +from typing import Union, Tuple, List, Dict, Callable, Any, DefaultDict from pystencils import Target, Backend, Field from pystencils.typing.typed_sympy import BasicType +from pystencils.typing.utilities import collate_types import numpy as np +# TODO: There exists DTypeLike in NumPy which would be better than type for type hinting, to new at the moment +# from numpy.typing import DTypeLike + # TODO: CreateKernelConfig is bloated think of more classes better usage, factory whatever ... # Proposition: CreateKernelConfigs Classes for different targets? @@ -30,17 +34,19 @@ class CreateKernelConfig: """ Name of the generated function - only important if generated code is written out """ - # TODO Sane defaults: config should check that the datatype is a Numpy type - # TODO Sane defaults: QoL default_number_float and default_number_int should be data_type if they are not specified - data_type: Union[str, Dict[str, BasicType]] = 'float64' + data_type: Union[type, str, DefaultDict[str, BasicType], Dict[str, BasicType]] = np.float64 """ - Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type + Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type. + If specified as a dict ideally a defaultdict is used to define a default value for symbols not listed in the + dict. If a plain dict is provided it will be transformed into a defaultdict internally. The default value + will then be specified via type collation then. """ - default_number_float: Union[str, np.dtype, BasicType] = 'float64' + default_number_float: Union[type, str, BasicType] = None """ - Data type used for all untyped floating point numbers (i.e. 0.5) + Data type used for all untyped floating point numbers (i.e. 0.5). By default the value of data_type is used. + If data_type is given as a defaultdict its default_factory is used. """ - default_number_int: Union[str, np.dtype, BasicType] = 'int64' + default_number_int: Union[type, str, BasicType] = np.int64 """ Data type used for all untyped integer numbers (i.e. 1) """ @@ -133,9 +139,22 @@ class CreateKernelConfig: def __call__(self): return BasicType(self.dt) + def _check_type(self, dtype_to_check): + if isinstance(dtype_to_check, str) and (dtype_to_check == 'float' or dtype_to_check == 'int'): + self._typing_error() + + if isinstance(dtype_to_check, type) and not hasattr(dtype_to_check, "dtype"): + # NumPy-types are also of type 'type'. However, they have more properties + self._typing_error() + + @staticmethod + def _typing_error(): + raise ValueError("It is not possible to use python types (float, int) for datatypes because these " + "types are ambiguous. For example float will map to double. " + "Also the string version like 'float' is not allowed, e.g. use 'float64' instead") + def __post_init__(self): # ---- Legacy parameters - # TODO Sane defaults: Check for abmigous types like "float", python float, which are dangerous for users if isinstance(self.target, str): new_target = Target[self.target.upper()] warnings.warn(f'Target "{self.target}" as str is deprecated. Use {new_target} instead', @@ -150,10 +169,30 @@ class CreateKernelConfig: else: raise NotImplementedError(f'Target {self.target} has no default backend') - # Normalise data types + # Normalise data types + for dtype in [self.data_type, self.default_number_float, self.default_number_int]: + self._check_type(dtype) + if not isinstance(self.data_type, dict): dt = copy(self.data_type) # The copy is necessary because BasicType has sympy shinanigans self.data_type = defaultdict(self.DataTypeFactory(dt)) + + if isinstance(self.data_type, dict) and not isinstance(self.data_type, defaultdict): + for dtype in self.data_type.values(): + self._check_type(dtype) + + dt = collate_types([BasicType(dtype) for dtype in self.data_type.values()]) + dtype_dict = self.data_type + self.data_type = defaultdict(self.DataTypeFactory(dt), dtype_dict) + + assert isinstance(self.data_type, defaultdict), "At this point data_type must be a defaultdict!" + for dtype in self.data_type.values(): + self._check_type(dtype) + self._check_type(self.data_type.default_factory()) + + if self.default_number_float is None: + self.default_number_float = self.data_type.default_factory() + if not isinstance(self.default_number_float, BasicType): self.default_number_float = BasicType(self.default_number_float) if not isinstance(self.default_number_int, BasicType): diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py index 06a2888ac..d1c473a0a 100644 --- a/pystencils/typing/types.py +++ b/pystencils/typing/types.py @@ -120,7 +120,10 @@ class BasicType(AbstractType): return f'{self.c_name}{" const" if self.const else ""}' def __repr__(self): - return str(self) + return f'BasicType( {str(self)} )' + + def _repr_html_(self): + return f'BasicType( {str(self)} )' def __eq__(self, other): return self.dtype_eq(other) and self.const == other.const @@ -216,6 +219,9 @@ class PointerType(AbstractType): def __repr__(self): return str(self) + def _repr_html_(self): + return str(self) + def __hash__(self): return hash((self._base_type, self.const, self.restrict)) @@ -273,6 +279,9 @@ class StructType(AbstractType): def __repr__(self): return str(self) + def _repr_html_(self): + return str(self) + def __hash__(self): return hash((self.numpy_dtype, self.const)) diff --git a/pystencils_tests/test_config.py b/pystencils_tests/test_config.py new file mode 100644 index 000000000..31f6f9ec9 --- /dev/null +++ b/pystencils_tests/test_config.py @@ -0,0 +1,114 @@ +from collections import defaultdict +import numpy as np +import pytest + +from pystencils import CreateKernelConfig, Target, Backend +from pystencils.typing import BasicType + + +def test_config(): + # targets + config = CreateKernelConfig(target=Target.CPU) + assert config.target == Target.CPU + assert config.backend == Backend.C + + config = CreateKernelConfig(target=Target.GPU) + assert config.target == Target.GPU + assert config.backend == Backend.CUDA + + # typing + config = CreateKernelConfig(data_type=np.float64) + assert isinstance(config.data_type, defaultdict) + assert config.data_type.default_factory() == BasicType('float64') + assert config.default_number_float == BasicType('float64') + assert config.default_number_int == BasicType('int64') + + config = CreateKernelConfig(data_type=np.float32) + assert isinstance(config.data_type, defaultdict) + assert config.data_type.default_factory() == BasicType('float32') + assert config.default_number_float == BasicType('float32') + assert config.default_number_int == BasicType('int64') + + config = CreateKernelConfig(data_type=np.float32, default_number_float=np.float64) + assert isinstance(config.data_type, defaultdict) + assert config.data_type.default_factory() == BasicType('float32') + assert config.default_number_float == BasicType('float64') + assert config.default_number_int == BasicType('int64') + + config = CreateKernelConfig(data_type=np.float32, default_number_float=np.float64, default_number_int=np.int16) + assert isinstance(config.data_type, defaultdict) + assert config.data_type.default_factory() == BasicType('float32') + assert config.default_number_float == BasicType('float64') + assert config.default_number_int == BasicType('int16') + + config = CreateKernelConfig(data_type='float64') + assert isinstance(config.data_type, defaultdict) + assert config.data_type.default_factory() == BasicType('float64') + assert config.default_number_float == BasicType('float64') + assert config.default_number_int == BasicType('int64') + + config = CreateKernelConfig(data_type={'a': np.float64, 'b': np.float32}) + assert isinstance(config.data_type, defaultdict) + assert config.data_type.default_factory() == BasicType('float64') + assert config.default_number_float == BasicType('float64') + assert config.default_number_int == BasicType('int64') + + config = CreateKernelConfig(data_type={'a': np.float32, 'b': np.int32}) + assert isinstance(config.data_type, defaultdict) + assert config.data_type.default_factory() == BasicType('float32') + assert config.default_number_float == BasicType('float32') + assert config.default_number_int == BasicType('int64') + + +def test_config_python_types(): + with pytest.raises(ValueError): + CreateKernelConfig(data_type=float) + + +def test_config_python_types2(): + with pytest.raises(ValueError): + CreateKernelConfig(data_type={'a': float}) + + +def test_config_python_types3(): + with pytest.raises(ValueError): + CreateKernelConfig(default_number_float=float) + + +def test_config_python_types4(): + with pytest.raises(ValueError): + CreateKernelConfig(default_number_int=int) + + +def test_config_python_types5(): + with pytest.raises(ValueError): + CreateKernelConfig(data_type="float") + + +def test_config_python_types6(): + with pytest.raises(ValueError): + CreateKernelConfig(default_number_float="float") + + +def test_config_python_types7(): + dtype = defaultdict(lambda: 'float', {'a': np.float64, 'b': np.int64}) + with pytest.raises(ValueError): + CreateKernelConfig(data_type=dtype) + + +def test_config_python_types8(): + dtype = defaultdict(lambda: float, {'a': np.float64, 'b': np.int64}) + with pytest.raises(ValueError): + CreateKernelConfig(data_type=dtype) + + +def test_config_python_types9(): + dtype = defaultdict(lambda: 'float32', {'a': 'float', 'b': np.int64}) + with pytest.raises(ValueError): + CreateKernelConfig(data_type=dtype) + + +def test_config_python_types10(): + dtype = defaultdict(lambda: 'float32', {'a': float, 'b': np.int64}) + with pytest.raises(ValueError): + CreateKernelConfig(data_type=dtype) -- GitLab