diff --git a/src/pystencils/backend/constants.py b/src/pystencils/backend/constants.py index 6dc07842ff12a1aadeedd37b15e3ee8c8185ebe9..125c1149ba7fec7ce279afef089229f70c75eb18 100644 --- a/src/pystencils/backend/constants.py +++ b/src/pystencils/backend/constants.py @@ -59,9 +59,14 @@ class PsConstant: @property def dtype(self) -> PsNumericType | None: + """This constant's data type, or ``None`` if it is untyped. + + The data type of a constant always has ``const == True``. + """ return self._dtype def get_dtype(self) -> PsNumericType: + """Retrieve this constant's data type, throwing an exception if the constant is untyped.""" if self._dtype is None: raise PsInternalCompilerError("Data type of constant was not set.") return self._dtype diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index d2c93e22109c4970ad5bb8915e7c6c8cc84173d6..190cd9e23b715ac5bca2a4bc6fd119013376cb49 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -12,7 +12,7 @@ from ...types import ( PsDereferencableType, PsPointerType, PsBoolType, - deconstify, + constify, ) from ..ast.structural import ( PsAstNode, @@ -21,6 +21,7 @@ from ..ast.structural import ( PsConditional, PsExpression, PsAssignment, + PsDeclaration, ) from ..ast.expressions import ( PsArrayAccess, @@ -54,20 +55,48 @@ NodeT = TypeVar("NodeT", bound=PsAstNode) class TypeContext: - def __init__(self, target_type: PsType | None = None): - self._target_type = deconstify(target_type) if target_type is not None else None + """Typing context, with support for type inference and checking. + + Instances of this class are used to propagate and check data types across expression subtrees + of the AST. Each type context has: + + - A target type `target_type`, which shall be applied to all expressions it covers + - A set of restrictions on the target type: + - `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides + - Additional restrictions may be added in the future. + """ + + def __init__( + self, target_type: PsType | None = None, require_nonconst: bool = False + ): + self._require_nonconst = require_nonconst self._deferred_exprs: list[PsExpression] = [] - def apply_dtype(self, expr: PsExpression | None, dtype: PsType): - """Applies the given ``dtype`` to the given expression inside this type context. + self._target_type = ( + self._fix_constness(target_type) if target_type is not None else None + ) + + @property + def target_type(self) -> PsType | None: + return self._target_type + + @property + def require_nonconst(self) -> bool: + return self._require_nonconst + + def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None): + """Applies the given ``dtype`` to this type context, and optionally to the given expression. - The given expression will be covered by this type context. If the context's target_type is already known, it must be compatible with the given dtype. If the target type is still unknown, target_type is set to dtype and retroactively applied to all deferred expressions. + + If an expression is specified, it will be covered by the type context. + If the expression already has a data type set, it must be compatible with the target type + and will be replaced by it. """ - dtype = deconstify(dtype) + dtype = self._fix_constness(dtype) if self._target_type is not None and dtype != self._target_type: raise TypificationError( @@ -80,14 +109,7 @@ class TypeContext: self._propagate_target_type() if expr is not None: - if expr.dtype is None: - self._apply_target_type(expr) - elif deconstify(expr.dtype) != self._target_type: - raise TypificationError( - "Type conflict: Predefined expression type did not match the context's target type\n" - f" Expression type: {dtype}\n" - f" Target type: {self._target_type}" - ) + self._apply_target_type(expr) def infer_dtype(self, expr: PsExpression): """Infer the data type for the given expression. @@ -96,7 +118,8 @@ class TypeContext: Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is called on this context. - If the expression already has a data type set, it must be equal to the inferred type. + If the expression already has a data type set, it must be compatible with the target type + and will be replaced by it. """ if self._target_type is None: @@ -113,7 +136,7 @@ class TypeContext: assert self._target_type is not None if expr.dtype is not None: - if deconstify(expr.dtype) != self.target_type: + if not self._compatible(expr.dtype): raise TypificationError( f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n" f" Expression type: {expr.dtype}\n" @@ -128,7 +151,7 @@ class TypeContext: ) if c.dtype is None: expr.constant = c.interpret_as(self._target_type) - elif deconstify(c.dtype) != self._target_type: + elif not self._compatible(c.dtype): raise TypificationError( f"Type mismatch at constant {c}: Constant type did not match the context's target type\n" f" Constant type: {c.dtype}\n" @@ -136,7 +159,13 @@ class TypeContext: ) case PsSymbolExpr(symb): - symb.apply_dtype(self._target_type) + assert symb.dtype is not None + if not self._compatible(symb.dtype): + raise TypificationError( + f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n" + f" Symbol type: {symb.dtype}\n" + f" Target type: {self._target_type}" + ) case ( PsIntDiv() @@ -151,18 +180,42 @@ class TypeContext: f" Expression: {expr}" f" Type Context: {self._target_type}" ) - - expr.dtype = self._target_type # endif + expr.dtype = self._target_type - @property - def target_type(self) -> PsType | None: - return self._target_type + def _compatible(self, dtype: PsType): + """Checks whether the given data type is compatible with the context's target type. + + If the target type is ``const``, they must be equal up to const qualification; + if the target type is not ``const``, `dtype` must match it exactly. + """ + assert self._target_type is not None + if self._target_type.const: + return constify(dtype) == self._target_type + else: + return dtype == self._target_type + + def _fix_constness(self, dtype: PsType, expr: PsExpression | None = None): + if self._require_nonconst: + if dtype.const: + if expr is None: + raise TypificationError( + f"Type mismatch: Encountered {dtype} in non-constant context." + ) + else: + raise TypificationError( + f"Type mismatch at expression {expr}: Encountered {dtype} in non-constant context." + ) + return dtype + else: + return constify(dtype) class Typifier: """Apply data types to expressions. + **Contextual Typing** + The Typifier will traverse the AST and apply a contextual typing scheme to figure out the data types of all encountered expressions. To this end, it covers each expression tree with a set of disjoint typing contexts. @@ -183,6 +236,21 @@ class Typifier: in the context, the expression is deferred by storing it in the context, and will be assigned a type as soon as the target type is fixed. + **Typing Rules** + + The following general rules apply: + + - The context's `default_dtype` is applied to all untyped symbols + - By default, all expressions receive a ``const`` type unless they occur on a (non-declaration) assignment's + left-hand side + + **Typing of symbol expressions** + + Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but + not necessarily their const-qualification. + A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type, + and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`, + but not vice versa. """ def __init__(self, ctx: KernelCreationContext): @@ -213,13 +281,21 @@ class Typifier: for s in statements: self.visit(s) - case PsAssignment(lhs, rhs): + case PsDeclaration(lhs, rhs): tc = TypeContext() # LHS defines target type; type context carries it to RHS self.visit_expr(lhs, tc) assert tc.target_type is not None self.visit_expr(rhs, tc) + case PsAssignment(lhs, rhs): + tc_lhs = TypeContext(require_nonconst=True) + self.visit_expr(lhs, tc_lhs) + assert tc_lhs.target_type is not None + + tc_rhs = TypeContext(tc_lhs.target_type, require_nonconst=False) + self.visit_expr(rhs, tc_rhs) + case PsConditional(cond, branch_true, branch_false): cond_tc = TypeContext(PsBoolType(const=True)) self.visit_expr(cond, cond_tc) @@ -233,10 +309,10 @@ class Typifier: if ctr.symbol.dtype is None: ctr.symbol.apply_dtype(self._ctx.index_dtype) - tc = TypeContext(ctr.symbol.dtype) - self.visit_expr(start, tc) - self.visit_expr(stop, tc) - self.visit_expr(step, tc) + tc_index = TypeContext(ctr.symbol.dtype) + self.visit_expr(start, tc_index) + self.visit_expr(stop, tc_index) + self.visit_expr(step, tc_index) self.visit(body) @@ -244,24 +320,35 @@ class Typifier: raise NotImplementedError(f"Can't typify {node}") def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None: - """Recursive processing of expression nodes""" + """Recursive processing of expression nodes. + + This method opens, expands, and closes typing contexts according to the respective expression's + typing rules. It may add or check restrictions only when opening or closing a type context. + + The actual type inference and checking during context expansion are performed by the methods + of `TypeContext`. ``visit_expr`` tells the typing context how to handle an expression by calling + either ``apply_dtype`` or ``infer_dtype``. + """ match expr: case PsSymbolExpr(_): - if expr.dtype is None: - tc.apply_dtype(expr, self._ctx.default_dtype) - else: - tc.apply_dtype(expr, expr.dtype) + if expr.symbol.dtype is None: + expr.symbol.dtype = self._ctx.default_dtype - case PsConstantExpr(_): - tc.infer_dtype(expr) + tc.apply_dtype(expr.symbol.dtype, expr) + + case PsConstantExpr(c): + if c.dtype is not None: + tc.apply_dtype(c.dtype, expr) + else: + tc.infer_dtype(expr) case PsArrayAccess(bptr, idx): - tc.apply_dtype(expr, bptr.array.element_type) + tc.apply_dtype(bptr.array.element_type, expr) index_tc = TypeContext() self.visit_expr(idx, index_tc) if index_tc.target_type is None: - index_tc.apply_dtype(idx, self._ctx.index_dtype) + index_tc.apply_dtype(self._ctx.index_dtype, idx) elif not isinstance(index_tc.target_type, PsIntegerType): raise TypificationError( f"Array index is not of integer type: {idx} has type {index_tc.target_type}" @@ -276,12 +363,12 @@ class Typifier: "Type of subscript base is not subscriptable." ) - tc.apply_dtype(expr, arr_tc.target_type.base_type) + tc.apply_dtype(arr_tc.target_type.base_type, expr) index_tc = TypeContext() self.visit_expr(idx, index_tc) if index_tc.target_type is None: - index_tc.apply_dtype(idx, self._ctx.index_dtype) + index_tc.apply_dtype(self._ctx.index_dtype, idx) elif not isinstance(index_tc.target_type, PsIntegerType): raise TypificationError( f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}" @@ -296,7 +383,7 @@ class Typifier: "Type of argument to a Deref is not dereferencable" ) - tc.apply_dtype(expr, ptr_tc.target_type.base_type) + tc.apply_dtype(ptr_tc.target_type.base_type, expr) case PsAddressOf(arg): arg_tc = TypeContext() @@ -308,10 +395,11 @@ class Typifier: ) ptr_type = PsPointerType(arg_tc.target_type, True) - tc.apply_dtype(expr, ptr_type) + tc.apply_dtype(ptr_type, expr) case PsLookup(aggr, member_name): - aggr_tc = TypeContext(None) + # Members of a struct type inherit the struct type's `const` qualifier + aggr_tc = TypeContext(None, require_nonconst=tc.require_nonconst) self.visit_expr(aggr, aggr_tc) aggr_type = aggr_tc.target_type @@ -326,7 +414,11 @@ class Typifier: f"Aggregate of type {aggr_type} does not have a member {member}." ) - tc.apply_dtype(expr, member.dtype) + member_type = member.dtype + if aggr_type.const: + member_type = constify(member_type) + + tc.apply_dtype(member_type, expr) case PsBinOp(op1, op2): self.visit_expr(op1, tc) @@ -365,14 +457,14 @@ class Typifier: f"{len(items)} items as {tc.target_type}" ) else: - items_tc.apply_dtype(None, tc.target_type.base_type) + items_tc.apply_dtype(tc.target_type.base_type) else: arr_type = PsArrayType(items_tc.target_type, len(items)) - tc.apply_dtype(expr, arr_type) + tc.apply_dtype(arr_type, expr) case PsCast(dtype, arg): self.visit_expr(arg, TypeContext()) - tc.apply_dtype(expr, dtype) + tc.apply_dtype(dtype, expr) case _: raise NotImplementedError(f"Can't typify {expr}") diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index ef746c614901a2a15a3b63e117d0c7cec61b9676..d9cc5f9cef240e923a8d80a7651947b1ca762ff4 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -6,11 +6,11 @@ from typing import cast from pystencils import Assignment, TypedSymbol, Field, FieldType -from pystencils.backend.ast.structural import PsDeclaration +from pystencils.backend.ast.structural import PsDeclaration, PsAssignment, PsExpression from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp from pystencils.backend.constants import PsConstant from pystencils.types import constify -from pystencils.types.quick import Fp, create_numeric_type +from pystencils.types.quick import Fp, create_type, create_numeric_type from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError @@ -38,7 +38,7 @@ def test_typify_simple(): assert isinstance(fasm, PsDeclaration) def check(expr): - assert expr.dtype == ctx.default_dtype + assert expr.dtype == constify(ctx.default_dtype) match expr: case PsConstantExpr(cs): assert cs.value == 2 @@ -56,6 +56,89 @@ def test_typify_simple(): check(fasm.rhs) +def test_rhs_constness(): + default_type = Fp(32) + ctx = KernelCreationContext(default_dtype=default_type) + + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + f = Field.create_generic( + "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM + ) + f_const = Field.create_generic( + "f_const", + 1, + index_shape=(1,), + dtype=constify(default_type), + field_type=FieldType.CUSTOM, + ) + + x, y, z = sp.symbols("x, y, z") + + # Right-hand sides should always get const types + asm = typify(freeze(Assignment(x, f.absolute_access([0], [0])))) + assert asm.rhs.get_dtype().const + + asm = typify( + freeze( + Assignment( + f.absolute_access([0], [0]), + f.absolute_access([0], [0]) * f_const.absolute_access([0], [0]) * x + y, + ) + ) + ) + assert asm.rhs.get_dtype().const + + +def test_lhs_constness(): + default_type = Fp(32) + ctx = KernelCreationContext(default_dtype=default_type) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + f = Field.create_generic( + "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM + ) + f_const = Field.create_generic( + "f_const", + 1, + index_shape=(1,), + dtype=constify(default_type), + field_type=FieldType.CUSTOM, + ) + + x, y, z = sp.symbols("x, y, z") + + # Assignment RHS may not be const + asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y))) + assert not asm.lhs.get_dtype().const + + # Cannot assign to const left-hand side + with pytest.raises(TypificationError): + _ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y))) + + np_struct = np.dtype([("size", np.uint32), ("data", np.float32)]) + struct_type = constify(create_type(np_struct)) + struct_field = Field.create_generic( + "struct_field", 1, dtype=struct_type, field_type=FieldType.CUSTOM + ) + + with pytest.raises(TypificationError): + _ = typify(freeze(Assignment(struct_field.absolute_access([0], "data"), x))) + + # Const LHS is only OK in declarations + + q = ctx.get_symbol("q", Fp(32, const=True)) + ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q)) + ast = typify(ast) + assert ast.lhs.dtype == Fp(32, const=True) + + ast = PsAssignment(PsExpression.make(q), PsExpression.make(q)) + with pytest.raises(TypificationError): + typify(ast) + + def test_typify_structs(): ctx = KernelCreationContext(default_dtype=Fp(32)) freeze = FreezeExpressions(ctx) @@ -70,6 +153,10 @@ def test_typify_structs(): fasm = freeze(asm) fasm = typify(fasm) + asm = Assignment(f.absolute_access((0,), "data"), x) + fasm = freeze(asm) + fasm = typify(fasm) + # Bad asm = Assignment(x, f.absolute_access((0,), "size")) fasm = freeze(asm) @@ -87,7 +174,7 @@ def test_contextual_typing(): expr = typify(expr) def check(expr): - assert expr.dtype == ctx.default_dtype + assert expr.dtype == constify(ctx.default_dtype) match expr: case PsConstantExpr(cs): assert cs.value in (2, 3, -4) @@ -199,6 +286,6 @@ def test_typify_constant_clones(): expr_clone = expr.clone() expr = typify(expr) - + assert expr_clone.operand1.dtype is None assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None