Skip to content
Snippets Groups Projects
Commit d5c24d0f authored by Martin Bauer's avatar Martin Bauer
Browse files

Fix: printing of double/float constants

- for example: sqrt(cast_func(3, double)) previously printed "sqrt(3)"
  instead of "sqrt(3.0)" leading to errors in CUDA code
parent 440866e2
Branches
Tags
No related merge requests found
from collections import namedtuple from collections import namedtuple
from typing import Set from typing import Set
import numpy as np
import sympy as sp import sympy as sp
from sympy.core import S from sympy.core import S
from sympy.printing.ccode import C89CodePrinter from sympy.printing.ccode import C89CodePrinter
...@@ -360,13 +361,10 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -360,13 +361,10 @@ class CustomSympyPrinter(CCodePrinter):
def _typed_number(self, number, dtype): def _typed_number(self, number, dtype):
res = self._print(number) res = self._print(number)
if dtype.is_float(): if dtype.numpy_dtype == np.float32:
if dtype == self._float_type: return res + '.0f' if '.' not in res else res + 'f'
if '.' not in res: elif dtype.numpy_dtype == np.float64:
res += ".0f" return res + '.0' if '.' not in res else res
else:
res += "f"
return res
else: else:
return res return res
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment