Skip to content
Snippets Groups Projects
Commit 0460532f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add call code for complex scalars as arguments

parent 5a5a878c
Branches
Tags
No related merge requests found
......@@ -255,6 +255,8 @@ type_mapping = {
np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
np.complex64: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexFloat'),
np.complex128: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexDouble'),
}
......@@ -265,6 +267,13 @@ if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '
if( PyErr_Occurred() ) {{ return NULL; }}
"""
template_extract_complex = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
{target_type} {name}{{ {extract_function_real}( obj_{name} ), {extract_function_imag}( obj_{name} ) }};
if( PyErr_Occurred() ) {{ return NULL; }}
"""
template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
......@@ -396,8 +405,16 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
parameters.append("buffer_{name}.shape[{i}]".format(i=param.symbol.coordinate, name=param.field_name))
else:
extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
pre_call_code += template_extract_scalar.format(extract_function=extract_function, target_type=target_type,
name=param.symbol.name)
if np.issubdtype(param.symbol.dtype.numpy_dtype, np.complexfloating):
pre_call_code += template_extract_complex.format(extract_function_real=extract_function[0],
extract_function_imag=extract_function[1],
target_type=target_type,
name=param.symbol.name)
else:
pre_call_code += template_extract_scalar.format(extract_function=extract_function,
target_type=target_type,
name=param.symbol.name)
parameters.append(param.symbol.name)
pre_call_code += equal_size_check(variable_sized_normal_fields)
......
......@@ -106,7 +106,8 @@ def test_complex_numbers_64(assignment, target):
@pytest.mark.parametrize('dtype', (np.float32, np.float64))
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_execution(dtype, target):
@pytest.mark.parametrize('with_complex_argument', ('with_complex_argument', False))
def test_complex_execution(dtype, target, with_complex_argument):
complex_dtype = f'complex{64 if dtype ==np.float32 else 128}'
x, y = pystencils.fields(f'x, y: {complex_dtype}[2d]')
......@@ -114,8 +115,13 @@ def test_complex_execution(dtype, target):
x_arr = np.zeros((20, 30), complex_dtype)
y_arr = np.zeros((20, 30), complex_dtype)
if with_complex_argument:
a = pystencils.TypedSymbol('a', create_type(complex_dtype))
else:
a = (2j+1)
assignments = AssignmentCollection({
y.center: x.center * (2j+1)
y.center: x.center + a
})
if target == 'gpu':
......@@ -125,4 +131,12 @@ def test_complex_execution(dtype, target):
kernel = pystencils.create_kernel(assignments, target=target, data_type=dtype).compile()
kernel(x=x_arr, y=y_arr)
if with_complex_argument:
kernel(x=x_arr, y=y_arr, a=2j+1)
else:
kernel(x=x_arr, y=y_arr)
if target == 'gpu':
y_arr = y_arr.get()
assert np.allclose(y_arr, 2j+1)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment