From 9c93da4a34d90dcd0f6e1c8eb11987aeb1494673 Mon Sep 17 00:00:00 2001
From: Christoph Alt <christoph.alt@fau.de>
Date: Tue, 14 Sep 2021 16:35:29 +0000
Subject: [PATCH] fixed create_kernel parameter data_type="float" to procucde
 single precision

---
 pystencils/transformations.py             | 15 ++++++++++
 pystencils_tests/test_kernel_data_type.py | 36 +++++++++++++++++++++++
 2 files changed, 51 insertions(+)
 create mode 100644 pystencils_tests/test_kernel_data_type.py

diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index b037225b9..c3a62948e 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -960,6 +960,8 @@ def add_types(eqs, type_for_symbol, check_independence_condition, check_double_w
     if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'):
         type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
 
+    type_for_symbol = adjust_c_single_precision_type(type_for_symbol)
+
     check = KernelConstraintsCheck(type_for_symbol, check_independence_condition,
                                    check_double_write_condition=check_double_write_condition)
 
@@ -1397,3 +1399,16 @@ def implement_interpolations(ast_node: ast.Node,
             ast_node.subs(substitutions)
 
     return ast_node
+
+
+def adjust_c_single_precision_type(type_for_symbol):
+    """Replaces every occurrence of 'float' with 'single' to enforce the numpy single precision type."""
+    def single_factory():
+        return "single"
+
+    for symbol in type_for_symbol:
+        if type_for_symbol[symbol] == "float":
+            type_for_symbol[symbol] = single_factory()
+    if hasattr(type_for_symbol, "default_factory") and type_for_symbol.default_factory() == "float":
+        type_for_symbol.default_factory = single_factory
+    return type_for_symbol
diff --git a/pystencils_tests/test_kernel_data_type.py b/pystencils_tests/test_kernel_data_type.py
new file mode 100644
index 000000000..2fbab3ff1
--- /dev/null
+++ b/pystencils_tests/test_kernel_data_type.py
@@ -0,0 +1,36 @@
+from collections import defaultdict
+
+import numpy as np
+import pytest
+from sympy.abc import x, y
+
+from pystencils import Assignment, create_kernel, fields, CreateKernelConfig
+from pystencils.transformations import adjust_c_single_precision_type
+
+
+@pytest.mark.parametrize("data_type", ("float", "double"))
+def test_single_precision(data_type):
+    dtype = f"float{64 if data_type == 'double' else 32}"
+    s = fields(f"s: {dtype}[1D]")
+    assignments = [Assignment(x, y), Assignment(s[0], x)]
+    ast = create_kernel(assignments, config=CreateKernelConfig(data_type=data_type))
+    assert ast.body.args[0].lhs.dtype.numpy_dtype == np.dtype(dtype)
+    assert ast.body.args[0].rhs.dtype.numpy_dtype == np.dtype(dtype)
+    assert ast.body.args[1].body.args[0].rhs.dtype.numpy_dtype == np.dtype(dtype)
+
+
+def test_adjustment_dict():
+    d = dict({"x": "float", "y": "double"})
+    adjust_c_single_precision_type(d)
+    assert np.dtype(d["x"]) == np.dtype("float32")
+    assert np.dtype(d["y"]) == np.dtype("float64")
+
+
+def test_adjustement_default_dict():
+    dd = defaultdict(lambda: "float")
+    dd["x"]
+    adjust_c_single_precision_type(dd)
+    dd["y"]
+    assert np.dtype(dd["x"]) == np.dtype("float32")
+    assert np.dtype(dd["y"]) == np.dtype("float32")
+    assert np.dtype(dd["z"]) == np.dtype("float32")
-- 
GitLab