Skip to content
Snippets Groups Projects
Commit 94e18cf3 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

some refactoring of PsTypedConstant and tests

parent fe808e6b
No related merge requests found
Pipeline #59880 failed with stages
in 2 minutes and 49 seconds
from __future__ import annotations from __future__ import annotations
from functools import reduce from functools import reduce
from typing import TypeAlias, Union, Any, Tuple, Callable from typing import TypeAlias, Union, Any, Tuple
import operator
import pymbolic.primitives as pb import pymbolic.primitives as pb
...@@ -117,6 +116,24 @@ PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess] ...@@ -117,6 +116,24 @@ PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess]
class PsTypedConstant: class PsTypedConstant:
"""Represents typed constants occuring in the pystencils AST. """Represents typed constants occuring in the pystencils AST.
Internal Representation of Constants
------------------------------------
Each `PsNumericType` acts as a factory for the code generator's internal representation of that type's
constants. The `PsTypedConstant` class embedds these into the expression trees.
Upon construction, this class's constructor attempts to interpret the given value in the given data type
by passing it to the data type's factory, which in turn may throw an exception if the value's type does
not match.
Operations and Constant Folding
-------------------------------
The `PsTypedConstant` class overrides the basic arithmetic operations for use during a constant folding pass.
Their implementations are very strict regarding types: No implicit conversions take place, and both operands
must always have the exact same type.
The only exception to this rule are the values `0`, `1`, and `-1`, which are promoted to `PsTypedConstant`
(pymbolic injects those at times).
A Note On Divisions A Note On Divisions
------------------- -------------------
...@@ -171,77 +188,74 @@ class PsTypedConstant: ...@@ -171,77 +188,74 @@ class PsTypedConstant:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"PsTypedConstant( {self._value}, {repr(self._dtype)} )" return f"PsTypedConstant( {self._value}, {repr(self._dtype)} )"
@staticmethod def _fix(self, v: Any) -> PsTypedConstant:
def _fix(v: Any, dtype: PsNumericType) -> PsTypedConstant: """In binary operations, checks for type equality and, if necessary, promotes the values
if not isinstance(v, PsTypedConstant): `0`, `1` and `-1` to `PsTypedConstant`."""
return PsTypedConstant(v, dtype) if not isinstance(v, PsTypedConstant) and v in (0, 1, -1):
else: return PsTypedConstant(v, self._dtype)
return v elif v._dtype != self._dtype:
@staticmethod
def _bin_op(
lhs: PsTypedConstant, rhs: PsTypedConstant, op: Callable[[Any, Any], Any]
) -> PsTypedConstant:
"""Backend for binary operators. Never call directly!"""
if lhs._dtype != rhs._dtype:
raise PsTypeError( raise PsTypeError(
f"Incompatible operand types in constant folding: {lhs._dtype} and {rhs._dtype}" f"Incompatible operand types in constant folding: {self._dtype} and {v._dtype}"
) )
else:
return v
try: def _rfix(self, v: Any) -> PsTypedConstant:
return PsTypedConstant(op(lhs._value, rhs._value), lhs._dtype) """Same as `_fix`, but for use with the `r...` versions of the binary ops. Only changes the order of the
except PsTypeError: types in the exception string."""
if not isinstance(v, PsTypedConstant) and v in (0, 1, -1):
return PsTypedConstant(v, self._dtype)
elif v._dtype != self._dtype:
raise PsTypeError( raise PsTypeError(
f"Invalid operation in constant folding: {op.__name__}( {repr(lhs)}, {repr(rhs)} )" f"Incompatible operand types in constant folding: {v._dtype} and {self._dtype}"
) )
else:
return v
def __add__(self, other: Any): def __add__(self, other: Any):
return PsTypedConstant._bin_op( return PsTypedConstant(self._value + self._fix(other)._value, self._dtype)
self, PsTypedConstant._fix(other, self._dtype), operator.add
)
def __radd__(self, other: Any): def __radd__(self, other: Any):
return PsTypedConstant._bin_op( return PsTypedConstant(self._rfix(other)._value + self._value, self._dtype)
PsTypedConstant._fix(other, self._dtype), self, operator.add
)
def __mul__(self, other: Any): def __mul__(self, other: Any):
return PsTypedConstant._bin_op( return PsTypedConstant(self._value * self._fix(other)._value, self._dtype)
self, PsTypedConstant._fix(other, self._dtype), operator.mul
)
def __rmul__(self, other: Any): def __rmul__(self, other: Any):
return PsTypedConstant._bin_op( return PsTypedConstant(self._rfix(other)._value * self._value, self._dtype)
PsTypedConstant._fix(other, self._dtype), self, operator.mul
)
def __sub__(self, other: Any): def __sub__(self, other: Any):
return PsTypedConstant._bin_op( return PsTypedConstant(self._value - self._fix(other)._value, self._dtype)
self, PsTypedConstant._fix(other, self._dtype), operator.sub
)
def __rsub__(self, other: Any): def __rsub__(self, other: Any):
return PsTypedConstant._bin_op( return PsTypedConstant(self._rfix(other)._value - self._value, self._dtype)
PsTypedConstant._fix(other, self._dtype), self, operator.sub
)
def __truediv__(self, other: Any): def __truediv__(self, other: Any):
other2 = PsTypedConstant._fix(other, self._dtype)
if self._dtype.is_float(): if self._dtype.is_float():
return PsTypedConstant._bin_op(self, other2, operator.truediv) return PsTypedConstant(self._value / self._fix(other)._value, self._dtype)
else: elif self._dtype.is_uint():
# For unsigned integers, `//` does the correct thing
return PsTypedConstant(self._value // self._fix(other)._value, self._dtype)
elif self._dtype.is_sint():
return NotImplemented # todo: C integer division return NotImplemented # todo: C integer division
else:
return NotImplemented
def __rtruediv__(self, other: Any): def __rtruediv__(self, other: Any):
other2 = PsTypedConstant._fix(other, self._dtype)
if self._dtype.is_float(): if self._dtype.is_float():
return PsTypedConstant._bin_op(other2, self, operator.truediv) return PsTypedConstant(self._rfix(other)._value / self._value, self._dtype)
else: elif self._dtype.is_uint():
return PsTypedConstant(self._rfix(other)._value // self._value, self._dtype)
elif self._dtype.is_sint():
return NotImplemented # todo: C integer division return NotImplemented # todo: C integer division
else:
return NotImplemented
def __mod__(self, other: Any): def __mod__(self, other: Any):
return NotImplemented # todo: C integer division if self._dtype.is_uint():
return PsTypedConstant(self._value % self._fix(other)._value, self._dtype)
else:
return NotImplemented # todo: C integer division
def __neg__(self): def __neg__(self):
return PsTypedConstant(-self._value, self._dtype) return PsTypedConstant(-self._value, self._dtype)
......
...@@ -27,6 +27,14 @@ def test_constant_folding_int(width): ...@@ -27,6 +27,14 @@ def test_constant_folding_int(width):
assert folder(expr) == PsTypedConstant(-53, SInt(width)) assert folder(expr) == PsTypedConstant(-53, SInt(width))
@pytest.mark.parametrize("width", (8, 16, 32, 64))
def test_constant_folding_product(width):
"""
The pymbolic constant folder shows inconsistent behaviour when folding products.
This test both describes the required behaviour and serves as a reminder to fix it.
"""
folder = ConstantFoldingMapper()
expr = pb.Product( expr = pb.Product(
( (
PsTypedConstant(2, SInt(width)), PsTypedConstant(2, SInt(width)),
......
...@@ -32,7 +32,7 @@ def test_float_constants(width): ...@@ -32,7 +32,7 @@ def test_float_constants(width):
assert a - b == PsTypedConstant(31.5, Fp(width)) assert a - b == PsTypedConstant(31.5, Fp(width))
assert a / c == PsTypedConstant(16.0, Fp(width)) assert a / c == PsTypedConstant(16.0, Fp(width))
def test_illegal_ops(): def test_illegal_ops():
# Cannot interpret negative numbers as unsigned types # Cannot interpret negative numbers as unsigned types
with pytest.raises(PsTypeError): with pytest.raises(PsTypeError):
...@@ -53,7 +53,16 @@ def test_illegal_ops(): ...@@ -53,7 +53,16 @@ def test_illegal_ops():
@pytest.mark.parametrize("width", (8, 16, 32, 64)) @pytest.mark.parametrize("width", (8, 16, 32, 64))
def test_integer_division(width): def test_unsigned_integer_division(width):
a = PsTypedConstant(8, UInt(width))
b = PsTypedConstant(3, UInt(width))
assert a / b == PsTypedConstant(2, UInt(width))
assert a % b == PsTypedConstant(2, UInt(width))
@pytest.mark.parametrize("width", (8, 16, 32, 64))
def test_signed_integer_division(width):
a = PsTypedConstant(-5, SInt(width)) a = PsTypedConstant(-5, SInt(width))
b = PsTypedConstant(2, SInt(width)) b = PsTypedConstant(2, SInt(width))
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment