Skip to content
Snippets Groups Projects
Commit 645b742b authored by Frederik Hennig's avatar Frederik Hennig
Browse files

minor fixes

parent 07af649b
Branches
Tags
No related merge requests found
Pipeline #63782 failed with stages
in 3 minutes and 11 seconds
......@@ -48,10 +48,10 @@ from ..types import (
PsPointerType,
PsIntegerType,
PsUnsignedIntegerType,
PsSignedIntegerType,
)
from .symbols import PsSymbol
from ..defaults import DEFAULTS
class PsLinearizedArray:
......@@ -77,7 +77,7 @@ class PsLinearizedArray:
element_type: PsType,
shape: Sequence[int | EllipsisType],
strides: Sequence[int | EllipsisType],
index_dtype: PsIntegerType = PsSignedIntegerType(64),
index_dtype: PsIntegerType = DEFAULTS.index_dtype,
):
self._name = name
self._element_type = element_type
......
......@@ -3,7 +3,7 @@ from __future__ import annotations
from typing import TypeVar
from .context import KernelCreationContext
from ...types import PsType, PsNumericType, PsStructType, deconstify
from ...types import PsType, PsNumericType, PsStructType, PsIntegerType, deconstify
from ..ast.structural import PsAstNode, PsBlock, PsLoop, PsExpression, PsAssignment
from ..ast.expressions import (
PsSymbolExpr,
......@@ -133,13 +133,6 @@ class Typifier:
self.visit_expr(expr, tc)
return expr
"""
def rec(self, expr: Any, tc: TypeContext) -> ExprOrConstant
All visitor methods take an expression and the current type context.
They shall return the typified expression, or throw `TypificationError` if typification fails.
"""
def visit(self, node: PsAstNode) -> None:
"""Recursive processing of structural nodes"""
match node:
......@@ -185,7 +178,15 @@ class Typifier:
case PsArrayAccess(_, idx):
tc.apply_and_check(expr, expr.dtype)
self.visit_expr(idx, TypeContext(self._ctx.index_dtype))
index_tc = TypeContext()
self.visit_expr(idx, index_tc)
if index_tc.target_type is None:
index_tc.apply_and_check(idx, self._ctx.index_dtype)
elif not isinstance(index_tc.target_type, PsIntegerType):
raise TypificationError(
f"Array index is not of integer type: {idx} has type {index_tc.target_type}"
)
case PsLookup(aggr, member_name):
aggr_tc = TypeContext(None)
......
......@@ -15,11 +15,11 @@ def assumptions_from_dtype(dtype: PsType):
assumptions = dict()
if isinstance(dtype, PsNumericType):
if dtype.is_int:
if dtype.is_int():
assumptions.update({"integer": True})
if dtype.is_uint:
if dtype.is_uint():
assumptions.update({"negative": False})
if dtype.is_int or dtype.is_float:
if dtype.is_int() or dtype.is_float():
assumptions.update({"real": True})
return assumptions
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment