From 953a1f9387f40e38058ee60ebba08b74cece919c Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 28 Mar 2024 16:03:35 +0100 Subject: [PATCH] Various fixes to constants --- src/pystencils/backend/constants.py | 50 +++++++++--- .../backend/kernelcreation/typification.py | 2 +- src/pystencils/types/basic_types.py | 48 +++++------ .../kernelcreation/test_typification.py | 19 +++-- tests/nbackend/test_constant_folding.py | 26 ------ tests/nbackend/test_constants.py | 79 ++++++++++++++++++ tests/nbackend/types/test_constants.py | 80 ------------------- 7 files changed, 154 insertions(+), 150 deletions(-) delete mode 100644 tests/nbackend/test_constant_folding.py create mode 100644 tests/nbackend/test_constants.py delete mode 100644 tests/nbackend/types/test_constants.py diff --git a/src/pystencils/backend/constants.py b/src/pystencils/backend/constants.py index 6e76f6dbb..6dc07842f 100644 --- a/src/pystencils/backend/constants.py +++ b/src/pystencils/backend/constants.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Any from ..types import PsNumericType, constify @@ -5,6 +6,21 @@ from .exceptions import PsInternalCompilerError class PsConstant: + """Type-safe representation of typed numerical constants. + + This class models constants in the backend representation of kernels. + A constant may be *untyped*, in which case its ``value`` may be any Python object. + + If the constant is *typed* (i.e. its ``dtype`` is not ``None``), its data type is used + to check the validity of its ``value`` and to convert it into the type's internal representation. + + Instances of `PsConstant` are immutable. + + Args: + value: The constant's value + dtype: The constant's data type, or ``None`` if untyped. + """ + __match_args__ = ("value", "dtype") def __init__(self, value: Any, dtype: PsNumericType | None = None): @@ -12,7 +28,30 @@ class PsConstant: self._value = value if dtype is not None: - self.apply_dtype(dtype) + self._dtype = constify(dtype) + self._value = self._dtype.create_constant(self._value) + else: + self._dtype = None + self._value = value + + def interpret_as(self, dtype: PsNumericType) -> PsConstant: + """Interprets this *untyped* constant with the given data type. + + If this constant is already typed, raises an error. + """ + if self._dtype is not None: + raise PsInternalCompilerError( + f"Cannot interpret already typed constant {self} with type {dtype}" + ) + + return PsConstant(self._value, dtype) + + def reinterpret_as(self, dtype: PsNumericType) -> PsConstant: + """Reinterprets this constant with the given data type. + + Other than `interpret_as`, this method also works on typed constants. + """ + return PsConstant(self._value, dtype) @property def value(self) -> Any: @@ -27,15 +66,6 @@ class PsConstant: raise PsInternalCompilerError("Data type of constant was not set.") return self._dtype - def apply_dtype(self, dtype: PsNumericType): - if self._dtype is not None: - raise PsInternalCompilerError( - "Attempt to apply data type to already typed constant." - ) - - self._dtype = constify(dtype) - self._value = self._dtype.create_constant(self._value) - def __str__(self) -> str: type_str = "<untyped>" if self._dtype is None else str(self._dtype) return f"{str(self._value)}: {type_str}" diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 9ef649b31..d2c93e221 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -127,7 +127,7 @@ class TypeContext: f"Can't typify constant with non-numeric type {self._target_type}" ) if c.dtype is None: - c.apply_dtype(self._target_type) + expr.constant = c.interpret_as(self._target_type) elif deconstify(c.dtype) != self._target_type: raise TypificationError( f"Type mismatch at constant {c}: Constant type did not match the context's target type\n" diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/basic_types.py index b83b6d7d6..3678ea126 100644 --- a/src/pystencils/types/basic_types.py +++ b/src/pystencils/types/basic_types.py @@ -486,9 +486,12 @@ class PsBoolType(PsScalarType): return np.dtype(PsBoolType.NUMPY_TYPE) def create_literal(self, value: Any) -> str: - if value in (1, True, np.True_): + if not isinstance(value, self.NUMPY_TYPE): + raise PsTypeError(f"Given value {value} is not of required type {self.NUMPY_TYPE}") + + if value == np.True_: return "true" - elif value in (0, False, np.False_): + elif value == np.False_: return "false" else: raise PsTypeError(f"Cannot create boolean literal from {value}") @@ -560,6 +563,17 @@ class PsIntegerType(PsScalarType, ABC): unsigned_suffix = "" if self.signed else "u" # TODO: cast literal to correct type? return str(value) + unsigned_suffix + + def create_constant(self, value: Any) -> Any: + np_type = self.NUMPY_TYPES[self._width] + + if isinstance(value, (int, np.integer)): + iinfo = np.iinfo(np_type) # type: ignore + if value < iinfo.min or value > iinfo.max: + raise PsTypeError(f"Could not interpret {value} as {self}: Value is out of bounds.") + return np_type(value) + + raise PsTypeError(f"Could not interpret {value} as {repr(self)}") def __eq__(self, other: object) -> bool: if not isinstance(other, PsIntegerType): @@ -598,17 +612,6 @@ class PsSignedIntegerType(PsIntegerType): def __init__(self, width: int, const: bool = False): super().__init__(width, True, const) - def create_constant(self, value: Any) -> Any: - np_type = self.NUMPY_TYPES[self._width] - - if isinstance(value, int): - return np_type(value) - - if isinstance(value, np_type): - return value - - raise PsTypeError(f"Could not interpret {value} as {repr(self)}") - @final class PsUnsignedIntegerType(PsIntegerType): @@ -626,17 +629,6 @@ class PsUnsignedIntegerType(PsIntegerType): def __init__(self, width: int, const: bool = False): super().__init__(width, False, const) - def create_constant(self, value: Any) -> Any: - np_type = self.NUMPY_TYPES[self._width] - - if isinstance(value, int) and value >= 0: - return np_type(value) - - if isinstance(value, np_type): - return value - - raise PsTypeError(f"Could not interpret {value} as {repr(self)}") - @final class PsIeeeFloatType(PsScalarType): @@ -698,12 +690,12 @@ class PsIeeeFloatType(PsScalarType): def create_constant(self, value: Any) -> Any: np_type = self.NUMPY_TYPES[self._width] - if isinstance(value, int) or isinstance(value, float): + if isinstance(value, (int, float, np.floating)): + finfo = np.finfo(np_type) # type: ignore + if value < finfo.min or value > finfo.max: + raise PsTypeError(f"Could not interpret {value} as {self}: Value is out of bounds.") return np_type(value) - if isinstance(value, np_type): - return value - raise PsTypeError(f"Could not interpret {value} as {repr(self)}") def __eq__(self, other: object) -> bool: diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index cb7e5561f..ef746c614 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -2,10 +2,13 @@ import pytest import sympy as sp import numpy as np +from typing import cast + from pystencils import Assignment, TypedSymbol, Field, FieldType from pystencils.backend.ast.structural import PsDeclaration from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp +from pystencils.backend.constants import PsConstant from pystencils.types import constify from pystencils.types.quick import Fp, create_numeric_type from pystencils.backend.kernelcreation.context import KernelCreationContext @@ -35,6 +38,7 @@ def test_typify_simple(): assert isinstance(fasm, PsDeclaration) def check(expr): + assert expr.dtype == ctx.default_dtype match expr: case PsConstantExpr(cs): assert cs.value == 2 @@ -83,6 +87,7 @@ def test_contextual_typing(): expr = typify(expr) def check(expr): + assert expr.dtype == ctx.default_dtype match expr: case PsConstantExpr(cs): assert cs.value in (2, 3, -4) @@ -184,12 +189,16 @@ def test_typify_integer_binops_in_floating_context(): expr = typify(expr) -def test_regression_typify_constants(): +def test_typify_constant_clones(): ctx = KernelCreationContext(default_dtype=Fp(32)) - freeze = FreezeExpressions(ctx) typify = Typifier(ctx) - x, y = sp.symbols("x, y") - expr = (-x - y) ** 2 + c = PsConstantExpr(PsConstant(3.0)) + x = PsSymbolExpr(ctx.get_symbol("x")) + expr = c + x + expr_clone = expr.clone() - typify(freeze(expr)) # just test that no error is raised + expr = typify(expr) + + assert expr_clone.operand1.dtype is None + assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None diff --git a/tests/nbackend/test_constant_folding.py b/tests/nbackend/test_constant_folding.py deleted file mode 100644 index ee214ff53..000000000 --- a/tests/nbackend/test_constant_folding.py +++ /dev/null @@ -1,26 +0,0 @@ -# TODO: Reimplement for constant folder -# import pytest - -# from pystencils.types.quick import * -# from pystencils.backend.constants import PsConstant - - -# @pytest.mark.parametrize("width", (8, 16, 32, 64)) -# def test_constant_folding_int(width): -# folder = ConstantFoldingMapper() - -# expr = pb.Sum( -# ( -# PsTypedConstant(13, UInt(width)), -# PsTypedConstant(5, UInt(width)), -# PsTypedConstant(3, UInt(width)), -# ) -# ) - -# assert folder(expr) == PsTypedConstant(21, UInt(width)) - -# expr = pb.Product( -# (PsTypedConstant(-1, SInt(width)), PsTypedConstant(41, SInt(width))) -# ) - PsTypedConstant(12, SInt(width)) - -# assert folder(expr) == PsTypedConstant(-53, SInt(width)) diff --git a/tests/nbackend/test_constants.py b/tests/nbackend/test_constants.py new file mode 100644 index 000000000..93c772e60 --- /dev/null +++ b/tests/nbackend/test_constants.py @@ -0,0 +1,79 @@ +import numpy as np +import pytest + +from pystencils.types import PsTypeError +from pystencils.backend.constants import PsConstant +from pystencils.types.quick import Fp, Bool, UInt, SInt +from pystencils.backend.exceptions import PsInternalCompilerError + + +def test_constant_equality(): + c1 = PsConstant(1.0, Fp(32)) + c2 = PsConstant(1.0, Fp(32)) + + assert c1 == c2 + assert hash(c1) == hash(c2) + + c3 = PsConstant(1.0, Fp(64)) + assert c1 != c3 + assert hash(c1) != hash(c3) + + c4 = c1.reinterpret_as(Fp(64)) + assert c4 != c1 + assert c4 == c3 + + +def test_interpret(): + c1 = PsConstant(3.4, Fp(32)) + c2 = PsConstant(3.4) + + assert c2.interpret_as(Fp(32)) == c1 + + with pytest.raises(PsInternalCompilerError): + _ = c1.interpret_as(Fp(64)) + + +def test_boolean_constants(): + true = PsConstant(True, Bool()) + for val in (1, 1.0, True, np.True_): + assert PsConstant(val, Bool()) == true + + false = PsConstant(False, Bool()) + for val in (0, 0.0, False, np.False_): + assert PsConstant(val, Bool()) == false + + with pytest.raises(PsTypeError): + PsConstant(1.1, Bool()) + + +def test_integer_bounds(): + # should not throw: + for val in (255, np.uint8(255), np.int16(255), np.int64(255)): + _ = PsConstant(val, UInt(8)) + + for val in (-128, np.int16(-128), np.int64(-128)): + _ = PsConstant(val, SInt(8)) + + # should throw: + for val in (256, np.int16(256), np.int64(256)): + with pytest.raises(PsTypeError): + _ = PsConstant(val, UInt(8)) + + for val in (-42, np.int32(-42)): + with pytest.raises(PsTypeError): + _ = PsConstant(val, UInt(8)) + + for val in (-129, np.int16(-129), np.int64(-129)): + with pytest.raises(PsTypeError): + _ = PsConstant(val, SInt(8)) + + +def test_floating_bounds(): + for val in (5.1e4, -5.9e4): + _ = PsConstant(val, Fp(16)) + _ = PsConstant(val, Fp(32)) + _ = PsConstant(val, Fp(64)) + + for val in (8.1e5, -7.6e5): + with pytest.raises(PsTypeError): + _ = PsConstant(val, Fp(16)) diff --git a/tests/nbackend/types/test_constants.py b/tests/nbackend/types/test_constants.py deleted file mode 100644 index 4d948e4e3..000000000 --- a/tests/nbackend/types/test_constants.py +++ /dev/null @@ -1,80 +0,0 @@ -# import pytest - -# TODO: Re-implement for constant folder -# from pystencils.types.quick import * -# from pystencils.types import PsTypeError -# from pystencils.backend.typed_expressions import PsTypedConstant - - -# @pytest.mark.parametrize("width", (8, 16, 32, 64)) -# def test_integer_constants(width): -# dtype = SInt(width) -# a = PsTypedConstant(42, dtype) -# b = PsTypedConstant(2, dtype) - -# assert a + b == PsTypedConstant(44, dtype) -# assert a - b == PsTypedConstant(40, dtype) -# assert a * b == PsTypedConstant(84, dtype) - -# assert a - b != PsTypedConstant(-12, dtype) - -# # Typed constants only compare to themselves -# assert a + b != 44 - - -# @pytest.mark.parametrize("width", (32, 64)) -# def test_float_constants(width): -# a = PsTypedConstant(32.0, Fp(width)) -# b = PsTypedConstant(0.5, Fp(width)) -# c = PsTypedConstant(2.0, Fp(width)) - -# assert a + b == PsTypedConstant(32.5, Fp(width)) -# assert a * b == PsTypedConstant(16.0, Fp(width)) -# assert a - b == PsTypedConstant(31.5, Fp(width)) -# assert a / c == PsTypedConstant(16.0, Fp(width)) - - -# def test_illegal_ops(): -# # Cannot interpret negative numbers as unsigned types -# with pytest.raises(PsTypeError): -# _ = PsTypedConstant(-3, UInt(32)) - -# # Mixed ops are illegal -# with pytest.raises(PsTypeError): -# _ = PsTypedConstant(32.0, Fp(32)) + PsTypedConstant(2, UInt(32)) - -# with pytest.raises(PsTypeError): -# _ = PsTypedConstant(32.0, Fp(32)) - PsTypedConstant(2, UInt(32)) - -# with pytest.raises(PsTypeError): -# _ = PsTypedConstant(32.0, Fp(32)) * PsTypedConstant(2, UInt(32)) - -# with pytest.raises(PsTypeError): -# _ = PsTypedConstant(32.0, Fp(32)) / PsTypedConstant(2, UInt(32)) - - -# @pytest.mark.parametrize("width", (8, 16, 32, 64)) -# def test_unsigned_integer_division(width): -# a = PsTypedConstant(8, UInt(width)) -# b = PsTypedConstant(3, UInt(width)) - -# assert a / b == PsTypedConstant(2, UInt(width)) -# assert a % b == PsTypedConstant(2, UInt(width)) - - -# @pytest.mark.parametrize("width", (8, 16, 32, 64)) -# def test_signed_integer_division(width): -# five = PsTypedConstant(5, SInt(width)) -# two = PsTypedConstant(2, SInt(width)) - -# assert five / two == PsTypedConstant(2, SInt(width)) -# assert five % two == PsTypedConstant(1, SInt(width)) - -# assert (- five) / two == PsTypedConstant(-2, SInt(width)) -# assert (- five) % two == PsTypedConstant(-1, SInt(width)) - -# assert five / (- two) == PsTypedConstant(-2, SInt(width)) -# assert five % (- two) == PsTypedConstant(1, SInt(width)) - -# assert (- five) / (- two) == PsTypedConstant(2, SInt(width)) -# assert (- five) % (- two) == PsTypedConstant(-1, SInt(width)) -- GitLab