diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 9df186470086f8115fe0c832917a8676d04aa7bf..916274314b1651aae29619ef2a6a59e15b1d3cc6 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -12,7 +12,7 @@ from ...sympyextensions.typed_sympy import TypedSymbol from ..symbols import PsSymbol from ..arrays import PsLinearizedArray -from ...types import PsType, PsIntegerType, PsNumericType, PsScalarType, PsStructType +from ...types import PsType, PsIntegerType, PsNumericType, PsScalarType, PsStructType, deconstify from ..constraints import KernelParamsConstraint from ..exceptions import PsInternalCompilerError, KernelConstraintsError @@ -63,8 +63,8 @@ class KernelCreationContext: default_dtype: PsNumericType = DEFAULTS.numeric_dtype, index_dtype: PsIntegerType = DEFAULTS.index_dtype, ): - self._default_dtype = default_dtype - self._index_dtype = index_dtype + self._default_dtype = deconstify(default_dtype) + self._index_dtype = deconstify(index_dtype) self._symbols: dict[str, PsSymbol] = dict() diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 8ef6edd24d54aa9a9992c04adf51abe620ab4813..fc085e2be99f61204cde92438811a3b4e41c8bf7 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -71,7 +71,9 @@ class TypeContext: """ def __init__( - self, target_type: PsType | None = None, require_nonconst: bool = False + self, + target_type: PsType | None = None, + require_nonconst: bool = False, ): self._require_nonconst = require_nonconst self._deferred_exprs: list[PsExpression] = [] @@ -171,8 +173,10 @@ class TypeContext: ) case PsSymbolExpr(symb): - assert symb.dtype is not None - if not self._compatible(symb.dtype): + if symb.dtype is None: + # Symbols are not forced to constness + symb.dtype = deconstify(self._target_type) + elif not self._compatible(symb.dtype): raise TypificationError( f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n" f" Symbol type: {symb.dtype}\n" @@ -262,7 +266,11 @@ class Typifier: The following general rules apply: - - The context's `default_dtype` is applied to all untyped symbols + - The context's ``default_dtype`` is applied to all untyped symbols encountered inside a right-hand side expression + - If an untyped symbol is encountered on an assignment's left-hand side, it will first be attempted to infer its + type from the right-hand side. If that fails, the context's ``default_dtype`` will be applied. + - It is an error if an untyped symbol occurs in the same type context as a typed symbol or constant + with a non-default data type. - By default, all expressions receive a ``const`` type unless they occur on a (non-declaration) assignment's left-hand side @@ -280,7 +288,12 @@ class Typifier: def __call__(self, node: NodeT) -> NodeT: if isinstance(node, PsExpression): - self.visit_expr(node, TypeContext()) + tc = TypeContext() + self.visit_expr(node, tc) + + if tc.target_type is None: + # no type could be inferred -> take the default + tc.apply_dtype(self._ctx.default_dtype) else: self.visit(node) return node @@ -304,20 +317,44 @@ class Typifier: self.visit(s) case PsDeclaration(lhs, rhs): + # Only if the LHS is an untyped symbol, infer its type from the RHS + infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None + tc = TypeContext() - # LHS defines target type; type context carries it to RHS - self.visit_expr(lhs, tc) - assert tc.target_type is not None + + if infer_lhs: + tc.infer_dtype(lhs) + else: + self.visit_expr(lhs, tc) + assert tc.target_type is not None + self.visit_expr(rhs, tc) + if infer_lhs and tc.target_type is None: + # no type has been inferred -> use the default dtype + tc.apply_dtype(self._ctx.default_dtype) + case PsAssignment(lhs, rhs): + infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None + tc_lhs = TypeContext(require_nonconst=True) - self.visit_expr(lhs, tc_lhs) - assert tc_lhs.target_type is not None - tc_rhs = TypeContext(tc_lhs.target_type, require_nonconst=False) + if infer_lhs: + tc_lhs.infer_dtype(lhs) + else: + self.visit_expr(lhs, tc_lhs) + assert tc_lhs.target_type is not None + + tc_rhs = TypeContext(target_type=tc_lhs.target_type) self.visit_expr(rhs, tc_rhs) + if infer_lhs: + if tc_rhs.target_type is None: + tc_rhs.apply_dtype(self._ctx.default_dtype) + + assert tc_rhs.target_type is not None + tc_lhs.apply_dtype(deconstify(tc_rhs.target_type)) + case PsConditional(cond, branch_true, branch_false): cond_tc = TypeContext(PsBoolType()) self.visit_expr(cond, cond_tc) @@ -330,6 +367,7 @@ class Typifier: case PsLoop(ctr, start, stop, step, body): if ctr.symbol.dtype is None: ctr.symbol.apply_dtype(self._ctx.index_dtype) + ctr.dtype = ctr.symbol.get_dtype() tc_index = TypeContext(ctr.symbol.dtype) self.visit_expr(start, tc_index) @@ -355,11 +393,10 @@ class Typifier: either ``apply_dtype`` or ``infer_dtype``. """ match expr: - case PsSymbolExpr(_): - if expr.symbol.dtype is None: - expr.symbol.dtype = self._ctx.default_dtype - - tc.apply_dtype(expr.symbol.dtype, expr) + case PsSymbolExpr(symb): + if symb.dtype is None: + symb.dtype = self._ctx.default_dtype + tc.apply_dtype(symb.dtype, expr) case PsConstantExpr(c): if c.dtype is not None: diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 5c2631e1eba1eaa84eb8f7ba6442ec020a191245..d3da7e8881d631266d58ee6cc0d4d3612a2900a1 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -4,7 +4,7 @@ import numpy as np from typing import cast -from pystencils import Assignment, TypedSymbol, Field, FieldType +from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment from pystencils.backend.ast.structural import ( PsDeclaration, @@ -16,6 +16,7 @@ from pystencils.backend.ast.structural import ( from pystencils.backend.ast.expressions import ( PsConstantExpr, PsSymbolExpr, + PsSubscript, PsBinOp, PsAnd, PsOr, @@ -27,12 +28,12 @@ from pystencils.backend.ast.expressions import ( PsGt, PsLt, PsCall, - PsTernary + PsTernary, ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import CFunction from pystencils.types import constify, create_type, create_numeric_type -from pystencils.types.quick import Fp, Int, Bool +from pystencils.types.quick import Fp, Int, Bool, Arr from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError @@ -186,7 +187,7 @@ def test_typify_structs(): fasm = typify(fasm) -def test_contextual_typing(): +def test_default_typing(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) typify = Typifier(ctx) @@ -213,6 +214,45 @@ def test_contextual_typing(): check(expr) +def test_lhs_inference(): + ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x, y, z = sp.symbols("x, y, z") + q = TypedSymbol("q", np.float32) + w = TypedSymbol("w", np.float16) + + # Type of the LHS is propagated to untyped RHS symbols + + asm = Assignment(x, 3 - q) + fasm = typify(freeze(asm)) + + assert ctx.get_symbol("x").dtype == Fp(32) + assert fasm.lhs.dtype == constify(Fp(32)) + + asm = Assignment(y, 3 - w) + fasm = typify(freeze(asm)) + + assert ctx.get_symbol("y").dtype == Fp(16) + assert fasm.lhs.dtype == constify(Fp(16)) + + fasm = PsAssignment(PsExpression.make(ctx.get_symbol("z")), freeze(3 - w)) + fasm = typify(fasm) + + assert ctx.get_symbol("z").dtype == Fp(16) + assert fasm.lhs.dtype == Fp(16) + + fasm = PsDeclaration( + PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q)) + ) + fasm = typify(fasm) + + assert ctx.get_symbol("r").dtype == Bool() + assert fasm.lhs.dtype == constify(Bool()) + assert fasm.rhs.dtype == constify(Bool()) + + def test_erronous_typing(): ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) freeze = FreezeExpressions(ctx) @@ -227,16 +267,43 @@ def test_erronous_typing(): with pytest.raises(TypificationError): typify(expr) + # Conflict between LHS and RHS symbols asm = Assignment(q, 3 - w) fasm = freeze(asm) with pytest.raises(TypificationError): typify(fasm) + # Do not propagate types back from LHS symbols to RHS symbols asm = Assignment(q, 3 - x) fasm = freeze(asm) with pytest.raises(TypificationError): typify(fasm) + asm = AddAugmentedAssignment(z, 3 - q) + fasm = freeze(asm) + with pytest.raises(TypificationError): + typify(fasm) + + +def test_invalid_indices(): + ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) + typify = Typifier(ctx) + + arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64)))) + x, y, z = [PsExpression.make(ctx.get_symbol(x)) for x in "xyz"] + + # Using default-typed symbols as array indices is illegal when the default type is a float + + fasm = PsAssignment(PsSubscript(arr, x + y), z) + + with pytest.raises(TypificationError): + typify(fasm) + + fasm = PsAssignment(z, PsSubscript(arr, x + y)) + + with pytest.raises(TypificationError): + typify(fasm) + def test_typify_integer_binops(): ctx = KernelCreationContext() @@ -366,7 +433,7 @@ def test_invalid_conditions(): with pytest.raises(TypificationError): typify(cond) - + def test_typify_ternary(): ctx = KernelCreationContext() typify = Typifier(ctx) diff --git a/tests/symbolics/test_conditional_field_access.py b/tests/symbolics/test_conditional_field_access.py index e18ffc56a4b0a95c30ff2e9e2d4affa5567654ac..1dbc88cf45c1d5bd9735cb018bb80b06b5be5f37 100644 --- a/tests/symbolics/test_conditional_field_access.py +++ b/tests/symbolics/test_conditional_field_access.py @@ -51,9 +51,6 @@ def add_fixed_constant_boundary_handling(assignments, with_cse): @pytest.mark.parametrize('dtype', ('float64', 'float32')) @pytest.mark.parametrize('with_cse', (False, 'with_cse')) def test_boundary_check(dtype, with_cse): - if with_cse: - pytest.xfail("Doesn't typify correctly yet.") - f, g = ps.fields(f"f, g : {dtype}[2D]") stencil = ps.Assignment(g[0, 0], (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4)