Skip to content
Snippets Groups Projects
Commit 953a1f93 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Various fixes to constants

parent d2cfa5a0
1 merge request!371Various fixes to constants
from __future__ import annotations
from typing import Any from typing import Any
from ..types import PsNumericType, constify from ..types import PsNumericType, constify
...@@ -5,6 +6,21 @@ from .exceptions import PsInternalCompilerError ...@@ -5,6 +6,21 @@ from .exceptions import PsInternalCompilerError
class PsConstant: class PsConstant:
"""Type-safe representation of typed numerical constants.
This class models constants in the backend representation of kernels.
A constant may be *untyped*, in which case its ``value`` may be any Python object.
If the constant is *typed* (i.e. its ``dtype`` is not ``None``), its data type is used
to check the validity of its ``value`` and to convert it into the type's internal representation.
Instances of `PsConstant` are immutable.
Args:
value: The constant's value
dtype: The constant's data type, or ``None`` if untyped.
"""
__match_args__ = ("value", "dtype") __match_args__ = ("value", "dtype")
def __init__(self, value: Any, dtype: PsNumericType | None = None): def __init__(self, value: Any, dtype: PsNumericType | None = None):
...@@ -12,7 +28,30 @@ class PsConstant: ...@@ -12,7 +28,30 @@ class PsConstant:
self._value = value self._value = value
if dtype is not None: if dtype is not None:
self.apply_dtype(dtype) self._dtype = constify(dtype)
self._value = self._dtype.create_constant(self._value)
else:
self._dtype = None
self._value = value
def interpret_as(self, dtype: PsNumericType) -> PsConstant:
"""Interprets this *untyped* constant with the given data type.
If this constant is already typed, raises an error.
"""
if self._dtype is not None:
raise PsInternalCompilerError(
f"Cannot interpret already typed constant {self} with type {dtype}"
)
return PsConstant(self._value, dtype)
def reinterpret_as(self, dtype: PsNumericType) -> PsConstant:
"""Reinterprets this constant with the given data type.
Other than `interpret_as`, this method also works on typed constants.
"""
return PsConstant(self._value, dtype)
@property @property
def value(self) -> Any: def value(self) -> Any:
...@@ -27,15 +66,6 @@ class PsConstant: ...@@ -27,15 +66,6 @@ class PsConstant:
raise PsInternalCompilerError("Data type of constant was not set.") raise PsInternalCompilerError("Data type of constant was not set.")
return self._dtype return self._dtype
def apply_dtype(self, dtype: PsNumericType):
if self._dtype is not None:
raise PsInternalCompilerError(
"Attempt to apply data type to already typed constant."
)
self._dtype = constify(dtype)
self._value = self._dtype.create_constant(self._value)
def __str__(self) -> str: def __str__(self) -> str:
type_str = "<untyped>" if self._dtype is None else str(self._dtype) type_str = "<untyped>" if self._dtype is None else str(self._dtype)
return f"{str(self._value)}: {type_str}" return f"{str(self._value)}: {type_str}"
......
...@@ -127,7 +127,7 @@ class TypeContext: ...@@ -127,7 +127,7 @@ class TypeContext:
f"Can't typify constant with non-numeric type {self._target_type}" f"Can't typify constant with non-numeric type {self._target_type}"
) )
if c.dtype is None: if c.dtype is None:
c.apply_dtype(self._target_type) expr.constant = c.interpret_as(self._target_type)
elif deconstify(c.dtype) != self._target_type: elif deconstify(c.dtype) != self._target_type:
raise TypificationError( raise TypificationError(
f"Type mismatch at constant {c}: Constant type did not match the context's target type\n" f"Type mismatch at constant {c}: Constant type did not match the context's target type\n"
......
...@@ -486,9 +486,12 @@ class PsBoolType(PsScalarType): ...@@ -486,9 +486,12 @@ class PsBoolType(PsScalarType):
return np.dtype(PsBoolType.NUMPY_TYPE) return np.dtype(PsBoolType.NUMPY_TYPE)
def create_literal(self, value: Any) -> str: def create_literal(self, value: Any) -> str:
if value in (1, True, np.True_): if not isinstance(value, self.NUMPY_TYPE):
raise PsTypeError(f"Given value {value} is not of required type {self.NUMPY_TYPE}")
if value == np.True_:
return "true" return "true"
elif value in (0, False, np.False_): elif value == np.False_:
return "false" return "false"
else: else:
raise PsTypeError(f"Cannot create boolean literal from {value}") raise PsTypeError(f"Cannot create boolean literal from {value}")
...@@ -560,6 +563,17 @@ class PsIntegerType(PsScalarType, ABC): ...@@ -560,6 +563,17 @@ class PsIntegerType(PsScalarType, ABC):
unsigned_suffix = "" if self.signed else "u" unsigned_suffix = "" if self.signed else "u"
# TODO: cast literal to correct type? # TODO: cast literal to correct type?
return str(value) + unsigned_suffix return str(value) + unsigned_suffix
def create_constant(self, value: Any) -> Any:
np_type = self.NUMPY_TYPES[self._width]
if isinstance(value, (int, np.integer)):
iinfo = np.iinfo(np_type) # type: ignore
if value < iinfo.min or value > iinfo.max:
raise PsTypeError(f"Could not interpret {value} as {self}: Value is out of bounds.")
return np_type(value)
raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, PsIntegerType): if not isinstance(other, PsIntegerType):
...@@ -598,17 +612,6 @@ class PsSignedIntegerType(PsIntegerType): ...@@ -598,17 +612,6 @@ class PsSignedIntegerType(PsIntegerType):
def __init__(self, width: int, const: bool = False): def __init__(self, width: int, const: bool = False):
super().__init__(width, True, const) super().__init__(width, True, const)
def create_constant(self, value: Any) -> Any:
np_type = self.NUMPY_TYPES[self._width]
if isinstance(value, int):
return np_type(value)
if isinstance(value, np_type):
return value
raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
@final @final
class PsUnsignedIntegerType(PsIntegerType): class PsUnsignedIntegerType(PsIntegerType):
...@@ -626,17 +629,6 @@ class PsUnsignedIntegerType(PsIntegerType): ...@@ -626,17 +629,6 @@ class PsUnsignedIntegerType(PsIntegerType):
def __init__(self, width: int, const: bool = False): def __init__(self, width: int, const: bool = False):
super().__init__(width, False, const) super().__init__(width, False, const)
def create_constant(self, value: Any) -> Any:
np_type = self.NUMPY_TYPES[self._width]
if isinstance(value, int) and value >= 0:
return np_type(value)
if isinstance(value, np_type):
return value
raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
@final @final
class PsIeeeFloatType(PsScalarType): class PsIeeeFloatType(PsScalarType):
...@@ -698,12 +690,12 @@ class PsIeeeFloatType(PsScalarType): ...@@ -698,12 +690,12 @@ class PsIeeeFloatType(PsScalarType):
def create_constant(self, value: Any) -> Any: def create_constant(self, value: Any) -> Any:
np_type = self.NUMPY_TYPES[self._width] np_type = self.NUMPY_TYPES[self._width]
if isinstance(value, int) or isinstance(value, float): if isinstance(value, (int, float, np.floating)):
finfo = np.finfo(np_type) # type: ignore
if value < finfo.min or value > finfo.max:
raise PsTypeError(f"Could not interpret {value} as {self}: Value is out of bounds.")
return np_type(value) return np_type(value)
if isinstance(value, np_type):
return value
raise PsTypeError(f"Could not interpret {value} as {repr(self)}") raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
......
...@@ -2,10 +2,13 @@ import pytest ...@@ -2,10 +2,13 @@ import pytest
import sympy as sp import sympy as sp
import numpy as np import numpy as np
from typing import cast
from pystencils import Assignment, TypedSymbol, Field, FieldType from pystencils import Assignment, TypedSymbol, Field, FieldType
from pystencils.backend.ast.structural import PsDeclaration from pystencils.backend.ast.structural import PsDeclaration
from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp
from pystencils.backend.constants import PsConstant
from pystencils.types import constify from pystencils.types import constify
from pystencils.types.quick import Fp, create_numeric_type from pystencils.types.quick import Fp, create_numeric_type
from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.context import KernelCreationContext
...@@ -35,6 +38,7 @@ def test_typify_simple(): ...@@ -35,6 +38,7 @@ def test_typify_simple():
assert isinstance(fasm, PsDeclaration) assert isinstance(fasm, PsDeclaration)
def check(expr): def check(expr):
assert expr.dtype == ctx.default_dtype
match expr: match expr:
case PsConstantExpr(cs): case PsConstantExpr(cs):
assert cs.value == 2 assert cs.value == 2
...@@ -83,6 +87,7 @@ def test_contextual_typing(): ...@@ -83,6 +87,7 @@ def test_contextual_typing():
expr = typify(expr) expr = typify(expr)
def check(expr): def check(expr):
assert expr.dtype == ctx.default_dtype
match expr: match expr:
case PsConstantExpr(cs): case PsConstantExpr(cs):
assert cs.value in (2, 3, -4) assert cs.value in (2, 3, -4)
...@@ -184,12 +189,16 @@ def test_typify_integer_binops_in_floating_context(): ...@@ -184,12 +189,16 @@ def test_typify_integer_binops_in_floating_context():
expr = typify(expr) expr = typify(expr)
def test_regression_typify_constants(): def test_typify_constant_clones():
ctx = KernelCreationContext(default_dtype=Fp(32)) ctx = KernelCreationContext(default_dtype=Fp(32))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx) typify = Typifier(ctx)
x, y = sp.symbols("x, y") c = PsConstantExpr(PsConstant(3.0))
expr = (-x - y) ** 2 x = PsSymbolExpr(ctx.get_symbol("x"))
expr = c + x
expr_clone = expr.clone()
typify(freeze(expr)) # just test that no error is raised expr = typify(expr)
assert expr_clone.operand1.dtype is None
assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None
# TODO: Reimplement for constant folder
# import pytest
# from pystencils.types.quick import *
# from pystencils.backend.constants import PsConstant
# @pytest.mark.parametrize("width", (8, 16, 32, 64))
# def test_constant_folding_int(width):
# folder = ConstantFoldingMapper()
# expr = pb.Sum(
# (
# PsTypedConstant(13, UInt(width)),
# PsTypedConstant(5, UInt(width)),
# PsTypedConstant(3, UInt(width)),
# )
# )
# assert folder(expr) == PsTypedConstant(21, UInt(width))
# expr = pb.Product(
# (PsTypedConstant(-1, SInt(width)), PsTypedConstant(41, SInt(width)))
# ) - PsTypedConstant(12, SInt(width))
# assert folder(expr) == PsTypedConstant(-53, SInt(width))
import numpy as np
import pytest
from pystencils.types import PsTypeError
from pystencils.backend.constants import PsConstant
from pystencils.types.quick import Fp, Bool, UInt, SInt
from pystencils.backend.exceptions import PsInternalCompilerError
def test_constant_equality():
c1 = PsConstant(1.0, Fp(32))
c2 = PsConstant(1.0, Fp(32))
assert c1 == c2
assert hash(c1) == hash(c2)
c3 = PsConstant(1.0, Fp(64))
assert c1 != c3
assert hash(c1) != hash(c3)
c4 = c1.reinterpret_as(Fp(64))
assert c4 != c1
assert c4 == c3
def test_interpret():
c1 = PsConstant(3.4, Fp(32))
c2 = PsConstant(3.4)
assert c2.interpret_as(Fp(32)) == c1
with pytest.raises(PsInternalCompilerError):
_ = c1.interpret_as(Fp(64))
def test_boolean_constants():
true = PsConstant(True, Bool())
for val in (1, 1.0, True, np.True_):
assert PsConstant(val, Bool()) == true
false = PsConstant(False, Bool())
for val in (0, 0.0, False, np.False_):
assert PsConstant(val, Bool()) == false
with pytest.raises(PsTypeError):
PsConstant(1.1, Bool())
def test_integer_bounds():
# should not throw:
for val in (255, np.uint8(255), np.int16(255), np.int64(255)):
_ = PsConstant(val, UInt(8))
for val in (-128, np.int16(-128), np.int64(-128)):
_ = PsConstant(val, SInt(8))
# should throw:
for val in (256, np.int16(256), np.int64(256)):
with pytest.raises(PsTypeError):
_ = PsConstant(val, UInt(8))
for val in (-42, np.int32(-42)):
with pytest.raises(PsTypeError):
_ = PsConstant(val, UInt(8))
for val in (-129, np.int16(-129), np.int64(-129)):
with pytest.raises(PsTypeError):
_ = PsConstant(val, SInt(8))
def test_floating_bounds():
for val in (5.1e4, -5.9e4):
_ = PsConstant(val, Fp(16))
_ = PsConstant(val, Fp(32))
_ = PsConstant(val, Fp(64))
for val in (8.1e5, -7.6e5):
with pytest.raises(PsTypeError):
_ = PsConstant(val, Fp(16))
# import pytest
# TODO: Re-implement for constant folder
# from pystencils.types.quick import *
# from pystencils.types import PsTypeError
# from pystencils.backend.typed_expressions import PsTypedConstant
# @pytest.mark.parametrize("width", (8, 16, 32, 64))
# def test_integer_constants(width):
# dtype = SInt(width)
# a = PsTypedConstant(42, dtype)
# b = PsTypedConstant(2, dtype)
# assert a + b == PsTypedConstant(44, dtype)
# assert a - b == PsTypedConstant(40, dtype)
# assert a * b == PsTypedConstant(84, dtype)
# assert a - b != PsTypedConstant(-12, dtype)
# # Typed constants only compare to themselves
# assert a + b != 44
# @pytest.mark.parametrize("width", (32, 64))
# def test_float_constants(width):
# a = PsTypedConstant(32.0, Fp(width))
# b = PsTypedConstant(0.5, Fp(width))
# c = PsTypedConstant(2.0, Fp(width))
# assert a + b == PsTypedConstant(32.5, Fp(width))
# assert a * b == PsTypedConstant(16.0, Fp(width))
# assert a - b == PsTypedConstant(31.5, Fp(width))
# assert a / c == PsTypedConstant(16.0, Fp(width))
# def test_illegal_ops():
# # Cannot interpret negative numbers as unsigned types
# with pytest.raises(PsTypeError):
# _ = PsTypedConstant(-3, UInt(32))
# # Mixed ops are illegal
# with pytest.raises(PsTypeError):
# _ = PsTypedConstant(32.0, Fp(32)) + PsTypedConstant(2, UInt(32))
# with pytest.raises(PsTypeError):
# _ = PsTypedConstant(32.0, Fp(32)) - PsTypedConstant(2, UInt(32))
# with pytest.raises(PsTypeError):
# _ = PsTypedConstant(32.0, Fp(32)) * PsTypedConstant(2, UInt(32))
# with pytest.raises(PsTypeError):
# _ = PsTypedConstant(32.0, Fp(32)) / PsTypedConstant(2, UInt(32))
# @pytest.mark.parametrize("width", (8, 16, 32, 64))
# def test_unsigned_integer_division(width):
# a = PsTypedConstant(8, UInt(width))
# b = PsTypedConstant(3, UInt(width))
# assert a / b == PsTypedConstant(2, UInt(width))
# assert a % b == PsTypedConstant(2, UInt(width))
# @pytest.mark.parametrize("width", (8, 16, 32, 64))
# def test_signed_integer_division(width):
# five = PsTypedConstant(5, SInt(width))
# two = PsTypedConstant(2, SInt(width))
# assert five / two == PsTypedConstant(2, SInt(width))
# assert five % two == PsTypedConstant(1, SInt(width))
# assert (- five) / two == PsTypedConstant(-2, SInt(width))
# assert (- five) % two == PsTypedConstant(-1, SInt(width))
# assert five / (- two) == PsTypedConstant(-2, SInt(width))
# assert five % (- two) == PsTypedConstant(1, SInt(width))
# assert (- five) / (- two) == PsTypedConstant(2, SInt(width))
# assert (- five) % (- two) == PsTypedConstant(-1, SInt(width))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment