Skip to content
Snippets Groups Projects
Commit 5a5a878c authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add execution test for complex numbers

parent f9b8ee6e
Branches
Tags
No related merge requests found
......@@ -358,7 +358,8 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
np_dtype = field.dtype.numpy_dtype
item_size = np_dtype.itemsize
if np_dtype.isbuiltin and FieldType.is_generic(field):
if (np_dtype.isbuiltin and FieldType.is_generic(field)
and not np.issubdtype(field.dtype.numpy_dtype, np.complexfloating)):
dtype_cond = "buffer_{name}.format[0] == '{format}'".format(name=field.name,
format=field.dtype.numpy_dtype.char)
pre_call_code += template_check_array.format(cond=dtype_cond, what="data type", name=field.name,
......
......@@ -9,10 +9,10 @@
import itertools
import numpy as np
import pytest
import sympy
from sympy.functions import im, re
import numpy as np
import pystencils
from pystencils import AssignmentCollection
......@@ -86,10 +86,10 @@ TEST_ASSIGNMENTS = [
})
]
SCALAR_DTYPES = [ 'float64']
SCALAR_DTYPES = ['float64']
@pytest.mark.parametrize("assignment",TEST_ASSIGNMENTS)
@pytest.mark.parametrize("assignment", TEST_ASSIGNMENTS)
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers_64(assignment, target):
ast = pystencils.create_kernel(assignment,
......@@ -102,3 +102,27 @@ def test_complex_numbers_64(assignment, target):
kernel = ast.compile()
assert kernel is not None
@pytest.mark.parametrize('dtype', (np.float32, np.float64))
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_execution(dtype, target):
complex_dtype = f'complex{64 if dtype ==np.float32 else 128}'
x, y = pystencils.fields(f'x, y: {complex_dtype}[2d]')
x_arr = np.zeros((20, 30), complex_dtype)
y_arr = np.zeros((20, 30), complex_dtype)
assignments = AssignmentCollection({
y.center: x.center * (2j+1)
})
if target == 'gpu':
from pycuda.gpuarray import zeros
x_arr = zeros((20, 30), complex_dtype)
y_arr = zeros((20, 30), complex_dtype)
kernel = pystencils.create_kernel(assignments, target=target, data_type=dtype).compile()
kernel(x=x_arr, y=y_arr)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment