Skip to content
Snippets Groups Projects
Commit 5a5a878c authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add execution test for complex numbers

parent f9b8ee6e
1 merge request!72Support complex numbers
......@@ -358,7 +358,8 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
np_dtype = field.dtype.numpy_dtype
item_size = np_dtype.itemsize
if np_dtype.isbuiltin and FieldType.is_generic(field):
if (np_dtype.isbuiltin and FieldType.is_generic(field)
and not np.issubdtype(field.dtype.numpy_dtype, np.complexfloating)):
dtype_cond = "buffer_{name}.format[0] == '{format}'".format(name=field.name,
format=field.dtype.numpy_dtype.char)
pre_call_code += template_check_array.format(cond=dtype_cond, what="data type", name=field.name,
......
......@@ -9,10 +9,10 @@
import itertools
import numpy as np
import pytest
import sympy
from sympy.functions import im, re
import numpy as np
import pystencils
from pystencils import AssignmentCollection
......@@ -86,10 +86,10 @@ TEST_ASSIGNMENTS = [
})
]
SCALAR_DTYPES = [ 'float64']
SCALAR_DTYPES = ['float64']
@pytest.mark.parametrize("assignment",TEST_ASSIGNMENTS)
@pytest.mark.parametrize("assignment", TEST_ASSIGNMENTS)
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers_64(assignment, target):
ast = pystencils.create_kernel(assignment,
......@@ -102,3 +102,27 @@ def test_complex_numbers_64(assignment, target):
kernel = ast.compile()
assert kernel is not None
@pytest.mark.parametrize('dtype', (np.float32, np.float64))
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_execution(dtype, target):
complex_dtype = f'complex{64 if dtype ==np.float32 else 128}'
x, y = pystencils.fields(f'x, y: {complex_dtype}[2d]')
x_arr = np.zeros((20, 30), complex_dtype)
y_arr = np.zeros((20, 30), complex_dtype)
assignments = AssignmentCollection({
y.center: x.center * (2j+1)
})
if target == 'gpu':
from pycuda.gpuarray import zeros
x_arr = zeros((20, 30), complex_dtype)
y_arr = zeros((20, 30), complex_dtype)
kernel = pystencils.create_kernel(assignments, target=target, data_type=dtype).compile()
kernel(x=x_arr, y=y_arr)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment