Skip to content
Snippets Groups Projects
test_complex_numbers.py 3.74 KiB
Newer Older
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""

"""

import itertools

import numpy as np
import pytest
import sympy
from sympy.functions import im, re

import pystencils
from pystencils import AssignmentCollection
from pystencils.data_types import TypedSymbol, create_type

X, Y = pystencils.fields('x, y: complex64[2d]')
A, B = pystencils.fields('a, b: float32[2d]')
S1, S2, T = sympy.symbols('S1, S2, T')

TEST_ASSIGNMENTS = [
    AssignmentCollection({X[0, 0]: 1j}),
    AssignmentCollection({
        S1: re(Y.center),
        S2: im(Y.center),
        X[0, 0]: 2j * S1 + S2
    }),
    AssignmentCollection({
        A.center: re(Y.center),
        B.center: im(Y.center),
    }),
    AssignmentCollection({
        Y.center: re(Y.center) + X.center + 2j,
    }),
    AssignmentCollection({
        T: 2 + 4j,
        Y.center: X.center / T,
    })
]

SCALAR_DTYPES = ['float32', 'float64']


@pytest.mark.parametrize("assignment, scalar_dtypes",
                         itertools.product(TEST_ASSIGNMENTS, (np.float32,)))
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers(assignment, scalar_dtypes, target):
    ast = pystencils.create_kernel(assignment,
                                   target=target,
                                   data_type=scalar_dtypes)
    code = str(pystencils.show_code(ast))

    print(code)
    assert "Not supported" not in code

    kernel = ast.compile()
    assert kernel is not None


X, Y = pystencils.fields('x, y: complex128[2d]')
A, B = pystencils.fields('a, b: float64[2d]')
S1, S2 = sympy.symbols('S1, S2')
T128 = TypedSymbol('ts', create_type('complex128'))

TEST_ASSIGNMENTS = [
    AssignmentCollection({X[0, 0]: 1j}),
    AssignmentCollection({
        S1: re(Y.center),
        S2: im(Y.center),
        X[0, 0]: 2j * S1 + S2
    }),
    AssignmentCollection({
        A.center: re(Y.center),
        B.center: im(Y.center),
    }),
    AssignmentCollection({
        Y.center: re(Y.center) + X.center + 2j,
    }),
    AssignmentCollection({
        T128: 2 + 4j,
        Y.center: X.center / T128,
    })
]

SCALAR_DTYPES = ['float64']
@pytest.mark.parametrize("assignment", TEST_ASSIGNMENTS)
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers_64(assignment, target):
    ast = pystencils.create_kernel(assignment,
                                   target=target,
                                   data_type='double')
    code = str(pystencils.show_code(ast))

    print(code)
    assert "Not supported" not in code

    kernel = ast.compile()
    assert kernel is not None


@pytest.mark.parametrize('dtype', (np.float32, np.float64))
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
@pytest.mark.parametrize('with_complex_argument', ('with_complex_argument', False))
def test_complex_execution(dtype, target, with_complex_argument):

    complex_dtype = f'complex{64 if dtype ==np.float32 else 128}'
    x, y = pystencils.fields(f'x, y:  {complex_dtype}[2d]')

    x_arr = np.zeros((20, 30), complex_dtype)
    y_arr = np.zeros((20, 30), complex_dtype)

    if with_complex_argument:
        a = pystencils.TypedSymbol('a', create_type(complex_dtype))
    else:
        a = (2j+1)

    assignments = AssignmentCollection({
        y.center: x.center + a
    })

    if target == 'gpu':
        from pycuda.gpuarray import zeros
        x_arr = zeros((20, 30), complex_dtype)
        y_arr = zeros((20, 30), complex_dtype)

    kernel = pystencils.create_kernel(assignments, target=target, data_type=dtype).compile()

    if with_complex_argument:
        kernel(x=x_arr, y=y_arr, a=2j+1)
    else:
        kernel(x=x_arr, y=y_arr)

    if target == 'gpu':
        y_arr = y_arr.get()
    assert np.allclose(y_arr, 2j+1)