From 645b742b60935fba0642de3c9fb9bd0b186ef401 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 6 Mar 2024 21:01:50 +0100
Subject: [PATCH] minor fixes

---
 src/pystencils/backend/arrays.py              |  4 ++--
 .../backend/kernelcreation/typification.py    | 19 ++++++++++---------
 src/pystencils/sympyextensions/typed_sympy.py |  6 +++---
 3 files changed, 15 insertions(+), 14 deletions(-)

diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py
index b9b8b4cbd..16203ac1c 100644
--- a/src/pystencils/backend/arrays.py
+++ b/src/pystencils/backend/arrays.py
@@ -48,10 +48,10 @@ from ..types import (
     PsPointerType,
     PsIntegerType,
     PsUnsignedIntegerType,
-    PsSignedIntegerType,
 )
 
 from .symbols import PsSymbol
+from ..defaults import DEFAULTS
 
 
 class PsLinearizedArray:
@@ -77,7 +77,7 @@ class PsLinearizedArray:
         element_type: PsType,
         shape: Sequence[int | EllipsisType],
         strides: Sequence[int | EllipsisType],
-        index_dtype: PsIntegerType = PsSignedIntegerType(64),
+        index_dtype: PsIntegerType = DEFAULTS.index_dtype,
     ):
         self._name = name
         self._element_type = element_type
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 5085dccfb..48e4e6aa0 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -3,7 +3,7 @@ from __future__ import annotations
 from typing import TypeVar
 
 from .context import KernelCreationContext
-from ...types import PsType, PsNumericType, PsStructType, deconstify
+from ...types import PsType, PsNumericType, PsStructType, PsIntegerType, deconstify
 from ..ast.structural import PsAstNode, PsBlock, PsLoop, PsExpression, PsAssignment
 from ..ast.expressions import (
     PsSymbolExpr,
@@ -133,13 +133,6 @@ class Typifier:
         self.visit_expr(expr, tc)
         return expr
 
-    """
-    def rec(self, expr: Any, tc: TypeContext) -> ExprOrConstant
-
-    All visitor methods take an expression and the current type context.
-    They shall return the typified expression, or throw `TypificationError` if typification fails.
-    """
-
     def visit(self, node: PsAstNode) -> None:
         """Recursive processing of structural nodes"""
         match node:
@@ -185,7 +178,15 @@ class Typifier:
 
             case PsArrayAccess(_, idx):
                 tc.apply_and_check(expr, expr.dtype)
-                self.visit_expr(idx, TypeContext(self._ctx.index_dtype))
+                
+                index_tc = TypeContext()
+                self.visit_expr(idx, index_tc)
+                if index_tc.target_type is None:
+                    index_tc.apply_and_check(idx, self._ctx.index_dtype)
+                elif not isinstance(index_tc.target_type, PsIntegerType):
+                    raise TypificationError(
+                        f"Array index is not of integer type: {idx} has type {index_tc.target_type}"
+                    )
 
             case PsLookup(aggr, member_name):
                 aggr_tc = TypeContext(None)
diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py
index 82af3d5e4..accea6d40 100644
--- a/src/pystencils/sympyextensions/typed_sympy.py
+++ b/src/pystencils/sympyextensions/typed_sympy.py
@@ -15,11 +15,11 @@ def assumptions_from_dtype(dtype: PsType):
     assumptions = dict()
 
     if isinstance(dtype, PsNumericType):
-        if dtype.is_int:
+        if dtype.is_int():
             assumptions.update({"integer": True})
-        if dtype.is_uint:
+        if dtype.is_uint():
             assumptions.update({"negative": False})
-        if dtype.is_int or dtype.is_float:
+        if dtype.is_int() or dtype.is_float():
             assumptions.update({"real": True})
 
     return assumptions
-- 
GitLab