diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py index b9b8b4cbd15683243a2d787b043e2f50e0248afe..16203ac1cffc467d0e8b4f5f56a4b75ae0aaa791 100644 --- a/src/pystencils/backend/arrays.py +++ b/src/pystencils/backend/arrays.py @@ -48,10 +48,10 @@ from ..types import ( PsPointerType, PsIntegerType, PsUnsignedIntegerType, - PsSignedIntegerType, ) from .symbols import PsSymbol +from ..defaults import DEFAULTS class PsLinearizedArray: @@ -77,7 +77,7 @@ class PsLinearizedArray: element_type: PsType, shape: Sequence[int | EllipsisType], strides: Sequence[int | EllipsisType], - index_dtype: PsIntegerType = PsSignedIntegerType(64), + index_dtype: PsIntegerType = DEFAULTS.index_dtype, ): self._name = name self._element_type = element_type diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 5085dccfb1cd933f47ca4d9a7b4d2e71fbd9b3ba..48e4e6aa0fd4fa7d31b37af420acb1f718b514a8 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import TypeVar from .context import KernelCreationContext -from ...types import PsType, PsNumericType, PsStructType, deconstify +from ...types import PsType, PsNumericType, PsStructType, PsIntegerType, deconstify from ..ast.structural import PsAstNode, PsBlock, PsLoop, PsExpression, PsAssignment from ..ast.expressions import ( PsSymbolExpr, @@ -133,13 +133,6 @@ class Typifier: self.visit_expr(expr, tc) return expr - """ - def rec(self, expr: Any, tc: TypeContext) -> ExprOrConstant - - All visitor methods take an expression and the current type context. - They shall return the typified expression, or throw `TypificationError` if typification fails. - """ - def visit(self, node: PsAstNode) -> None: """Recursive processing of structural nodes""" match node: @@ -185,7 +178,15 @@ class Typifier: case PsArrayAccess(_, idx): tc.apply_and_check(expr, expr.dtype) - self.visit_expr(idx, TypeContext(self._ctx.index_dtype)) + + index_tc = TypeContext() + self.visit_expr(idx, index_tc) + if index_tc.target_type is None: + index_tc.apply_and_check(idx, self._ctx.index_dtype) + elif not isinstance(index_tc.target_type, PsIntegerType): + raise TypificationError( + f"Array index is not of integer type: {idx} has type {index_tc.target_type}" + ) case PsLookup(aggr, member_name): aggr_tc = TypeContext(None) diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index 82af3d5e46136796db96188c69d4de8d325aa607..accea6d40b596326edecfd54938876f67df88991 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -15,11 +15,11 @@ def assumptions_from_dtype(dtype: PsType): assumptions = dict() if isinstance(dtype, PsNumericType): - if dtype.is_int: + if dtype.is_int(): assumptions.update({"integer": True}) - if dtype.is_uint: + if dtype.is_uint(): assumptions.update({"negative": False}) - if dtype.is_int or dtype.is_float: + if dtype.is_int() or dtype.is_float(): assumptions.update({"real": True}) return assumptions