diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py index c3e4488e8c4239832b925dfd4ef102ffefb637ea..f9497ce898e452383e6e24fb1c1ca28920173bd1 100644 --- a/pystencils/nbackend/typed_expressions.py +++ b/pystencils/nbackend/typed_expressions.py @@ -1,8 +1,7 @@ from __future__ import annotations from functools import reduce -from typing import TypeAlias, Union, Any, Tuple, Callable -import operator +from typing import TypeAlias, Union, Any, Tuple import pymbolic.primitives as pb @@ -117,6 +116,24 @@ PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess] class PsTypedConstant: """Represents typed constants occuring in the pystencils AST. + Internal Representation of Constants + ------------------------------------ + + Each `PsNumericType` acts as a factory for the code generator's internal representation of that type's + constants. The `PsTypedConstant` class embedds these into the expression trees. + Upon construction, this class's constructor attempts to interpret the given value in the given data type + by passing it to the data type's factory, which in turn may throw an exception if the value's type does + not match. + + Operations and Constant Folding + ------------------------------- + + The `PsTypedConstant` class overrides the basic arithmetic operations for use during a constant folding pass. + Their implementations are very strict regarding types: No implicit conversions take place, and both operands + must always have the exact same type. + The only exception to this rule are the values `0`, `1`, and `-1`, which are promoted to `PsTypedConstant` + (pymbolic injects those at times). + A Note On Divisions ------------------- @@ -171,77 +188,74 @@ class PsTypedConstant: def __repr__(self) -> str: return f"PsTypedConstant( {self._value}, {repr(self._dtype)} )" - @staticmethod - def _fix(v: Any, dtype: PsNumericType) -> PsTypedConstant: - if not isinstance(v, PsTypedConstant): - return PsTypedConstant(v, dtype) - else: - return v - - @staticmethod - def _bin_op( - lhs: PsTypedConstant, rhs: PsTypedConstant, op: Callable[[Any, Any], Any] - ) -> PsTypedConstant: - """Backend for binary operators. Never call directly!""" - - if lhs._dtype != rhs._dtype: + def _fix(self, v: Any) -> PsTypedConstant: + """In binary operations, checks for type equality and, if necessary, promotes the values + `0`, `1` and `-1` to `PsTypedConstant`.""" + if not isinstance(v, PsTypedConstant) and v in (0, 1, -1): + return PsTypedConstant(v, self._dtype) + elif v._dtype != self._dtype: raise PsTypeError( - f"Incompatible operand types in constant folding: {lhs._dtype} and {rhs._dtype}" + f"Incompatible operand types in constant folding: {self._dtype} and {v._dtype}" ) + else: + return v - try: - return PsTypedConstant(op(lhs._value, rhs._value), lhs._dtype) - except PsTypeError: + def _rfix(self, v: Any) -> PsTypedConstant: + """Same as `_fix`, but for use with the `r...` versions of the binary ops. Only changes the order of the + types in the exception string.""" + if not isinstance(v, PsTypedConstant) and v in (0, 1, -1): + return PsTypedConstant(v, self._dtype) + elif v._dtype != self._dtype: raise PsTypeError( - f"Invalid operation in constant folding: {op.__name__}( {repr(lhs)}, {repr(rhs)} )" + f"Incompatible operand types in constant folding: {v._dtype} and {self._dtype}" ) + else: + return v def __add__(self, other: Any): - return PsTypedConstant._bin_op( - self, PsTypedConstant._fix(other, self._dtype), operator.add - ) + return PsTypedConstant(self._value + self._fix(other)._value, self._dtype) def __radd__(self, other: Any): - return PsTypedConstant._bin_op( - PsTypedConstant._fix(other, self._dtype), self, operator.add - ) + return PsTypedConstant(self._rfix(other)._value + self._value, self._dtype) def __mul__(self, other: Any): - return PsTypedConstant._bin_op( - self, PsTypedConstant._fix(other, self._dtype), operator.mul - ) + return PsTypedConstant(self._value * self._fix(other)._value, self._dtype) def __rmul__(self, other: Any): - return PsTypedConstant._bin_op( - PsTypedConstant._fix(other, self._dtype), self, operator.mul - ) + return PsTypedConstant(self._rfix(other)._value * self._value, self._dtype) def __sub__(self, other: Any): - return PsTypedConstant._bin_op( - self, PsTypedConstant._fix(other, self._dtype), operator.sub - ) + return PsTypedConstant(self._value - self._fix(other)._value, self._dtype) def __rsub__(self, other: Any): - return PsTypedConstant._bin_op( - PsTypedConstant._fix(other, self._dtype), self, operator.sub - ) + return PsTypedConstant(self._rfix(other)._value - self._value, self._dtype) def __truediv__(self, other: Any): - other2 = PsTypedConstant._fix(other, self._dtype) if self._dtype.is_float(): - return PsTypedConstant._bin_op(self, other2, operator.truediv) - else: + return PsTypedConstant(self._value / self._fix(other)._value, self._dtype) + elif self._dtype.is_uint(): + # For unsigned integers, `//` does the correct thing + return PsTypedConstant(self._value // self._fix(other)._value, self._dtype) + elif self._dtype.is_sint(): return NotImplemented # todo: C integer division + else: + return NotImplemented def __rtruediv__(self, other: Any): - other2 = PsTypedConstant._fix(other, self._dtype) if self._dtype.is_float(): - return PsTypedConstant._bin_op(other2, self, operator.truediv) - else: + return PsTypedConstant(self._rfix(other)._value / self._value, self._dtype) + elif self._dtype.is_uint(): + return PsTypedConstant(self._rfix(other)._value // self._value, self._dtype) + elif self._dtype.is_sint(): return NotImplemented # todo: C integer division + else: + return NotImplemented def __mod__(self, other: Any): - return NotImplemented # todo: C integer division + if self._dtype.is_uint(): + return PsTypedConstant(self._value % self._fix(other)._value, self._dtype) + else: + return NotImplemented # todo: C integer division def __neg__(self): return PsTypedConstant(-self._value, self._dtype) diff --git a/pystencils_tests/nbackend/test_constant_folding.py b/pystencils_tests/nbackend/test_constant_folding.py index 7f395aaf7259f98b99fc03bacdff264256645402..12bc15b69cb58181f306755d44c194bcb8cd6fc2 100644 --- a/pystencils_tests/nbackend/test_constant_folding.py +++ b/pystencils_tests/nbackend/test_constant_folding.py @@ -27,6 +27,14 @@ def test_constant_folding_int(width): assert folder(expr) == PsTypedConstant(-53, SInt(width)) +@pytest.mark.parametrize("width", (8, 16, 32, 64)) +def test_constant_folding_product(width): + """ + The pymbolic constant folder shows inconsistent behaviour when folding products. + This test both describes the required behaviour and serves as a reminder to fix it. + """ + folder = ConstantFoldingMapper() + expr = pb.Product( ( PsTypedConstant(2, SInt(width)), diff --git a/pystencils_tests/nbackend/types/test_constants.py b/pystencils_tests/nbackend/types/test_constants.py index f166a157050f9c9582c8a0a42e13af7c1a6a1f36..222db6043b36e3df27b794bb50bbfd9439972afe 100644 --- a/pystencils_tests/nbackend/types/test_constants.py +++ b/pystencils_tests/nbackend/types/test_constants.py @@ -32,7 +32,7 @@ def test_float_constants(width): assert a - b == PsTypedConstant(31.5, Fp(width)) assert a / c == PsTypedConstant(16.0, Fp(width)) - + def test_illegal_ops(): # Cannot interpret negative numbers as unsigned types with pytest.raises(PsTypeError): @@ -53,7 +53,16 @@ def test_illegal_ops(): @pytest.mark.parametrize("width", (8, 16, 32, 64)) -def test_integer_division(width): +def test_unsigned_integer_division(width): + a = PsTypedConstant(8, UInt(width)) + b = PsTypedConstant(3, UInt(width)) + + assert a / b == PsTypedConstant(2, UInt(width)) + assert a % b == PsTypedConstant(2, UInt(width)) + + +@pytest.mark.parametrize("width", (8, 16, 32, 64)) +def test_signed_integer_division(width): a = PsTypedConstant(-5, SInt(width)) b = PsTypedConstant(2, SInt(width))