Commit 3f5fe292 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'support-complex-numbers' into 'master'

Support complex numbers

See merge request pycodegen/pystencils!72
parents a834955b eda2f772
......@@ -5,7 +5,7 @@ from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type
from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
from pystencils.field import Field
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils.sympyextensions import fast_subs
......@@ -569,6 +569,7 @@ class SympyAssignment(Node):
if isinstance(symbol, Field.Access):
for i in range(len(symbol.offsets)):
loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
result = {r for r in result if not isinstance(r, TypedImaginaryUnit)}
result.update(loop_counters)
result.update(self._lhs_symbol.atoms(sp.Symbol))
return result
......
......@@ -80,8 +80,8 @@ def get_global_declarations(ast):
global_declarations = []
def visit_node(sub_ast):
nonlocal global_declarations
if hasattr(sub_ast, "required_global_declarations"):
nonlocal global_declarations
global_declarations += sub_ast.required_global_declarations
if hasattr(sub_ast, "args"):
......@@ -103,7 +103,7 @@ def get_headers(ast_node: Node) -> Set[str]:
if hasattr(ast_node, 'headers'):
headers.update(ast_node.headers)
for a in ast_node.args:
if isinstance(a, Node):
if isinstance(a, (sp.Expr, Node)):
headers.update(get_headers(a))
for g in get_global_declarations(ast_node):
......@@ -234,7 +234,8 @@ class CBackend:
else:
prefix = ''
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
return "%s%s = %s;" % (data_type,
self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
else:
lhs_type = get_type_of_expression(node.lhs)
......@@ -443,6 +444,27 @@ class CustomSympyPrinter(CCodePrinter):
_print_Max = C89CodePrinter._print_Max
_print_Min = C89CodePrinter._print_Min
def _print_re(self, expr):
return f"real({self._print(expr.args[0])})"
def _print_im(self, expr):
return f"imag({self._print(expr.args[0])})"
def _print_ImaginaryUnit(self, expr):
return "complex<double>{0,1}"
def _print_TypedImaginaryUnit(self, expr):
if expr.dtype.numpy_dtype == np.complex64:
return "complex<float>{0,1}"
elif expr.dtype.numpy_dtype == np.complex128:
return "complex<double>{0,1}"
else:
raise NotImplementedError(
"only complex64 and complex128 supported")
def _print_Complex(self, expr):
return self._typed_number(expr, np.complex64)
# noinspection PyPep8Naming
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
......
......@@ -255,6 +255,8 @@ type_mapping = {
np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
np.complex64: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexFloat'),
np.complex128: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexDouble'),
}
......@@ -265,6 +267,13 @@ if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '
if( PyErr_Occurred() ) {{ return NULL; }}
"""
template_extract_complex = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
{target_type} {name}{{ {extract_function_real}( obj_{name} ), {extract_function_imag}( obj_{name} ) }};
if( PyErr_Occurred() ) {{ return NULL; }}
"""
template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
......@@ -358,7 +367,8 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
np_dtype = field.dtype.numpy_dtype
item_size = np_dtype.itemsize
if np_dtype.isbuiltin and FieldType.is_generic(field):
if (np_dtype.isbuiltin and FieldType.is_generic(field)
and not np.issubdtype(field.dtype.numpy_dtype, np.complexfloating)):
dtype_cond = "buffer_{name}.format[0] == '{format}'".format(name=field.name,
format=field.dtype.numpy_dtype.char)
pre_call_code += template_check_array.format(cond=dtype_cond, what="data type", name=field.name,
......@@ -395,8 +405,16 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
parameters.append("buffer_{name}.shape[{i}]".format(i=param.symbol.coordinate, name=param.field_name))
else:
extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
pre_call_code += template_extract_scalar.format(extract_function=extract_function, target_type=target_type,
name=param.symbol.name)
if np.issubdtype(param.symbol.dtype.numpy_dtype, np.complexfloating):
pre_call_code += template_extract_complex.format(extract_function_real=extract_function[0],
extract_function_imag=extract_function[1],
target_type=target_type,
name=param.symbol.name)
else:
pre_call_code += template_extract_scalar.format(extract_function=extract_function,
target_type=target_type,
name=param.symbol.name)
parameters.append(param.symbol.name)
pre_call_code += equal_size_check(variable_sized_normal_fields)
......
......@@ -4,14 +4,14 @@ from functools import partial
from typing import Tuple
import numpy as np
import sympy as sp
import sympy.codegen.ast
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean
import pystencils
import sympy as sp
import sympy.codegen.ast
from pystencils.cache import memorycache, memorycache_if_hashable
from pystencils.utils import all_equal
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean
try:
import llvmlite.ir as ir
......@@ -250,6 +250,22 @@ class TypedSymbol(sp.Symbol):
def reversed(self):
return self
@property
def headers(self):
headers = []
try:
if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
try:
if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
return headers
def create_type(specification):
"""Creates a subclass of Type according to a string or an object of subclass Type.
......@@ -420,16 +436,29 @@ def peel_off_type(dtype, type_to_peel_off):
return dtype
def collate_types(types, forbid_collation_to_float=False):
def collate_types(types,
forbid_collation_to_complex=False,
forbid_collation_to_float=False,
default_float_type='float64',
default_int_type='int64'):
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
"""
if forbid_collation_to_complex:
types = [
t for t in types
if not np.issubdtype(t.numpy_dtype, np.complexfloating)
]
if not types:
return create_type(default_float_type)
if forbid_collation_to_float:
types = [t for t in types if not (hasattr(t, 'is_float') and t.is_float())]
types = [
t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)
]
if not types:
return create_type('int32')
return create_type(default_int_type)
# Pointer arithmetic case i.e. pointer + integer is allowed
if any(type(t) is PointerType for t in types):
......@@ -484,6 +513,8 @@ def get_type_of_expression(expr,
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type(default_int_type)
elif expr.is_real is False:
return create_type((np.zeros((1,), default_float_type) * 1j).dtype)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess):
......@@ -510,7 +541,7 @@ def get_type_of_expression(expr,
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return typed_symbol.dtype.base_type
elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
elif isinstance(expr, (sp.boolalg.Boolean, sp.boolalg.BooleanFunction)):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
......@@ -523,7 +554,12 @@ def get_type_of_expression(expr,
expr: sp.Expr
if expr.args:
types = tuple(get_type(a) for a in expr.args)
return collate_types(types)
return collate_types(
types,
forbid_collation_to_complex=expr.is_real is True,
forbid_collation_to_float=expr.is_integer is True,
default_float_type=default_float_type,
default_int_type=default_int_type)
else:
if expr.is_integer:
return create_type(default_int_type)
......@@ -550,6 +586,10 @@ class BasicType(Type):
return 'double'
elif name == 'float32':
return 'float'
elif name == 'complex64':
return 'ComplexFloat'
elif name == 'complex128':
return 'ComplexDouble'
elif name.startswith('int'):
width = int(name[len("int"):])
return "int%d_t" % (width,)
......@@ -761,3 +801,23 @@ class StructType:
def __hash__(self):
return hash((self.numpy_dtype, self.const))
class TypedImaginaryUnit(TypedSymbol):
def __new__(cls, *args, **kwds):
obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, dtype, *args, **kwargs):
obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
"_i",
dtype,
imaginary=True,
*args,
**kwargs)
return obj
headers = ['"cuda_complex.hpp"']
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
This diff is collapsed.
......@@ -7,14 +7,15 @@ from types import MappingProxyType
import numpy as np
import sympy as sp
from sympy.core.numbers import ImaginaryUnit
from sympy.logic.boolalg import Boolean
import pystencils.astnodes as ast
import pystencils.integer_functions
from pystencils.assignment import Assignment
from pystencils.data_types import (
PointerType, StructType, TypedSymbol, cast_func, collate_types, create_type, get_base_type,
get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type,
get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
from pystencils.field import AbstractField, Field, FieldType
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection
......@@ -830,6 +831,8 @@ class KernelConstraintsCheck:
if new_args:
rhs.offsets = new_args
return rhs
elif isinstance(rhs, ImaginaryUnit):
return TypedImaginaryUnit(create_type(self._type_for_symbol['_complex_type']))
elif isinstance(rhs, TypedSymbol):
return rhs
elif isinstance(rhs, sp.Symbol):
......@@ -929,7 +932,7 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
list of equations where symbols have been replaced by typed symbols
"""
if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
......@@ -1093,6 +1096,10 @@ def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='i
dictionary, mapping symbol name to type
"""
result = defaultdict(lambda: default_type)
if hasattr(default_type, 'numpy_dtype'):
result['_complex_type'] = (np.zeros((1,), default_type.numpy_dtype) * 1j).dtype
else:
result['_complex_type'] = (np.zeros((1,), default_type) * 1j).dtype
for eq in eqs:
if isinstance(eq, ast.Conditional):
result.update(typing_from_sympy_inspection(eq.true_block.args))
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import itertools
import numpy as np
import pytest
import sympy
from sympy.functions import im, re
import pystencils
from pystencils import AssignmentCollection
from pystencils.data_types import TypedSymbol, create_type
X, Y = pystencils.fields('x, y: complex64[2d]')
A, B = pystencils.fields('a, b: float32[2d]')
S1, S2, T = sympy.symbols('S1, S2, T')
TEST_ASSIGNMENTS = [
AssignmentCollection({X[0, 0]: 1j}),
AssignmentCollection({
S1: re(Y.center),
S2: im(Y.center),
X[0, 0]: 2j * S1 + S2
}),
AssignmentCollection({
A.center: re(Y.center),
B.center: im(Y.center),
}),
AssignmentCollection({
Y.center: re(Y.center) + X.center + 2j,
}),
AssignmentCollection({
T: 2 + 4j,
Y.center: X.center / T,
})
]
SCALAR_DTYPES = ['float32', 'float64']
@pytest.mark.parametrize("assignment, scalar_dtypes",
itertools.product(TEST_ASSIGNMENTS, (np.float32,)))
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers(assignment, scalar_dtypes, target):
ast = pystencils.create_kernel(assignment,
target=target,
data_type=scalar_dtypes)
code = str(pystencils.show_code(ast))
print(code)
assert "Not supported" not in code
kernel = ast.compile()
assert kernel is not None
X, Y = pystencils.fields('x, y: complex128[2d]')
A, B = pystencils.fields('a, b: float64[2d]')
S1, S2 = sympy.symbols('S1, S2')
T128 = TypedSymbol('ts', create_type('complex128'))
TEST_ASSIGNMENTS = [
AssignmentCollection({X[0, 0]: 1j}),
AssignmentCollection({
S1: re(Y.center),
S2: im(Y.center),
X[0, 0]: 2j * S1 + S2
}),
AssignmentCollection({
A.center: re(Y.center),
B.center: im(Y.center),
}),
AssignmentCollection({
Y.center: re(Y.center) + X.center + 2j,
}),
AssignmentCollection({
T128: 2 + 4j,
Y.center: X.center / T128,
})
]
SCALAR_DTYPES = ['float64']
@pytest.mark.parametrize("assignment", TEST_ASSIGNMENTS)
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
def test_complex_numbers_64(assignment, target):
ast = pystencils.create_kernel(assignment,
target=target,
data_type='double')
code = str(pystencils.show_code(ast))
print(code)
assert "Not supported" not in code
kernel = ast.compile()
assert kernel is not None
@pytest.mark.parametrize('dtype', (np.float32, np.float64))
@pytest.mark.parametrize('target', ('cpu', 'gpu'))
@pytest.mark.parametrize('with_complex_argument', ('with_complex_argument', False))
def test_complex_execution(dtype, target, with_complex_argument):
complex_dtype = f'complex{64 if dtype ==np.float32 else 128}'
x, y = pystencils.fields(f'x, y: {complex_dtype}[2d]')
x_arr = np.zeros((20, 30), complex_dtype)
y_arr = np.zeros((20, 30), complex_dtype)
if with_complex_argument:
a = pystencils.TypedSymbol('a', create_type(complex_dtype))
else:
a = (2j+1)
assignments = AssignmentCollection({
y.center: x.center + a
})
if target == 'gpu':
from pycuda.gpuarray import zeros
x_arr = zeros((20, 30), complex_dtype)
y_arr = zeros((20, 30), complex_dtype)
kernel = pystencils.create_kernel(assignments, target=target, data_type=dtype).compile()
if with_complex_argument:
kernel(x=x_arr, y=y_arr, a=2j+1)
else:
kernel(x=x_arr, y=y_arr)
if target == 'gpu':
y_arr = y_arr.get()
assert np.allclose(y_arr, 2j+1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment