Commit f9b8ee6e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add complex number support / headers support for sp.Expr

parent 9f76ea1d
......@@ -5,7 +5,7 @@ from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type
from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
from pystencils.field import Field
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils.sympyextensions import fast_subs
......@@ -555,6 +555,7 @@ class SympyAssignment(Node):
if isinstance(symbol, Field.Access):
for i in range(len(symbol.offsets)):
loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
result = {r for r in result if not isinstance(r, TypedImaginaryUnit)}
result.update(loop_counters)
result.update(self._lhs_symbol.atoms(sp.Symbol))
return result
......
......@@ -76,8 +76,8 @@ def get_global_declarations(ast):
global_declarations = []
def visit_node(sub_ast):
nonlocal global_declarations
if hasattr(sub_ast, "required_global_declarations"):
nonlocal global_declarations
global_declarations += sub_ast.required_global_declarations
if hasattr(sub_ast, "args"):
......@@ -99,7 +99,7 @@ def get_headers(ast_node: Node) -> Set[str]:
if hasattr(ast_node, 'headers'):
headers.update(ast_node.headers)
for a in ast_node.args:
if isinstance(a, Node):
if isinstance(a, (sp.Expr, Node)):
headers.update(get_headers(a))
for g in get_global_declarations(ast_node):
......@@ -230,7 +230,8 @@ class CBackend:
else:
prefix = ''
data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
return "%s%s = %s;" % (data_type,
self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
else:
lhs_type = get_type_of_expression(node.lhs)
......@@ -432,6 +433,27 @@ class CustomSympyPrinter(CCodePrinter):
_print_Max = C89CodePrinter._print_Max
_print_Min = C89CodePrinter._print_Min
def _print_re(self, expr):
return f"real({self._print(expr.args[0])})"
def _print_im(self, expr):
return f"imag({self._print(expr.args[0])})"
def _print_ImaginaryUnit(self, expr):
return "complex<double>{0,1}"
def _print_TypedImaginaryUnit(self, expr):
if expr.dtype.numpy_dtype == np.complex64:
return "complex<float>{0,1}"
elif expr.dtype.numpy_dtype == np.complex128:
return "complex<double>{0,1}"
else:
raise NotImplementedError(
"only complex64 and complex128 supported")
def _print_Complex(self, expr):
return self._typed_number(expr, np.complex64)
# noinspection PyPep8Naming
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
......
......@@ -244,6 +244,22 @@ class TypedSymbol(sp.Symbol):
def reversed(self):
return self
@property
def headers(self):
headers = []
try:
if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
try:
if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
return headers
def create_type(specification):
"""Creates a subclass of Type according to a string or an object of subclass Type.
......@@ -414,16 +430,27 @@ def peel_off_type(dtype, type_to_peel_off):
return dtype
def collate_types(types, forbid_collation_to_float=False):
def collate_types(types,
forbid_collation_to_complex=False,
forbid_collation_to_float=False):
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
"""
if forbid_collation_to_complex:
types = [
t for t in types
if not np.issubdtype(t.numpy_dtype, np.complexfloating)
]
if not types:
return create_type(np.float64)
if forbid_collation_to_float:
types = [t for t in types if not (hasattr(t, 'is_float') and t.is_float())]
types = [
t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)
]
if not types:
return create_type('int32')
return create_type(np.int32)
# Pointer arithmetic case i.e. pointer + integer is allowed
if any(type(t) is PointerType for t in types):
......@@ -478,6 +505,8 @@ def get_type_of_expression(expr,
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type(default_int_type)
elif expr.is_real is False:
return create_type((np.zeros((1,), default_float_type) * 1j).dtype)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess):
......@@ -504,7 +533,7 @@ def get_type_of_expression(expr,
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return typed_symbol.dtype.base_type
elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
elif isinstance(expr, (sp.boolalg.Boolean, sp.boolalg.BooleanFunction)):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
......@@ -517,7 +546,10 @@ def get_type_of_expression(expr,
expr: sp.Expr
if expr.args:
types = tuple(get_type(a) for a in expr.args)
return collate_types(types)
return collate_types(
types,
forbid_collation_to_complex=expr.is_real is True,
forbid_collation_to_float=expr.is_integer is True)
else:
if expr.is_integer:
return create_type(default_int_type)
......@@ -544,6 +576,10 @@ class BasicType(Type):
return 'double'
elif name == 'float32':
return 'float'
elif name == 'complex64':
return 'ComplexFloat'
elif name == 'complex128':
return 'ComplexDouble'
elif name.startswith('int'):
width = int(name[len("int"):])
return "int%d_t" % (width,)
......@@ -755,3 +791,23 @@ class StructType:
def __hash__(self):
return hash((self.numpy_dtype, self.const))
class TypedImaginaryUnit(TypedSymbol):
def __new__(cls, *args, **kwds):
obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, dtype, *args, **kwargs):
obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
"_i",
dtype,
imaginary=True,
*args,
**kwargs)
return obj
headers = ['"cuda_complex.hpp"']
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
// An implementation of C++ std::complex for use on CUDA devices.
// Written by John C. Travers <jtravs@gmail.com> (2012)
//
// Missing:
// - long double support (not supported on CUDA)
// - some integral pow functions (due to lack of C++11 support on CUDA)
//
// Heavily derived from the LLVM libcpp project (svn revision 147853).
// Based on libcxx/include/complex.
// The git history contains the complete change history from the original.
// The modifications are licensed as per the original LLVM license below.
//
// -*- C++ -*-
//===--------------------------- complex ----------------------------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is dual licensed under the MIT and the University of Illinois Open
// Source Licenses. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
extern "C++" {
#ifndef CUDA_COMPLEX_HPP
#define CUDA_COMPLEX_HPP
#ifdef __CUDACC__
#define CUDA_CALLABLE_MEMBER __host__ __device__
#else
#define CUDA_CALLABLE_MEMBER
#endif
/*
complex synopsis
template<class T>
class complex
{
public:
typedef T value_type;
complex(const T& re = T(), const T& im = T());
complex(const complex&);
template<class X> complex(const complex<X>&);
T real() const;
T imag() const;
void real(T);
void imag(T);
complex<T>& operator= (const T&);
complex<T>& operator+=(const T&);
complex<T>& operator-=(const T&);
complex<T>& operator*=(const T&);
complex<T>& operator/=(const T&);
complex& operator=(const complex&);
template<class X> complex<T>& operator= (const complex<X>&);
template<class X> complex<T>& operator+=(const complex<X>&);
template<class X> complex<T>& operator-=(const complex<X>&);
template<class X> complex<T>& operator*=(const complex<X>&);
template<class X> complex<T>& operator/=(const complex<X>&);
};
template<>
class complex<float>
{
public:
typedef float value_type;
constexpr complex(float re = 0.0f, float im = 0.0f);
explicit constexpr complex(const complex<double>&);
constexpr float real() const;
void real(float);
constexpr float imag() const;
void imag(float);
complex<float>& operator= (float);
complex<float>& operator+=(float);
complex<float>& operator-=(float);
complex<float>& operator*=(float);
complex<float>& operator/=(float);
complex<float>& operator=(const complex<float>&);
template<class X> complex<float>& operator= (const complex<X>&);
template<class X> complex<float>& operator+=(const complex<X>&);
template<class X> complex<float>& operator-=(const complex<X>&);
template<class X> complex<float>& operator*=(const complex<X>&);
template<class X> complex<float>& operator/=(const complex<X>&);
};
template<>
class complex<double>
{
public:
typedef double value_type;
constexpr complex(double re = 0.0, double im = 0.0);
constexpr complex(const complex<float>&);
constexpr double real() const;
void real(double);
constexpr double imag() const;
void imag(double);
complex<double>& operator= (double);
complex<double>& operator+=(double);
complex<double>& operator-=(double);
complex<double>& operator*=(double);
complex<double>& operator/=(double);
complex<double>& operator=(const complex<double>&);
template<class X> complex<double>& operator= (const complex<X>&);
template<class X> complex<double>& operator+=(const complex<X>&);
template<class X> complex<double>& operator-=(const complex<X>&);
template<class X> complex<double>& operator*=(const complex<X>&);
template<class X> complex<double>& operator/=(const complex<X>&);
};
// 26.3.6 operators:
template<class T> complex<T> operator+(const complex<T>&, const complex<T>&);
template<class T> complex<T> operator+(const complex<T>&, const T&);
template<class T> complex<T> operator+(const T&, const complex<T>&);
template<class T> complex<T> operator-(const complex<T>&, const complex<T>&);
template<class T> complex<T> operator-(const complex<T>&, const T&);
template<class T> complex<T> operator-(const T&, const complex<T>&);
template<class T> complex<T> operator*(const complex<T>&, const complex<T>&);
template<class T> complex<T> operator*(const complex<T>&, const T&);
template<class T> complex<T> operator*(const T&, const complex<T>&);
template<class T> complex<T> operator/(const complex<T>&, const complex<T>&);
template<class T> complex<T> operator/(const complex<T>&, const T&);
template<class T> complex<T> operator/(const T&, const complex<T>&);
template<class T> complex<T> operator+(const complex<T>&);
template<class T> complex<T> operator-(const complex<T>&);
template<class T> bool operator==(const complex<T>&, const complex<T>&);
template<class T> bool operator==(const complex<T>&, const T&);
template<class T> bool operator==(const T&, const complex<T>&);
template<class T> bool operator!=(const complex<T>&, const complex<T>&);
template<class T> bool operator!=(const complex<T>&, const T&);
template<class T> bool operator!=(const T&, const complex<T>&);
template<class T, class charT, class traits>
basic_istream<charT, traits>&
operator>>(basic_istream<charT, traits>&, complex<T>&);
template<class T, class charT, class traits>
basic_ostream<charT, traits>&
operator<<(basic_ostream<charT, traits>&, const complex<T>&);
// 26.3.7 values:
template<class T> T real(const complex<T>&);
double real(double);
template<Integral T> double real(T);
float real(float);
template<class T> T imag(const complex<T>&);
double imag(double);
template<Integral T> double imag(T);
float imag(float);
template<class T> T abs(const complex<T>&);
template<class T> T arg(const complex<T>&);
double arg(double);
template<Integral T> double arg(T);
float arg(float);
template<class T> T norm(const complex<T>&);
double norm(double);
template<Integral T> double norm(T);
float norm(float);
template<class T> complex<T> conj(const complex<T>&);
complex<double> conj(double);
template<Integral T> complex<double> conj(T);
complex<float> conj(float);
template<class T> complex<T> proj(const complex<T>&);
complex<double> proj(double);
template<Integral T> complex<double> proj(T);
complex<float> proj(float);
template<class T> complex<T> polar(const T&, const T& = 0);
// 26.3.8 transcendentals:
template<class T> complex<T> acos(const complex<T>&);
template<class T> complex<T> asin(const complex<T>&);
template<class T> complex<T> atan(const complex<T>&);
template<class T> complex<T> acosh(const complex<T>&);
template<class T> complex<T> asinh(const complex<T>&);
template<class T> complex<T> atanh(const complex<T>&);
template<class T> complex<T> cos (const complex<T>&);
template<class T> complex<T> cosh (const complex<T>&);
template<class T> complex<T> exp (const complex<T>&);
template<class T> complex<T> log (const complex<T>&);
template<class T> complex<T> log10(const complex<T>&);
template<class T> complex<T> pow(const complex<T>&, const T&);
template<class T> complex<T> pow(const complex<T>&, const complex<T>&);
template<class T> complex<T> pow(const T&, const complex<T>&);
template<class T> complex<T> sin (const complex<T>&);
template<class T> complex<T> sinh (const complex<T>&);
template<class T> complex<T> sqrt (const complex<T>&);
template<class T> complex<T> tan (const complex<T>&);
template<class T> complex<T> tanh (const complex<T>&);
template<class T, class charT, class traits>
basic_istream<charT, traits>&
operator>>(basic_istream<charT, traits>& is, complex<T>& x);
template<class T, class charT, class traits>
basic_ostream<charT, traits>&
operator<<(basic_ostream<charT, traits>& o, const complex<T>& x);
*/
#include <math.h>
#include <sstream>
template <class _Tp> class complex;
template <class _Tp>
complex<_Tp> operator*(const complex<_Tp> &__z, const complex<_Tp> &__w);
template <class _Tp>
complex<_Tp> operator/(const complex<_Tp> &__x, const complex<_Tp> &__y);
template <class _Tp> class complex {
public:
typedef _Tp value_type;
private:
value_type __re_;
value_type __im_;
public:
CUDA_CALLABLE_MEMBER
complex(const value_type &__re = value_type(),
const value_type &__im = value_type())
: __re_(__re), __im_(__im) {}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex(const complex<_Xp> &__c)
: __re_(__c.real()), __im_(__c.imag()) {}
CUDA_CALLABLE_MEMBER value_type real() const { return __re_; }
CUDA_CALLABLE_MEMBER value_type imag() const { return __im_; }
CUDA_CALLABLE_MEMBER void real(value_type __re) { __re_ = __re; }
CUDA_CALLABLE_MEMBER void imag(value_type __im) { __im_ = __im; }
CUDA_CALLABLE_MEMBER complex &operator=(const value_type &__re) {
__re_ = __re;
__im_ = value_type();
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator+=(const value_type &__re) {
__re_ += __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator-=(const value_type &__re) {
__re_ -= __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator*=(const value_type &__re) {
__re_ *= __re;
__im_ *= __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator/=(const value_type &__re) {
__re_ /= __re;
__im_ /= __re;
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator=(const complex<_Xp> &__c) {
__re_ = __c.real();
__im_ = __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator+=(const complex<_Xp> &__c) {
__re_ += __c.real();
__im_ += __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator-=(const complex<_Xp> &__c) {
__re_ -= __c.real();
__im_ -= __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator*=(const complex<_Xp> &__c) {
*this = *this * __c;
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator/=(const complex<_Xp> &__c) {
*this = *this / __c;
return *this;
}
};
template <> class complex<double>;
template <> class complex<float> {
float __re_;
float __im_;
public:
typedef float value_type;
/*constexpr*/ CUDA_CALLABLE_MEMBER complex(float __re = 0.0f,
float __im = 0.0f)
: __re_(__re), __im_(__im) {}
explicit /*constexpr*/ complex(const complex<double> &__c);
/*constexpr*/ CUDA_CALLABLE_MEMBER float real() const { return __re_; }
/*constexpr*/ CUDA_CALLABLE_MEMBER float imag() const { return __im_; }
CUDA_CALLABLE_MEMBER void real(value_type __re) { __re_ = __re; }
CUDA_CALLABLE_MEMBER void imag(value_type __im) { __im_ = __im; }
CUDA_CALLABLE_MEMBER complex &operator=(float __re) {
__re_ = __re;
__im_ = value_type();
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator+=(float __re) {
__re_ += __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator-=(float __re) {
__re_ -= __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator*=(float __re) {
__re_ *= __re;
__im_ *= __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator/=(float __re) {
__re_ /= __re;
__im_ /= __re;
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator=(const complex<_Xp> &__c) {
__re_ = __c.real();
__im_ = __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator+=(const complex<_Xp> &__c) {
__re_ += __c.real();
__im_ += __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator-=(const complex<_Xp> &__c) {
__re_ -= __c.real();
__im_ -= __c.imag();
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator*=(const complex<_Xp> &__c) {
*this = *this * __c;
return *this;
}
template <class _Xp>
CUDA_CALLABLE_MEMBER complex &operator/=(const complex<_Xp> &__c) {
*this = *this / __c;
return *this;
}
};
template <> class complex<double> {
double __re_;
double __im_;
public:
typedef double value_type;
/*constexpr*/ CUDA_CALLABLE_MEMBER complex(double __re = 0.0,
double __im = 0.0)
: __re_(__re), __im_(__im) {}
/*constexpr*/ complex(const complex<float> &__c);
/*constexpr*/ CUDA_CALLABLE_MEMBER double real() const { return __re_; }
/*constexpr*/ CUDA_CALLABLE_MEMBER double imag() const { return __im_; }
CUDA_CALLABLE_MEMBER void real(value_type __re) { __re_ = __re; }
CUDA_CALLABLE_MEMBER void imag(value_type __im) { __im_ = __im; }
CUDA_CALLABLE_MEMBER complex &operator=(double __re) {
__re_ = __re;
__im_ = value_type();
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator+=(double __re) {
__re_ += __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator-=(double __re) {
__re_ -= __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator*=(double __re) {
__re_ *= __re;
__im_ *= __re;
return *this;
}
CUDA_CALLABLE_MEMBER complex &operator/=(double __re) {
__re_ /= __re;
__im_ /= __re;
return *this;
}
template <class _Xp>