# -*- coding: utf-8 -*- # # Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> # # Distributed under terms of the GPLv3 license. """ """ import itertools import pytest import sympy from sympy.functions import im, re import numpy as np import pystencils from pystencils import AssignmentCollection from pystencils.data_types import TypedSymbol, create_type X, Y = pystencils.fields('x, y: complex64[2d]') A, B = pystencils.fields('a, b: float32[2d]') S1, S2, T = sympy.symbols('S1, S2, T') TEST_ASSIGNMENTS = [ AssignmentCollection({X[0, 0]: 1j}), AssignmentCollection({ S1: re(Y.center), S2: im(Y.center), X[0, 0]: 2j * S1 + S2 }), AssignmentCollection({ A.center: re(Y.center), B.center: im(Y.center), }), AssignmentCollection({ Y.center: re(Y.center) + X.center + 2j, }), AssignmentCollection({ T: 2 + 4j, Y.center: X.center / T, }) ] SCALAR_DTYPES = ['float32', 'float64'] @pytest.mark.parametrize("assignment, scalar_dtypes", itertools.product(TEST_ASSIGNMENTS, (np.float32,))) @pytest.mark.parametrize('target', ('cpu', 'gpu')) def test_complex_numbers(assignment, scalar_dtypes, target): ast = pystencils.create_kernel(assignment, target=target, data_type=scalar_dtypes) code = str(pystencils.show_code(ast)) print(code) assert "Not supported" not in code kernel = ast.compile() assert kernel is not None X, Y = pystencils.fields('x, y: complex128[2d]') A, B = pystencils.fields('a, b: float64[2d]') S1, S2 = sympy.symbols('S1, S2') T128 = TypedSymbol('ts', create_type('complex128')) TEST_ASSIGNMENTS = [ AssignmentCollection({X[0, 0]: 1j}), AssignmentCollection({ S1: re(Y.center), S2: im(Y.center), X[0, 0]: 2j * S1 + S2 }), AssignmentCollection({ A.center: re(Y.center), B.center: im(Y.center), }), AssignmentCollection({ Y.center: re(Y.center) + X.center + 2j, }), AssignmentCollection({ T128: 2 + 4j, Y.center: X.center / T128, }) ] SCALAR_DTYPES = [ 'float64'] @pytest.mark.parametrize("assignment",TEST_ASSIGNMENTS) @pytest.mark.parametrize('target', ('cpu', 'gpu')) def test_complex_numbers_64(assignment, target): ast = pystencils.create_kernel(assignment, target=target, data_type='double') code = str(pystencils.show_code(ast)) print(code) assert "Not supported" not in code kernel = ast.compile() assert kernel is not None