From 9c93da4a34d90dcd0f6e1c8eb11987aeb1494673 Mon Sep 17 00:00:00 2001 From: Christoph Alt <christoph.alt@fau.de> Date: Tue, 14 Sep 2021 16:35:29 +0000 Subject: [PATCH] fixed create_kernel parameter data_type="float" to procucde single precision --- pystencils/transformations.py | 15 ++++++++++ pystencils_tests/test_kernel_data_type.py | 36 +++++++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 pystencils_tests/test_kernel_data_type.py diff --git a/pystencils/transformations.py b/pystencils/transformations.py index b037225b9..c3a62948e 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -960,6 +960,8 @@ def add_types(eqs, type_for_symbol, check_independence_condition, check_double_w if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'): type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol) + type_for_symbol = adjust_c_single_precision_type(type_for_symbol) + check = KernelConstraintsCheck(type_for_symbol, check_independence_condition, check_double_write_condition=check_double_write_condition) @@ -1397,3 +1399,16 @@ def implement_interpolations(ast_node: ast.Node, ast_node.subs(substitutions) return ast_node + + +def adjust_c_single_precision_type(type_for_symbol): + """Replaces every occurrence of 'float' with 'single' to enforce the numpy single precision type.""" + def single_factory(): + return "single" + + for symbol in type_for_symbol: + if type_for_symbol[symbol] == "float": + type_for_symbol[symbol] = single_factory() + if hasattr(type_for_symbol, "default_factory") and type_for_symbol.default_factory() == "float": + type_for_symbol.default_factory = single_factory + return type_for_symbol diff --git a/pystencils_tests/test_kernel_data_type.py b/pystencils_tests/test_kernel_data_type.py new file mode 100644 index 000000000..2fbab3ff1 --- /dev/null +++ b/pystencils_tests/test_kernel_data_type.py @@ -0,0 +1,36 @@ +from collections import defaultdict + +import numpy as np +import pytest +from sympy.abc import x, y + +from pystencils import Assignment, create_kernel, fields, CreateKernelConfig +from pystencils.transformations import adjust_c_single_precision_type + + +@pytest.mark.parametrize("data_type", ("float", "double")) +def test_single_precision(data_type): + dtype = f"float{64 if data_type == 'double' else 32}" + s = fields(f"s: {dtype}[1D]") + assignments = [Assignment(x, y), Assignment(s[0], x)] + ast = create_kernel(assignments, config=CreateKernelConfig(data_type=data_type)) + assert ast.body.args[0].lhs.dtype.numpy_dtype == np.dtype(dtype) + assert ast.body.args[0].rhs.dtype.numpy_dtype == np.dtype(dtype) + assert ast.body.args[1].body.args[0].rhs.dtype.numpy_dtype == np.dtype(dtype) + + +def test_adjustment_dict(): + d = dict({"x": "float", "y": "double"}) + adjust_c_single_precision_type(d) + assert np.dtype(d["x"]) == np.dtype("float32") + assert np.dtype(d["y"]) == np.dtype("float64") + + +def test_adjustement_default_dict(): + dd = defaultdict(lambda: "float") + dd["x"] + adjust_c_single_precision_type(dd) + dd["y"] + assert np.dtype(dd["x"]) == np.dtype("float32") + assert np.dtype(dd["y"]) == np.dtype("float32") + assert np.dtype(dd["z"]) == np.dtype("float32") -- GitLab