Skip to content
Snippets Groups Projects
Commit 9c93da4a authored by Christoph Alt's avatar Christoph Alt Committed by Jan Hönig
Browse files

fixed create_kernel parameter data_type="float" to procucde single precision

parent 52775e94
No related merge requests found
......@@ -960,6 +960,8 @@ def add_types(eqs, type_for_symbol, check_independence_condition, check_double_w
if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
type_for_symbol = adjust_c_single_precision_type(type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition,
check_double_write_condition=check_double_write_condition)
......@@ -1397,3 +1399,16 @@ def implement_interpolations(ast_node: ast.Node,
ast_node.subs(substitutions)
return ast_node
def adjust_c_single_precision_type(type_for_symbol):
"""Replaces every occurrence of 'float' with 'single' to enforce the numpy single precision type."""
def single_factory():
return "single"
for symbol in type_for_symbol:
if type_for_symbol[symbol] == "float":
type_for_symbol[symbol] = single_factory()
if hasattr(type_for_symbol, "default_factory") and type_for_symbol.default_factory() == "float":
type_for_symbol.default_factory = single_factory
return type_for_symbol
from collections import defaultdict
import numpy as np
import pytest
from sympy.abc import x, y
from pystencils import Assignment, create_kernel, fields, CreateKernelConfig
from pystencils.transformations import adjust_c_single_precision_type
@pytest.mark.parametrize("data_type", ("float", "double"))
def test_single_precision(data_type):
dtype = f"float{64 if data_type == 'double' else 32}"
s = fields(f"s: {dtype}[1D]")
assignments = [Assignment(x, y), Assignment(s[0], x)]
ast = create_kernel(assignments, config=CreateKernelConfig(data_type=data_type))
assert ast.body.args[0].lhs.dtype.numpy_dtype == np.dtype(dtype)
assert ast.body.args[0].rhs.dtype.numpy_dtype == np.dtype(dtype)
assert ast.body.args[1].body.args[0].rhs.dtype.numpy_dtype == np.dtype(dtype)
def test_adjustment_dict():
d = dict({"x": "float", "y": "double"})
adjust_c_single_precision_type(d)
assert np.dtype(d["x"]) == np.dtype("float32")
assert np.dtype(d["y"]) == np.dtype("float64")
def test_adjustement_default_dict():
dd = defaultdict(lambda: "float")
dd["x"]
adjust_c_single_precision_type(dd)
dd["y"]
assert np.dtype(dd["x"]) == np.dtype("float32")
assert np.dtype(dd["y"]) == np.dtype("float32")
assert np.dtype(dd["z"]) == np.dtype("float32")
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment