Skip to content
Snippets Groups Projects
Commit 645b742b authored by Frederik Hennig's avatar Frederik Hennig
Browse files

minor fixes

parent 07af649b
No related merge requests found
Pipeline #63782 failed with stages
in 3 minutes and 11 seconds
...@@ -48,10 +48,10 @@ from ..types import ( ...@@ -48,10 +48,10 @@ from ..types import (
PsPointerType, PsPointerType,
PsIntegerType, PsIntegerType,
PsUnsignedIntegerType, PsUnsignedIntegerType,
PsSignedIntegerType,
) )
from .symbols import PsSymbol from .symbols import PsSymbol
from ..defaults import DEFAULTS
class PsLinearizedArray: class PsLinearizedArray:
...@@ -77,7 +77,7 @@ class PsLinearizedArray: ...@@ -77,7 +77,7 @@ class PsLinearizedArray:
element_type: PsType, element_type: PsType,
shape: Sequence[int | EllipsisType], shape: Sequence[int | EllipsisType],
strides: Sequence[int | EllipsisType], strides: Sequence[int | EllipsisType],
index_dtype: PsIntegerType = PsSignedIntegerType(64), index_dtype: PsIntegerType = DEFAULTS.index_dtype,
): ):
self._name = name self._name = name
self._element_type = element_type self._element_type = element_type
......
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
from typing import TypeVar from typing import TypeVar
from .context import KernelCreationContext from .context import KernelCreationContext
from ...types import PsType, PsNumericType, PsStructType, deconstify from ...types import PsType, PsNumericType, PsStructType, PsIntegerType, deconstify
from ..ast.structural import PsAstNode, PsBlock, PsLoop, PsExpression, PsAssignment from ..ast.structural import PsAstNode, PsBlock, PsLoop, PsExpression, PsAssignment
from ..ast.expressions import ( from ..ast.expressions import (
PsSymbolExpr, PsSymbolExpr,
...@@ -133,13 +133,6 @@ class Typifier: ...@@ -133,13 +133,6 @@ class Typifier:
self.visit_expr(expr, tc) self.visit_expr(expr, tc)
return expr return expr
"""
def rec(self, expr: Any, tc: TypeContext) -> ExprOrConstant
All visitor methods take an expression and the current type context.
They shall return the typified expression, or throw `TypificationError` if typification fails.
"""
def visit(self, node: PsAstNode) -> None: def visit(self, node: PsAstNode) -> None:
"""Recursive processing of structural nodes""" """Recursive processing of structural nodes"""
match node: match node:
...@@ -185,7 +178,15 @@ class Typifier: ...@@ -185,7 +178,15 @@ class Typifier:
case PsArrayAccess(_, idx): case PsArrayAccess(_, idx):
tc.apply_and_check(expr, expr.dtype) tc.apply_and_check(expr, expr.dtype)
self.visit_expr(idx, TypeContext(self._ctx.index_dtype))
index_tc = TypeContext()
self.visit_expr(idx, index_tc)
if index_tc.target_type is None:
index_tc.apply_and_check(idx, self._ctx.index_dtype)
elif not isinstance(index_tc.target_type, PsIntegerType):
raise TypificationError(
f"Array index is not of integer type: {idx} has type {index_tc.target_type}"
)
case PsLookup(aggr, member_name): case PsLookup(aggr, member_name):
aggr_tc = TypeContext(None) aggr_tc = TypeContext(None)
......
...@@ -15,11 +15,11 @@ def assumptions_from_dtype(dtype: PsType): ...@@ -15,11 +15,11 @@ def assumptions_from_dtype(dtype: PsType):
assumptions = dict() assumptions = dict()
if isinstance(dtype, PsNumericType): if isinstance(dtype, PsNumericType):
if dtype.is_int: if dtype.is_int():
assumptions.update({"integer": True}) assumptions.update({"integer": True})
if dtype.is_uint: if dtype.is_uint():
assumptions.update({"negative": False}) assumptions.update({"negative": False})
if dtype.is_int or dtype.is_float: if dtype.is_int() or dtype.is_float():
assumptions.update({"real": True}) assumptions.update({"real": True})
return assumptions return assumptions
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment