Skip to content
Snippets Groups Projects
Commit 994bcc76 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

rename PsTypedSymbol -> PsTypedVariable

parent 9f4d89dc
No related merge requests found
......@@ -5,7 +5,7 @@ from abc import ABC
import pymbolic.primitives as pb
from ..typed_expressions import PsTypedSymbol, PsLvalue
from ..typed_expressions import PsTypedVariable, PsLvalue
class PsAstNode(ABC):
......@@ -77,8 +77,8 @@ class PsLvalueExpr(PsExpression):
class PsSymbolExpr(PsLvalueExpr):
"""Wrapper around PsTypedSymbols"""
def __init__(self, symbol: PsTypedSymbol):
if not isinstance(symbol, PsTypedSymbol):
def __init__(self, symbol: PsTypedVariable):
if not isinstance(symbol, PsTypedVariable):
raise TypeError("Not a symbol!")
super(PsLvalueExpr, self).__init__(symbol)
......
......@@ -5,7 +5,7 @@ from typing import Dict
from pymbolic.primitives import Expression
from pymbolic.mapper.substitutor import CachedSubstitutionMapper
from ..typed_expressions import PsTypedSymbol
from ..typed_expressions import PsTypedVariable
from .dispatcher import ast_visitor
from .nodes import PsAstNode, PsAssignment, PsLoop, PsExpression
......@@ -21,7 +21,7 @@ class PsAstTransformer(ABC):
class PsSymbolsSubstitutor(PsAstTransformer):
def __init__(self, subs_dict: Dict[PsTypedSymbol, Expression]):
def __init__(self, subs_dict: Dict[PsTypedVariable, Expression]):
self._subs_dict = subs_dict
self._mapper = CachedSubstitutionMapper(lambda s: self._subs_dict.get(s, None))
......@@ -33,7 +33,7 @@ class PsSymbolsSubstitutor(PsAstTransformer):
@visit.case(PsAssignment)
def assignment(self, asm: PsAssignment):
lhs_expr = asm.lhs.expression
if isinstance(lhs_expr, PsTypedSymbol) and lhs_expr in self._subs_dict:
if isinstance(lhs_expr, PsTypedVariable) and lhs_expr in self._subs_dict:
raise ValueError(f"Cannot substitute symbol {lhs_expr} that occurs on a left-hand side of an assignment.")
self.transform_children(asm)
return asm
......
......@@ -7,9 +7,9 @@ import pymbolic.primitives as pb
from ..typing import AbstractType, BasicType
class PsTypedSymbol(pb.Variable):
class PsTypedVariable(pb.Variable):
def __init__(self, name: str, dtype: AbstractType):
super(PsTypedSymbol, self).__init__(name)
super(PsTypedVariable, self).__init__(name)
self._dtype = dtype
@property
......@@ -17,7 +17,7 @@ class PsTypedSymbol(pb.Variable):
return self._dtype
class PsArrayBasePointer(PsTypedSymbol):
class PsArrayBasePointer(PsTypedVariable):
def __init__(self, name: str, base_type: AbstractType):
super(PsArrayBasePointer, self).__init__(name, base_type)
......@@ -27,7 +27,7 @@ class PsArrayAccess(pb.Subscript):
super(PsArrayAccess, self).__init__(base_ptr, index)
PsLvalue: TypeAlias = Union[PsTypedSymbol, PsArrayAccess]
PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess]
class PsTypedConstant:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment