Skip to content
Snippets Groups Projects
Commit 0460532f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add call code for complex scalars as arguments

parent 5a5a878c
1 merge request!72Support complex numbers
......@@ -255,6 +255,8 @@ type_mapping = {
np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
np.complex64: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexFloat'),
np.complex128: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexDouble'),
}
......@@ -265,6 +267,13 @@ if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '
if( PyErr_Occurred() ) {{ return NULL; }}
"""
template_extract_complex = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
{target_type} {name}{{ {extract_function_real}( obj_{name} ), {extract_function_imag}( obj_{name} ) }};
if( PyErr_Occurred() ) {{ return NULL; }}
"""
template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
......@@ -396,8 +405,16 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
parameters.append("buffer_{name}.shape[{i}]".format(i=param.symbol.coordinate, name=param.field_name))
else:
extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
pre_call_code += template_extract_scalar.format(extract_function=extract_function, target_type=target_type,
name=param.symbol.name)
if np.issubdtype(param.symbol.dtype.numpy_dtype, np.complexfloating):
pre_call_code += template_extract_complex.format(extract_function_real=extract_function[0],
extract_function_imag=extract_function[1],
target_type=target_type,
name=param.symbol.name)
else:
pre_call_code += template_extract_scalar.format(extract_function=extract_function,
target_type=target_type,
name=param.symbol.name)
parameters.append(param.symbol.name)
pre_call_code += equal_size_check(variable_sized_normal_fields)
......
......@@ -106,7 +106,8 @@ def test_complex_numbers_64(assignment, target):
@pytest.mark.parametrize('dtype', (np.float32, np.float64))
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_execution(dtype, target):
@pytest.mark.parametrize('with_complex_argument', ('with_complex_argument', False))
def test_complex_execution(dtype, target, with_complex_argument):
complex_dtype = f'complex{64 if dtype ==np.float32 else 128}'
x, y = pystencils.fields(f'x, y: {complex_dtype}[2d]')
......@@ -114,8 +115,13 @@ def test_complex_execution(dtype, target):
x_arr = np.zeros((20, 30), complex_dtype)
y_arr = np.zeros((20, 30), complex_dtype)
if with_complex_argument:
a = pystencils.TypedSymbol('a', create_type(complex_dtype))
else:
a = (2j+1)
assignments = AssignmentCollection({
y.center: x.center * (2j+1)
y.center: x.center + a
})
if target == 'gpu':
......@@ -125,4 +131,12 @@ def test_complex_execution(dtype, target):
kernel = pystencils.create_kernel(assignments, target=target, data_type=dtype).compile()
kernel(x=x_arr, y=y_arr)
if with_complex_argument:
kernel(x=x_arr, y=y_arr, a=2j+1)
else:
kernel(x=x_arr, y=y_arr)
if target == 'gpu':
y_arr = y_arr.get()
assert np.allclose(y_arr, 2j+1)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment