From 645b742b60935fba0642de3c9fb9bd0b186ef401 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 6 Mar 2024 21:01:50 +0100 Subject: [PATCH] minor fixes --- src/pystencils/backend/arrays.py | 4 ++-- .../backend/kernelcreation/typification.py | 19 ++++++++++--------- src/pystencils/sympyextensions/typed_sympy.py | 6 +++--- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py index b9b8b4cbd..16203ac1c 100644 --- a/src/pystencils/backend/arrays.py +++ b/src/pystencils/backend/arrays.py @@ -48,10 +48,10 @@ from ..types import ( PsPointerType, PsIntegerType, PsUnsignedIntegerType, - PsSignedIntegerType, ) from .symbols import PsSymbol +from ..defaults import DEFAULTS class PsLinearizedArray: @@ -77,7 +77,7 @@ class PsLinearizedArray: element_type: PsType, shape: Sequence[int | EllipsisType], strides: Sequence[int | EllipsisType], - index_dtype: PsIntegerType = PsSignedIntegerType(64), + index_dtype: PsIntegerType = DEFAULTS.index_dtype, ): self._name = name self._element_type = element_type diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 5085dccfb..48e4e6aa0 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import TypeVar from .context import KernelCreationContext -from ...types import PsType, PsNumericType, PsStructType, deconstify +from ...types import PsType, PsNumericType, PsStructType, PsIntegerType, deconstify from ..ast.structural import PsAstNode, PsBlock, PsLoop, PsExpression, PsAssignment from ..ast.expressions import ( PsSymbolExpr, @@ -133,13 +133,6 @@ class Typifier: self.visit_expr(expr, tc) return expr - """ - def rec(self, expr: Any, tc: TypeContext) -> ExprOrConstant - - All visitor methods take an expression and the current type context. - They shall return the typified expression, or throw `TypificationError` if typification fails. - """ - def visit(self, node: PsAstNode) -> None: """Recursive processing of structural nodes""" match node: @@ -185,7 +178,15 @@ class Typifier: case PsArrayAccess(_, idx): tc.apply_and_check(expr, expr.dtype) - self.visit_expr(idx, TypeContext(self._ctx.index_dtype)) + + index_tc = TypeContext() + self.visit_expr(idx, index_tc) + if index_tc.target_type is None: + index_tc.apply_and_check(idx, self._ctx.index_dtype) + elif not isinstance(index_tc.target_type, PsIntegerType): + raise TypificationError( + f"Array index is not of integer type: {idx} has type {index_tc.target_type}" + ) case PsLookup(aggr, member_name): aggr_tc = TypeContext(None) diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index 82af3d5e4..accea6d40 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -15,11 +15,11 @@ def assumptions_from_dtype(dtype: PsType): assumptions = dict() if isinstance(dtype, PsNumericType): - if dtype.is_int: + if dtype.is_int(): assumptions.update({"integer": True}) - if dtype.is_uint: + if dtype.is_uint(): assumptions.update({"negative": False}) - if dtype.is_int or dtype.is_float: + if dtype.is_int() or dtype.is_float(): assumptions.update({"real": True}) return assumptions -- GitLab