From c78e6088ed55249ae77678472accfd7a4300ad45 Mon Sep 17 00:00:00 2001
From: Daniel Bauer <daniel.j.bauer@fau.de>
Date: Wed, 22 May 2024 11:01:34 +0200
Subject: [PATCH] Cast integer literals to target type

---
 src/pystencils/types/types.py     | 10 ++++++----
 tests/nbackend/test_extensions.py |  6 +++---
 2 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py
index 8e51f9397..ae0a8829d 100644
--- a/src/pystencils/types/types.py
+++ b/src/pystencils/types/types.py
@@ -483,8 +483,7 @@ class PsIntegerType(PsScalarType, ABC):
         if not isinstance(value, np_dtype):
             raise PsTypeError(f"Given value {value} is not of required type {np_dtype}")
         unsigned_suffix = "" if self.signed else "u"
-        #   TODO: cast literal to correct type?
-        return str(value) + unsigned_suffix
+        return f"(({self._c_type_without_const()}) {value}{unsigned_suffix})"
 
     def create_constant(self, value: Any) -> Any:
         np_type = self.NUMPY_TYPES[self._width]
@@ -499,9 +498,12 @@ class PsIntegerType(PsScalarType, ABC):
 
         raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
 
-    def c_string(self) -> str:
+    def _c_type_without_const(self) -> str:
         prefix = "" if self._signed else "u"
-        return f"{self._const_string()}{prefix}int{self._width}_t"
+        return f"{prefix}int{self._width}_t"
+
+    def c_string(self) -> str:
+        return f"{self._const_string()}{self._c_type_without_const()}"
 
     def __repr__(self) -> str:
         return f"PsIntegerType( width={self.width}, signed={self.signed}, const={self.const} )"
diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py
index 75726a351..8d600ef76 100644
--- a/tests/nbackend/test_extensions.py
+++ b/tests/nbackend/test_extensions.py
@@ -54,6 +54,6 @@ def test_literals():
     print(code)
 
     assert "const double x = C;" in code
-    assert "CELLS[0]" in code
-    assert "CELLS[1]" in code
-    assert "CELLS[2]" in code
+    assert "CELLS[((int64_t) 0)]" in code
+    assert "CELLS[((int64_t) 1)]" in code
+    assert "CELLS[((int64_t) 2)]" in code
-- 
GitLab