From beb226547d066979718db723f9e972861645e253 Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Thu, 23 Jan 2020 10:57:52 +0100
Subject: [PATCH] correct type conversion in generated complex code

---
 pystencils/cpu/cpujit.py            |  3 ++-
 pystencils/include/cuda_complex.hpp | 16 ++++++++--------
 2 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py
index 6376ffb85..8e8242065 100644
--- a/pystencils/cpu/cpujit.py
+++ b/pystencils/cpu/cpujit.py
@@ -270,7 +270,7 @@ if( PyErr_Occurred() ) {{ return NULL; }}
 template_extract_complex = """
 PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
 if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
-{target_type} {name}{{ {extract_function_real}( obj_{name} ), {extract_function_imag}( obj_{name} ) }};
+{target_type} {name}{{ ({real_type}) {extract_function_real}( obj_{name} ), ({real_type}) {extract_function_imag}( obj_{name} ) }};
 if( PyErr_Occurred() ) {{ return NULL; }}
 """
 
@@ -409,6 +409,7 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
                 pre_call_code += template_extract_complex.format(extract_function_real=extract_function[0],
                                                                  extract_function_imag=extract_function[1],
                                                                  target_type=target_type,
+                                                                 real_type="float" if target_type == "ComplexFloat" else "double",
                                                                  name=param.symbol.name)
             else:
                 pre_call_code += template_extract_scalar.format(extract_function=extract_function,
diff --git a/pystencils/include/cuda_complex.hpp b/pystencils/include/cuda_complex.hpp
index ad555264a..535aa52e3 100644
--- a/pystencils/include/cuda_complex.hpp
+++ b/pystencils/include/cuda_complex.hpp
@@ -1173,53 +1173,53 @@ operator<<(std::basic_ostream<_CharT, _Traits> &__os, const complex<_Tp> &__x) {
 template <class U, class V>
 CUDA_CALLABLE_MEMBER auto operator*(const complex<U> &complexNumber,
                                     const V &scalar) -> complex<U> {
-  return complex<U>{real(complexNumber) * scalar, imag(complexNumber) * scalar};
+  return complex<U>(real(complexNumber) * scalar, imag(complexNumber) * scalar);
 }
 
 template <class U, class V>
 CUDA_CALLABLE_MEMBER auto operator*(const V &scalar,
                                     const complex<U> &complexNumber)
     -> complex<U> {
-  return complex<U>{real(complexNumber) * scalar, imag(complexNumber) * scalar};
+  return complex<U>(real(complexNumber) * scalar, imag(complexNumber) * scalar);
 }
 
 template <class U, class V>
 CUDA_CALLABLE_MEMBER auto operator+(const complex<U> &complexNumber,
                                     const V &scalar) -> complex<U> {
-  return complex<U>{real(complexNumber) + scalar, imag(complexNumber)};
+  return complex<U>(real(complexNumber) + scalar, imag(complexNumber));
 }
 
 template <class U, class V>
 CUDA_CALLABLE_MEMBER auto operator+(const V &scalar,
                                     const complex<U> &complexNumber)
     -> complex<U> {
-  return complex<U>{real(complexNumber) + scalar, imag(complexNumber)};
+  return complex<U>(real(complexNumber) + scalar, imag(complexNumber));
 }
 
 template <class U, class V>
 CUDA_CALLABLE_MEMBER auto operator-(const complex<U> &complexNumber,
                                     const V &scalar) -> complex<U> {
-  return complex<U>{real(complexNumber) - scalar, imag(complexNumber)};
+  return complex<U>(real(complexNumber) - scalar, imag(complexNumber));
 }
 
 template <class U, class V>
 CUDA_CALLABLE_MEMBER auto operator-(const V &scalar,
                                     const complex<U> &complexNumber)
     -> complex<U> {
-  return complex<U>{scalar - real(complexNumber), imag(complexNumber)};
+  return complex<U>(scalar - real(complexNumber), imag(complexNumber));
 }
 
 template <class U, class V>
 CUDA_CALLABLE_MEMBER auto operator/(const complex<U> &complexNumber,
                                     const V scalar) -> complex<U> {
-  return complex<U>{real(complexNumber) / scalar, imag(complexNumber) / scalar};
+  return complex<U>(real(complexNumber) / scalar, imag(complexNumber) / scalar);
 }
 
 template <class U, class V>
 CUDA_CALLABLE_MEMBER auto operator/(const V scalar,
                                     const complex<U> &complexNumber)
     -> complex<U> {
-  return complex<U>{scalar, 0} / complexNumber;
+  return complex<U>(scalar, 0) / complexNumber;
 }
 
 using ComplexDouble = complex<double>;
-- 
GitLab