From d5c24d0fabea69e1b59ee1d64f17244a87da2f80 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Fri, 16 Aug 2019 10:59:13 +0200 Subject: [PATCH] Fix: printing of double/float constants - for example: sqrt(cast_func(3, double)) previously printed "sqrt(3)" instead of "sqrt(3.0)" leading to errors in CUDA code --- pystencils/backends/cbackend.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index f8b7e6d..2e27374 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -1,6 +1,7 @@ from collections import namedtuple from typing import Set +import numpy as np import sympy as sp from sympy.core import S from sympy.printing.ccode import C89CodePrinter @@ -360,13 +361,10 @@ class CustomSympyPrinter(CCodePrinter): def _typed_number(self, number, dtype): res = self._print(number) - if dtype.is_float(): - if dtype == self._float_type: - if '.' not in res: - res += ".0f" - else: - res += "f" - return res + if dtype.numpy_dtype == np.float32: + return res + '.0f' if '.' not in res else res + 'f' + elif dtype.numpy_dtype == np.float64: + return res + '.0' if '.' not in res else res else: return res -- GitLab