Commit 684ef359 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Correctly determine complex dtype of symbols and imaginary unit

parent b9423f80
Pipeline #17710 failed with stage
in 5 minutes and 1 second
......@@ -431,11 +431,15 @@ def collate_types(types,
def get_type_of_expression(expr,
default_float_type='double',
default_int_type='int',
default_complex_type='complex128',
symbol_type_dict=None):
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
# TODO: determine more general
if default_float_type == 'double' or default_float_type == 'float64':
default_complex_type = 'complex128'
else:
default_complex_type = 'complex64'
if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double'))
......@@ -443,7 +447,6 @@ def get_type_of_expression(expr,
get_type = partial(get_type_of_expression,
default_float_type=default_float_type,
default_int_type=default_int_type,
default_complex_type=default_complex_type,
symbol_type_dict=symbol_type_dict)
expr = sp.sympify(expr)
......
......@@ -12,8 +12,8 @@ from sympy.logic.boolalg import Boolean
import pystencils.astnodes as ast
from pystencils.assignment import Assignment
from pystencils.data_types import (
PointerType, StructType, TypedSymbol, cast_func, collate_types, create_type, get_base_type,
get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type,
get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
from pystencils.field import AbstractField, Field, FieldType
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection
......@@ -898,6 +898,11 @@ class KernelConstraintsCheck:
return rhs
elif isinstance(rhs, TypedSymbol):
return rhs
elif isinstance(rhs, sp.numbers.ImaginaryUnit):
return TypedImaginaryUnit(self._type_for_symbol['_ImaginaryUnit'])
elif isinstance(rhs, sp.Symbol):
return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
elif isinstance(rhs, sp.Symbol):
return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
elif type_constants and isinstance(rhs, np.generic):
......@@ -1167,6 +1172,11 @@ def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='i
dictionary, mapping symbol name to type
"""
result = defaultdict(lambda: default_type)
if default_type == 'double' or default_type == 'float64': # todo: fix
result['_ImaginaryUnit'] = create_type('complex128')
else:
result['_ImaginaryUnit'] = create_type('complex64')
for eq in eqs:
if isinstance(eq, ast.Conditional):
result.update(typing_from_sympy_inspection(eq.true_block.args))
......
......@@ -20,7 +20,8 @@ from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, create_type
X, Y = pystencils.fields('x, y: complex64[2d]')
A, B = pystencils.fields('a, b: float32[2d]')
S1, S2 = sympy.symbols('S1, S2')
T64 = TypedSymbol('t', create_type('complex64'))
# T64 = TypedSymbol('t', create_type('complex64'))
T64 = sympy.Symbol('t')
TEST_ASSIGNMENTS = [
AssignmentCollection({X[0, 0]: 1j}),
......@@ -48,11 +49,9 @@ SCALAR_DTYPES = ['float32', 'float64']
@pytest.mark.parametrize("assignment, scalar_dtypes",
itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES))
def test_complex_numbers(assignment, scalar_dtypes):
ast = pystencils.create_kernel(assignment.subs(
{sympy.sympify(1j).args[1]:
TypedImaginaryUnit(create_type('complex64'))}),
target='cpu',
data_type=scalar_dtypes)
ast = pystencils.create_kernel(assignment,
target='cpu',
data_type='float32')
code = str(pystencils.show_code(ast))
print(code)
......@@ -94,11 +93,9 @@ SCALAR_DTYPES = ['float32', 'float64']
@pytest.mark.parametrize("assignment, scalar_dtypes",
itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES))
def test_complex_numbers_64(assignment, scalar_dtypes):
ast = pystencils.create_kernel(assignment.subs(
{sympy.sympify(1j).args[1]:
TypedImaginaryUnit(create_type('complex128'))}),
target='cpu',
data_type=scalar_dtypes)
ast = pystencils.create_kernel(assignment,
target='cpu',
data_type='double')
code = str(pystencils.show_code(ast))
print(code)
......@@ -113,5 +110,8 @@ def test_get_data_type():
from pystencils.data_types import get_type_of_expression
i = TypedImaginaryUnit(create_type('complex128'))
# assert get_type_of_expression(i+3).numpy_dtype == np.complex128
assert get_type_of_expression(i+3).numpy_dtype == np.complex128
assert get_type_of_expression(i+3.).numpy_dtype == np.complex128
i = TypedImaginaryUnit(create_type('complex64'))
assert get_type_of_expression(i+3, default_float_type='float32').numpy_dtype == np.complex64
assert get_type_of_expression(i+3., default_float_type='float32').numpy_dtype == np.complex64
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment