From be1a46b65733f3c6490ce6067c8fee7655babba3 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sun, 14 Jan 2024 18:56:05 +0100
Subject: [PATCH] implement C-style integer division and mod

---
 pystencils/nbackend/typed_expressions.py      | 23 ++++++++++++++++---
 pystencils/nbackend/types/basic_types.py      |  2 ++
 .../nbackend/types/test_constants.py          | 17 ++++++++++----
 3 files changed, 35 insertions(+), 7 deletions(-)

diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py
index f9497ce89..4fe705615 100644
--- a/pystencils/nbackend/typed_expressions.py
+++ b/pystencils/nbackend/typed_expressions.py
@@ -229,6 +229,14 @@ class PsTypedConstant:
 
     def __rsub__(self, other: Any):
         return PsTypedConstant(self._rfix(other)._value - self._value, self._dtype)
+    
+    @staticmethod
+    def _divrem(dividend, divisor):
+        quotient =  abs(dividend) // abs(divisor)
+        quotient = quotient if (dividend * divisor > 0) else (- quotient)
+        rem = abs(dividend) % abs(divisor)
+        rem = rem if dividend >= 0 else (- rem)
+        return quotient, rem
 
     def __truediv__(self, other: Any):
         if self._dtype.is_float():
@@ -237,7 +245,10 @@ class PsTypedConstant:
             #   For unsigned integers, `//` does the correct thing
             return PsTypedConstant(self._value // self._fix(other)._value, self._dtype)
         elif self._dtype.is_sint():
-            return NotImplemented  # todo: C integer division
+            dividend = self._value
+            divisor = self._fix(other)._value
+            quotient, _ = self._divrem(dividend, divisor)
+            return PsTypedConstant(quotient, self._dtype)
         else:
             return NotImplemented
 
@@ -247,7 +258,10 @@ class PsTypedConstant:
         elif self._dtype.is_uint():
             return PsTypedConstant(self._rfix(other)._value // self._value, self._dtype)
         elif self._dtype.is_sint():
-            return NotImplemented  # todo: C integer division
+            dividend = self._fix(other)._value
+            divisor = self._value
+            quotient, _ = self._divrem(dividend, divisor)
+            return PsTypedConstant(quotient, self._dtype)
         else:
             return NotImplemented
 
@@ -255,7 +269,10 @@ class PsTypedConstant:
         if self._dtype.is_uint():
             return PsTypedConstant(self._value % self._fix(other)._value, self._dtype)
         else:
-            return NotImplemented  # todo: C integer division
+            dividend = self._value
+            divisor = self._fix(other)._value
+            _, rem = self._divrem(dividend, divisor)
+            return PsTypedConstant(rem, self._dtype)
 
     def __neg__(self):
         return PsTypedConstant(-self._value, self._dtype)
diff --git a/pystencils/nbackend/types/basic_types.py b/pystencils/nbackend/types/basic_types.py
index ada5bdd47..412189968 100644
--- a/pystencils/nbackend/types/basic_types.py
+++ b/pystencils/nbackend/types/basic_types.py
@@ -16,6 +16,8 @@ class PsAbstractType(ABC):
 
     **Type Equality:** Subclasses must implement `__eq__`, but may rely on `_base_equal` to implement
     type equality checks.
+
+    **Hashing:** Each subclass that implements `__eq__` must also implement `__hash__`.
     """
 
     def __init__(self, const: bool = False):
diff --git a/pystencils_tests/nbackend/types/test_constants.py b/pystencils_tests/nbackend/types/test_constants.py
index 222db6043..4eaf4060a 100644
--- a/pystencils_tests/nbackend/types/test_constants.py
+++ b/pystencils_tests/nbackend/types/test_constants.py
@@ -63,8 +63,17 @@ def test_unsigned_integer_division(width):
 
 @pytest.mark.parametrize("width", (8, 16, 32, 64))
 def test_signed_integer_division(width):
-    a = PsTypedConstant(-5, SInt(width))
-    b = PsTypedConstant(2, SInt(width))
+    five = PsTypedConstant(5, SInt(width))
+    two = PsTypedConstant(2, SInt(width))
 
-    assert a / b == PsTypedConstant(-2, SInt(width))
-    assert a % b == PsTypedConstant(-1, SInt(width))
+    assert five / two == PsTypedConstant(2, SInt(width))
+    assert five % two == PsTypedConstant(1, SInt(width))
+
+    assert (- five) / two == PsTypedConstant(-2, SInt(width))
+    assert (- five) % two == PsTypedConstant(-1, SInt(width))
+
+    assert five / (- two) == PsTypedConstant(-2, SInt(width))
+    assert five % (- two) == PsTypedConstant(1, SInt(width))
+
+    assert (- five) / (- two) == PsTypedConstant(2, SInt(width))
+    assert (- five) % (- two) == PsTypedConstant(-1, SInt(width))
-- 
GitLab