From 8f65a3a778db4a2b5bace272f0d84900e093df4b Mon Sep 17 00:00:00 2001 From: Daniel Bauer <daniel.j.bauer@fau.de> Date: Thu, 28 Mar 2024 14:36:46 +0100 Subject: [PATCH] Fix typing of constants --- .../backend/kernelcreation/typification.py | 13 ++++++++++--- src/pystencils/backend/symbols.py | 4 +++- src/pystencils/types/basic_types.py | 2 +- .../kernelcreation/test_typification.py | 19 ++++++++++++++----- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index dcfb0f548..9ef649b31 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -96,7 +96,7 @@ class TypeContext: Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is called on this context. - It the expression already has a data type set, it must be equal to the inferred type. + If the expression already has a data type set, it must be equal to the inferred type. """ if self._target_type is None: @@ -126,8 +126,15 @@ class TypeContext: raise TypificationError( f"Can't typify constant with non-numeric type {self._target_type}" ) - c.apply_dtype(self._target_type) - + if c.dtype is None: + c.apply_dtype(self._target_type) + elif deconstify(c.dtype) != self._target_type: + raise TypificationError( + f"Type mismatch at constant {c}: Constant type did not match the context's target type\n" + f" Constant type: {c.dtype}\n" + f" Target type: {self._target_type}" + ) + case PsSymbolExpr(symb): symb.apply_dtype(self._target_type) diff --git a/src/pystencils/backend/symbols.py b/src/pystencils/backend/symbols.py index e1db5a930..b007e3fcf 100644 --- a/src/pystencils/backend/symbols.py +++ b/src/pystencils/backend/symbols.py @@ -42,7 +42,9 @@ class PsSymbol: def get_dtype(self) -> PsType: if self._dtype is None: - raise PsInternalCompilerError("Symbol had no type assigned yet") + raise PsInternalCompilerError( + f"Symbol {self.name} had no type assigned yet" + ) return self._dtype def __str__(self) -> str: diff --git a/src/pystencils/types/basic_types.py b/src/pystencils/types/basic_types.py index 7565ea9a5..b83b6d7d6 100644 --- a/src/pystencils/types/basic_types.py +++ b/src/pystencils/types/basic_types.py @@ -655,7 +655,7 @@ class PsIeeeFloatType(PsScalarType): def __init__(self, width: int, const: bool = False): if width not in self.SUPPORTED_WIDTHS: raise ValueError( - f"Invalid integer width; must be one of {self.SUPPORTED_WIDTHS}." + f"Invalid integer width {width}; must be one of {self.SUPPORTED_WIDTHS}." ) super().__init__(const) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 9ff18623e..cb7e5561f 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -131,13 +131,11 @@ def test_typify_integer_binops(): ctx.get_symbol("x", ctx.index_dtype) ctx.get_symbol("y", ctx.index_dtype) - ctx.get_symbol("z", ctx.index_dtype) - x, y, z = sp.symbols("x, y, z") + x, y = sp.symbols("x, y") expr = bit_shift_left( - bit_shift_right(bitwise_and(x, 2), bitwise_or(y, z)), bitwise_xor(2, 2) - ) # ^ - # TODO: x can not be a constant here, because then the typifier can not check that the arguments are integer. + bit_shift_right(bitwise_and(2, 2), bitwise_or(x, y)), bitwise_xor(2, 2) + ) expr = freeze(expr) expr = typify(expr) @@ -184,3 +182,14 @@ def test_typify_integer_binops_in_floating_context(): with pytest.raises(TypificationError): expr = typify(expr) + + +def test_regression_typify_constants(): + ctx = KernelCreationContext(default_dtype=Fp(32)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x, y = sp.symbols("x, y") + expr = (-x - y) ** 2 + + typify(freeze(expr)) # just test that no error is raised -- GitLab