Skip to content
Snippets Groups Projects
test_complex_numbers.py 2.62 KiB
Newer Older
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""

"""

import itertools

import pytest
import sympy
from sympy.functions import im, re
import numpy as np

import pystencils
from pystencils import AssignmentCollection
from pystencils.data_types import TypedSymbol, create_type

X, Y = pystencils.fields('x, y: complex64[2d]')
A, B = pystencils.fields('a, b: float32[2d]')
S1, S2, T = sympy.symbols('S1, S2, T')

TEST_ASSIGNMENTS = [
    AssignmentCollection({X[0, 0]: 1j}),
    AssignmentCollection({
        S1: re(Y.center),
        S2: im(Y.center),
        X[0, 0]: 2j * S1 + S2
    }),
    AssignmentCollection({
        A.center: re(Y.center),
        B.center: im(Y.center),
    }),
    AssignmentCollection({
        Y.center: re(Y.center) + X.center + 2j,
    }),
    AssignmentCollection({
        T: 2 + 4j,
        Y.center: X.center / T,
    })
]

SCALAR_DTYPES = ['float32', 'float64']


@pytest.mark.parametrize("assignment, scalar_dtypes",
                         itertools.product(TEST_ASSIGNMENTS, (np.float32,)))
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers(assignment, scalar_dtypes, target):
    ast = pystencils.create_kernel(assignment,
                                   target=target,
                                   data_type=scalar_dtypes)
    code = str(pystencils.show_code(ast))

    print(code)
    assert "Not supported" not in code

    kernel = ast.compile()
    assert kernel is not None


X, Y = pystencils.fields('x, y: complex128[2d]')
A, B = pystencils.fields('a, b: float64[2d]')
S1, S2 = sympy.symbols('S1, S2')
T128 = TypedSymbol('ts', create_type('complex128'))

TEST_ASSIGNMENTS = [
    AssignmentCollection({X[0, 0]: 1j}),
    AssignmentCollection({
        S1: re(Y.center),
        S2: im(Y.center),
        X[0, 0]: 2j * S1 + S2
    }),
    AssignmentCollection({
        A.center: re(Y.center),
        B.center: im(Y.center),
    }),
    AssignmentCollection({
        Y.center: re(Y.center) + X.center + 2j,
    }),
    AssignmentCollection({
        T128: 2 + 4j,
        Y.center: X.center / T128,
    })
]

SCALAR_DTYPES = [ 'float64']


@pytest.mark.parametrize("assignment",TEST_ASSIGNMENTS)
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers_64(assignment, target):
    ast = pystencils.create_kernel(assignment,
                                   target=target,
                                   data_type='double')
    code = str(pystencils.show_code(ast))

    print(code)
    assert "Not supported" not in code

    kernel = ast.compile()
    assert kernel is not None