Commit f9b8ee6e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add complex number support / headers support for sp.Expr

parent 9f76ea1d
......@@ -5,7 +5,7 @@ from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type
from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
from pystencils.field import Field
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils.sympyextensions import fast_subs
......@@ -555,6 +555,7 @@ class SympyAssignment(Node):
if isinstance(symbol, Field.Access):
for i in range(len(symbol.offsets)):
loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
result = {r for r in result if not isinstance(r, TypedImaginaryUnit)}
result.update(loop_counters)
result.update(self._lhs_symbol.atoms(sp.Symbol))
return result
......
......@@ -76,8 +76,8 @@ def get_global_declarations(ast):
global_declarations = []
def visit_node(sub_ast):
nonlocal global_declarations
if hasattr(sub_ast, "required_global_declarations"):
nonlocal global_declarations
global_declarations += sub_ast.required_global_declarations
if hasattr(sub_ast, "args"):
......@@ -99,7 +99,7 @@ def get_headers(ast_node: Node) -> Set[str]:
if hasattr(ast_node, 'headers'):
headers.update(ast_node.headers)
for a in ast_node.args:
if isinstance(a, Node):
if isinstance(a, (sp.Expr, Node)):
headers.update(get_headers(a))
for g in get_global_declarations(ast_node):
......@@ -230,7 +230,8 @@ class CBackend:
else:
prefix = ''
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
return "%s%s = %s;" % (data_type,
self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
else:
lhs_type = get_type_of_expression(node.lhs)
......@@ -432,6 +433,27 @@ class CustomSympyPrinter(CCodePrinter):
_print_Max = C89CodePrinter._print_Max
_print_Min = C89CodePrinter._print_Min
def _print_re(self, expr):
return f"real({self._print(expr.args[0])})"
def _print_im(self, expr):
return f"imag({self._print(expr.args[0])})"
def _print_ImaginaryUnit(self, expr):
return "complex<double>{0,1}"
def _print_TypedImaginaryUnit(self, expr):
if expr.dtype.numpy_dtype == np.complex64:
return "complex<float>{0,1}"
elif expr.dtype.numpy_dtype == np.complex128:
return "complex<double>{0,1}"
else:
raise NotImplementedError(
"only complex64 and complex128 supported")
def _print_Complex(self, expr):
return self._typed_number(expr, np.complex64)
# noinspection PyPep8Naming
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
......
......@@ -244,6 +244,22 @@ class TypedSymbol(sp.Symbol):
def reversed(self):
return self
@property
def headers(self):
headers = []
try:
if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
try:
if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
return headers
def create_type(specification):
"""Creates a subclass of Type according to a string or an object of subclass Type.
......@@ -414,16 +430,27 @@ def peel_off_type(dtype, type_to_peel_off):
return dtype
def collate_types(types, forbid_collation_to_float=False):
def collate_types(types,
forbid_collation_to_complex=False,
forbid_collation_to_float=False):
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
"""
if forbid_collation_to_complex:
types = [
t for t in types
if not np.issubdtype(t.numpy_dtype, np.complexfloating)
]
if not types:
return create_type(np.float64)
if forbid_collation_to_float:
types = [t for t in types if not (hasattr(t, 'is_float') and t.is_float())]
types = [
t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)
]
if not types:
return create_type('int32')
return create_type(np.int32)
# Pointer arithmetic case i.e. pointer + integer is allowed
if any(type(t) is PointerType for t in types):
......@@ -478,6 +505,8 @@ def get_type_of_expression(expr,
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type(default_int_type)
elif expr.is_real is False:
return create_type((np.zeros((1,), default_float_type) * 1j).dtype)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess):
......@@ -504,7 +533,7 @@ def get_type_of_expression(expr,
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return typed_symbol.dtype.base_type
elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
elif isinstance(expr, (sp.boolalg.Boolean, sp.boolalg.BooleanFunction)):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
......@@ -517,7 +546,10 @@ def get_type_of_expression(expr,
expr: sp.Expr
if expr.args:
types = tuple(get_type(a) for a in expr.args)
return collate_types(types)
return collate_types(
types,
forbid_collation_to_complex=expr.is_real is True,
forbid_collation_to_float=expr.is_integer is True)
else:
if expr.is_integer:
return create_type(default_int_type)
......@@ -544,6 +576,10 @@ class BasicType(Type):
return 'double'
elif name == 'float32':
return 'float'
elif name == 'complex64':
return 'ComplexFloat'
elif name == 'complex128':
return 'ComplexDouble'
elif name.startswith('int'):
width = int(name[len("int"):])
return "int%d_t" % (width,)
......@@ -755,3 +791,23 @@ class StructType:
def __hash__(self):
return hash((self.numpy_dtype, self.const))
class TypedImaginaryUnit(TypedSymbol):
def __new__(cls, *args, **kwds):
obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, dtype, *args, **kwargs):
obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
"_i",
dtype,
imaginary=True,
*args,
**kwargs)
return obj
headers = ['"cuda_complex.hpp"']
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
This diff is collapsed.
......@@ -7,14 +7,15 @@ from types import MappingProxyType
import numpy as np
import sympy as sp
from sympy.core.numbers import ImaginaryUnit
from sympy.logic.boolalg import Boolean
import pystencils.astnodes as ast
import pystencils.integer_functions
from pystencils.assignment import Assignment
from pystencils.data_types import (
PointerType, StructType, TypedSymbol, cast_func, collate_types, create_type, get_base_type,
get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type,
get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
from pystencils.field import AbstractField, Field, FieldType
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection
......@@ -827,6 +828,8 @@ class KernelConstraintsCheck:
if new_args:
rhs.offsets = new_args
return rhs
elif isinstance(rhs, ImaginaryUnit):
return TypedImaginaryUnit(create_type(self._type_for_symbol['_complex_type']))
elif isinstance(rhs, TypedSymbol):
return rhs
elif isinstance(rhs, sp.Symbol):
......@@ -930,7 +933,7 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
list of equations where symbols have been replaced by typed symbols
"""
if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
......@@ -1090,6 +1093,10 @@ def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='i
dictionary, mapping symbol name to type
"""
result = defaultdict(lambda: default_type)
if hasattr(default_type, 'numpy_dtype'):
result['_complex_type'] = (np.zeros((1,), default_type.numpy_dtype) * 1j).dtype
else:
result['_complex_type'] = (np.zeros((1,), default_type) * 1j).dtype
for eq in eqs:
if isinstance(eq, ast.Conditional):
result.update(typing_from_sympy_inspection(eq.true_block.args))
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import itertools
import pytest
import sympy
from sympy.functions import im, re
import numpy as np
import pystencils
from pystencils import AssignmentCollection
from pystencils.data_types import TypedSymbol, create_type
X, Y = pystencils.fields('x, y: complex64[2d]')
A, B = pystencils.fields('a, b: float32[2d]')
S1, S2, T = sympy.symbols('S1, S2, T')
TEST_ASSIGNMENTS = [
AssignmentCollection({X[0, 0]: 1j}),
AssignmentCollection({
S1: re(Y.center),
S2: im(Y.center),
X[0, 0]: 2j * S1 + S2
}),
AssignmentCollection({
A.center: re(Y.center),
B.center: im(Y.center),
}),
AssignmentCollection({
Y.center: re(Y.center) + X.center + 2j,
}),
AssignmentCollection({
T: 2 + 4j,
Y.center: X.center / T,
})
]
SCALAR_DTYPES = ['float32', 'float64']
@pytest.mark.parametrize("assignment, scalar_dtypes",
itertools.product(TEST_ASSIGNMENTS, (np.float32,)))
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers(assignment, scalar_dtypes, target):
ast = pystencils.create_kernel(assignment,
target=target,
data_type=scalar_dtypes)
code = str(pystencils.show_code(ast))
print(code)
assert "Not supported" not in code
kernel = ast.compile()
assert kernel is not None
X, Y = pystencils.fields('x, y: complex128[2d]')
A, B = pystencils.fields('a, b: float64[2d]')
S1, S2 = sympy.symbols('S1, S2')
T128 = TypedSymbol('ts', create_type('complex128'))
TEST_ASSIGNMENTS = [
AssignmentCollection({X[0, 0]: 1j}),
AssignmentCollection({
S1: re(Y.center),
S2: im(Y.center),
X[0, 0]: 2j * S1 + S2
}),
AssignmentCollection({
A.center: re(Y.center),
B.center: im(Y.center),
}),
AssignmentCollection({
Y.center: re(Y.center) + X.center + 2j,
}),
AssignmentCollection({
T128: 2 + 4j,
Y.center: X.center / T128,
})
]
SCALAR_DTYPES = [ 'float64']
@pytest.mark.parametrize("assignment",TEST_ASSIGNMENTS)
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers_64(assignment, target):
ast = pystencils.create_kernel(assignment,
target=target,
data_type='double')
code = str(pystencils.show_code(ast))
print(code)
assert "Not supported" not in code
kernel = ast.compile()
assert kernel is not None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment