From d5c24d0fabea69e1b59ee1d64f17244a87da2f80 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 16 Aug 2019 10:59:13 +0200
Subject: [PATCH] Fix: printing of double/float constants

- for example: sqrt(cast_func(3, double)) previously printed "sqrt(3)"
  instead of "sqrt(3.0)" leading to errors in CUDA code
---
 pystencils/backends/cbackend.py | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index f8b7e6d..2e27374 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -1,6 +1,7 @@
 from collections import namedtuple
 from typing import Set
 
+import numpy as np
 import sympy as sp
 from sympy.core import S
 from sympy.printing.ccode import C89CodePrinter
@@ -360,13 +361,10 @@ class CustomSympyPrinter(CCodePrinter):
 
     def _typed_number(self, number, dtype):
         res = self._print(number)
-        if dtype.is_float():
-            if dtype == self._float_type:
-                if '.' not in res:
-                    res += ".0f"
-                else:
-                    res += "f"
-            return res
+        if dtype.numpy_dtype == np.float32:
+            return res + '.0f' if '.' not in res else res + 'f'
+        elif dtype.numpy_dtype == np.float64:
+            return res + '.0' if '.' not in res else res
         else:
             return res
 
-- 
GitLab