Commit f9b8ee6e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add complex number support / headers support for sp.Expr

parent 9f76ea1d
...@@ -5,7 +5,7 @@ from typing import Any, List, Optional, Sequence, Set, Union ...@@ -5,7 +5,7 @@ from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
from pystencils.field import Field from pystencils.field import Field
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils.sympyextensions import fast_subs from pystencils.sympyextensions import fast_subs
...@@ -555,6 +555,7 @@ class SympyAssignment(Node): ...@@ -555,6 +555,7 @@ class SympyAssignment(Node):
if isinstance(symbol, Field.Access): if isinstance(symbol, Field.Access):
for i in range(len(symbol.offsets)): for i in range(len(symbol.offsets)):
loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i)) loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
result = {r for r in result if not isinstance(r, TypedImaginaryUnit)}
result.update(loop_counters) result.update(loop_counters)
result.update(self._lhs_symbol.atoms(sp.Symbol)) result.update(self._lhs_symbol.atoms(sp.Symbol))
return result return result
......
...@@ -76,8 +76,8 @@ def get_global_declarations(ast): ...@@ -76,8 +76,8 @@ def get_global_declarations(ast):
global_declarations = [] global_declarations = []
def visit_node(sub_ast): def visit_node(sub_ast):
nonlocal global_declarations
if hasattr(sub_ast, "required_global_declarations"): if hasattr(sub_ast, "required_global_declarations"):
nonlocal global_declarations
global_declarations += sub_ast.required_global_declarations global_declarations += sub_ast.required_global_declarations
if hasattr(sub_ast, "args"): if hasattr(sub_ast, "args"):
...@@ -99,7 +99,7 @@ def get_headers(ast_node: Node) -> Set[str]: ...@@ -99,7 +99,7 @@ def get_headers(ast_node: Node) -> Set[str]:
if hasattr(ast_node, 'headers'): if hasattr(ast_node, 'headers'):
headers.update(ast_node.headers) headers.update(ast_node.headers)
for a in ast_node.args: for a in ast_node.args:
if isinstance(a, Node): if isinstance(a, (sp.Expr, Node)):
headers.update(get_headers(a)) headers.update(get_headers(a))
for g in get_global_declarations(ast_node): for g in get_global_declarations(ast_node):
...@@ -230,7 +230,8 @@ class CBackend: ...@@ -230,7 +230,8 @@ class CBackend:
else: else:
prefix = '' prefix = ''
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " " data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), return "%s%s = %s;" % (data_type,
self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs)) self.sympy_printer.doprint(node.rhs))
else: else:
lhs_type = get_type_of_expression(node.lhs) lhs_type = get_type_of_expression(node.lhs)
...@@ -432,6 +433,27 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -432,6 +433,27 @@ class CustomSympyPrinter(CCodePrinter):
_print_Max = C89CodePrinter._print_Max _print_Max = C89CodePrinter._print_Max
_print_Min = C89CodePrinter._print_Min _print_Min = C89CodePrinter._print_Min
def _print_re(self, expr):
return f"real({self._print(expr.args[0])})"
def _print_im(self, expr):
return f"imag({self._print(expr.args[0])})"
def _print_ImaginaryUnit(self, expr):
return "complex<double>{0,1}"
def _print_TypedImaginaryUnit(self, expr):
if expr.dtype.numpy_dtype == np.complex64:
return "complex<float>{0,1}"
elif expr.dtype.numpy_dtype == np.complex128:
return "complex<double>{0,1}"
else:
raise NotImplementedError(
"only complex64 and complex128 supported")
def _print_Complex(self, expr):
return self._typed_number(expr, np.complex64)
# noinspection PyPep8Naming # noinspection PyPep8Naming
class VectorizedCustomSympyPrinter(CustomSympyPrinter): class VectorizedCustomSympyPrinter(CustomSympyPrinter):
......
...@@ -244,6 +244,22 @@ class TypedSymbol(sp.Symbol): ...@@ -244,6 +244,22 @@ class TypedSymbol(sp.Symbol):
def reversed(self): def reversed(self):
return self return self
@property
def headers(self):
headers = []
try:
if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
try:
if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
return headers
def create_type(specification): def create_type(specification):
"""Creates a subclass of Type according to a string or an object of subclass Type. """Creates a subclass of Type according to a string or an object of subclass Type.
...@@ -414,16 +430,27 @@ def peel_off_type(dtype, type_to_peel_off): ...@@ -414,16 +430,27 @@ def peel_off_type(dtype, type_to_peel_off):
return dtype return dtype
def collate_types(types, forbid_collation_to_float=False): def collate_types(types,
forbid_collation_to_complex=False,
forbid_collation_to_float=False):
""" """
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy. Uses the collation rules from numpy.
""" """
if forbid_collation_to_complex:
types = [
t for t in types
if not np.issubdtype(t.numpy_dtype, np.complexfloating)
]
if not types:
return create_type(np.float64)
if forbid_collation_to_float: if forbid_collation_to_float:
types = [t for t in types if not (hasattr(t, 'is_float') and t.is_float())] types = [
t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)
]
if not types: if not types:
return create_type('int32') return create_type(np.int32)
# Pointer arithmetic case i.e. pointer + integer is allowed # Pointer arithmetic case i.e. pointer + integer is allowed
if any(type(t) is PointerType for t in types): if any(type(t) is PointerType for t in types):
...@@ -478,6 +505,8 @@ def get_type_of_expression(expr, ...@@ -478,6 +505,8 @@ def get_type_of_expression(expr,
expr = sp.sympify(expr) expr = sp.sympify(expr)
if isinstance(expr, sp.Integer): if isinstance(expr, sp.Integer):
return create_type(default_int_type) return create_type(default_int_type)
elif expr.is_real is False:
return create_type((np.zeros((1,), default_float_type) * 1j).dtype)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type(default_float_type) return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess): elif isinstance(expr, ResolvedFieldAccess):
...@@ -504,7 +533,7 @@ def get_type_of_expression(expr, ...@@ -504,7 +533,7 @@ def get_type_of_expression(expr,
elif isinstance(expr, sp.Indexed): elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label typed_symbol = expr.base.label
return typed_symbol.dtype.base_type return typed_symbol.dtype.base_type
elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction): elif isinstance(expr, (sp.boolalg.Boolean, sp.boolalg.BooleanFunction)):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool") result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)] vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
...@@ -517,7 +546,10 @@ def get_type_of_expression(expr, ...@@ -517,7 +546,10 @@ def get_type_of_expression(expr,
expr: sp.Expr expr: sp.Expr
if expr.args: if expr.args:
types = tuple(get_type(a) for a in expr.args) types = tuple(get_type(a) for a in expr.args)
return collate_types(types) return collate_types(
types,
forbid_collation_to_complex=expr.is_real is True,
forbid_collation_to_float=expr.is_integer is True)
else: else:
if expr.is_integer: if expr.is_integer:
return create_type(default_int_type) return create_type(default_int_type)
...@@ -544,6 +576,10 @@ class BasicType(Type): ...@@ -544,6 +576,10 @@ class BasicType(Type):
return 'double' return 'double'
elif name == 'float32': elif name == 'float32':
return 'float' return 'float'
elif name == 'complex64':
return 'ComplexFloat'
elif name == 'complex128':
return 'ComplexDouble'
elif name.startswith('int'): elif name.startswith('int'):
width = int(name[len("int"):]) width = int(name[len("int"):])
return "int%d_t" % (width,) return "int%d_t" % (width,)
...@@ -755,3 +791,23 @@ class StructType: ...@@ -755,3 +791,23 @@ class StructType:
def __hash__(self): def __hash__(self):
return hash((self.numpy_dtype, self.const)) return hash((self.numpy_dtype, self.const))
class TypedImaginaryUnit(TypedSymbol):
def __new__(cls, *args, **kwds):
obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, dtype, *args, **kwargs):
obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
"_i",
dtype,
imaginary=True,
*args,
**kwargs)
return obj
headers = ['"cuda_complex.hpp"']
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
This diff is collapsed.
...@@ -7,14 +7,15 @@ from types import MappingProxyType ...@@ -7,14 +7,15 @@ from types import MappingProxyType
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from sympy.core.numbers import ImaginaryUnit
from sympy.logic.boolalg import Boolean from sympy.logic.boolalg import Boolean
import pystencils.astnodes as ast import pystencils.astnodes as ast
import pystencils.integer_functions import pystencils.integer_functions
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.data_types import ( from pystencils.data_types import (
PointerType, StructType, TypedSymbol, cast_func, collate_types, create_type, get_base_type, PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type,
get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func) get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
from pystencils.field import AbstractField, Field, FieldType from pystencils.field import AbstractField, Field, FieldType
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection
...@@ -827,6 +828,8 @@ class KernelConstraintsCheck: ...@@ -827,6 +828,8 @@ class KernelConstraintsCheck:
if new_args: if new_args:
rhs.offsets = new_args rhs.offsets = new_args
return rhs return rhs
elif isinstance(rhs, ImaginaryUnit):
return TypedImaginaryUnit(create_type(self._type_for_symbol['_complex_type']))
elif isinstance(rhs, TypedSymbol): elif isinstance(rhs, TypedSymbol):
return rhs return rhs
elif isinstance(rhs, sp.Symbol): elif isinstance(rhs, sp.Symbol):
...@@ -930,7 +933,7 @@ def add_types(eqs, type_for_symbol, check_independence_condition): ...@@ -930,7 +933,7 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
``fields_read, fields_written, typed_equations`` set of read fields, set of written fields, ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
list of equations where symbols have been replaced by typed symbols list of equations where symbols have been replaced by typed symbols
""" """
if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'): if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol) type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition) check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
...@@ -1090,6 +1093,10 @@ def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='i ...@@ -1090,6 +1093,10 @@ def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='i
dictionary, mapping symbol name to type dictionary, mapping symbol name to type
""" """
result = defaultdict(lambda: default_type) result = defaultdict(lambda: default_type)
if hasattr(default_type, 'numpy_dtype'):
result['_complex_type'] = (np.zeros((1,), default_type.numpy_dtype) * 1j).dtype
else:
result['_complex_type'] = (np.zeros((1,), default_type) * 1j).dtype
for eq in eqs: for eq in eqs:
if isinstance(eq, ast.Conditional): if isinstance(eq, ast.Conditional):
result.update(typing_from_sympy_inspection(eq.true_block.args)) result.update(typing_from_sympy_inspection(eq.true_block.args))
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import itertools
import pytest
import sympy
from sympy.functions import im, re
import numpy as np
import pystencils
from pystencils import AssignmentCollection
from pystencils.data_types import TypedSymbol, create_type
X, Y = pystencils.fields('x, y: complex64[2d]')
A, B = pystencils.fields('a, b: float32[2d]')
S1, S2, T = sympy.symbols('S1, S2, T')
TEST_ASSIGNMENTS = [
AssignmentCollection({X[0, 0]: 1j}),
AssignmentCollection({
S1: re(Y.center),
S2: im(Y.center),
X[0, 0]: 2j * S1 + S2
}),
AssignmentCollection({
A.center: re(Y.center),
B.center: im(Y.center),
}),
AssignmentCollection({
Y.center: re(Y.center) + X.center + 2j,
}),
AssignmentCollection({
T: 2 + 4j,
Y.center: X.center / T,
})
]
SCALAR_DTYPES = ['float32', 'float64']
@pytest.mark.parametrize("assignment, scalar_dtypes",
itertools.product(TEST_ASSIGNMENTS, (np.float32,)))
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers(assignment, scalar_dtypes, target):
ast = pystencils.create_kernel(assignment,
target=target,
data_type=scalar_dtypes)
code = str(pystencils.show_code(ast))
print(code)
assert "Not supported" not in code
kernel = ast.compile()
assert kernel is not None
X, Y = pystencils.fields('x, y: complex128[2d]')
A, B = pystencils.fields('a, b: float64[2d]')
S1, S2 = sympy.symbols('S1, S2')
T128 = TypedSymbol('ts', create_type('complex128'))
TEST_ASSIGNMENTS = [
AssignmentCollection({X[0, 0]: 1j}),
AssignmentCollection({
S1: re(Y.center),
S2: im(Y.center),
X[0, 0]: 2j * S1 + S2
}),
AssignmentCollection({
A.center: re(Y.center),
B.center: im(Y.center),
}),
AssignmentCollection({
Y.center: re(Y.center) + X.center + 2j,
}),
AssignmentCollection({
T128: 2 + 4j,
Y.center: X.center / T128,
})
]
SCALAR_DTYPES = [ 'float64']
@pytest.mark.parametrize("assignment",TEST_ASSIGNMENTS)
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers_64(assignment, target):
ast = pystencils.create_kernel(assignment,
target=target,
data_type='double')
code = str(pystencils.show_code(ast))
print(code)
assert "Not supported" not in code
kernel = ast.compile()
assert kernel is not None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment