From e1c2463c70fed9de4cc1351df46ebd665a92d0cb Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 13 Mar 2024 19:51:09 +0100 Subject: [PATCH] fix: Lookups and Derefs are now Lvalues --- src/pystencils/backend/ast/expressions.py | 16 +++++--------- src/pystencils/backend/ast/structural.py | 22 +++++++++++-------- .../backend/kernelcreation/freeze.py | 6 +++-- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 7c24b55b2..73f34a1fd 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -60,15 +60,11 @@ class PsExpression(PsAstNode, ABC): pass -class PsLvalueExpr(PsExpression, ABC): - """Base class for all expressions that may occur as an lvalue""" - - @abstractmethod - def clone(self) -> PsLvalueExpr: - pass +class PsLvalue(ABC): + """Mix-in for all expressions that may occur as an lvalue""" -class PsSymbolExpr(PsLeafMixIn, PsLvalueExpr): +class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression): """A single symbol as an expression.""" __match_args__ = ("symbol",) @@ -124,7 +120,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): return f"Constant({repr(self._constant)})" -class PsSubscript(PsLvalueExpr): +class PsSubscript(PsLvalue, PsExpression): __match_args__ = ("base", "index") def __init__(self, base: PsExpression, index: PsExpression): @@ -271,7 +267,7 @@ class PsVectorArrayAccess(PsArrayAccess): ) -class PsLookup(PsExpression): +class PsLookup(PsExpression, PsLvalue): __match_args__ = ("aggregate", "member_name") def __init__(self, aggregate: PsExpression, member_name: str) -> None: @@ -384,7 +380,7 @@ class PsNeg(PsUnOp): return operator.neg -class PsDeref(PsUnOp): +class PsDeref(PsLvalue, PsUnOp): pass diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index e5b88891c..441faa606 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -3,7 +3,7 @@ from typing import Sequence, cast from types import NoneType from .astnode import PsAstNode, PsLeafMixIn -from .expressions import PsExpression, PsLvalueExpr, PsSymbolExpr +from .expressions import PsExpression, PsLvalue, PsSymbolExpr from .util import failing_cast @@ -76,16 +76,20 @@ class PsAssignment(PsAstNode): "rhs", ) - def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression): - self._lhs = lhs + def __init__(self, lhs: PsExpression, rhs: PsExpression): + if not isinstance(lhs, PsLvalue): + raise ValueError("Assignment LHS must be an lvalue") + self._lhs: PsExpression = lhs self._rhs = rhs @property - def lhs(self) -> PsLvalueExpr: + def lhs(self) -> PsExpression: return self._lhs @lhs.setter - def lhs(self, lvalue: PsLvalueExpr): + def lhs(self, lvalue: PsExpression): + if not isinstance(lvalue, PsLvalue): + raise ValueError("Assignment LHS must be an lvalue") self._lhs = lvalue @property @@ -105,7 +109,7 @@ class PsAssignment(PsAstNode): def set_child(self, idx: int, c: PsAstNode): idx = [0, 1][idx] # trick to normalize index if idx == 0: - self._lhs = failing_cast(PsLvalueExpr, c) + self.lhs = failing_cast(PsExpression, c) elif idx == 1: self._rhs = failing_cast(PsExpression, c) else: @@ -125,11 +129,11 @@ class PsDeclaration(PsAssignment): super().__init__(lhs, rhs) @property - def lhs(self) -> PsLvalueExpr: + def lhs(self) -> PsExpression: return self._lhs @lhs.setter - def lhs(self, lvalue: PsLvalueExpr): + def lhs(self, lvalue: PsExpression): self._lhs = failing_cast(PsSymbolExpr, lvalue) @property @@ -146,7 +150,7 @@ class PsDeclaration(PsAssignment): def set_child(self, idx: int, c: PsAstNode): idx = [0, 1][idx] # trick to normalize index if idx == 0: - self._lhs = failing_cast(PsSymbolExpr, c) + self.lhs = failing_cast(PsSymbolExpr, c) elif idx == 1: self._rhs = failing_cast(PsExpression, c) else: diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index ecdcf2f94..5d07b9e71 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -132,10 +132,12 @@ class FreezeExpressions: if isinstance(lhs, PsSymbolExpr): return PsDeclaration(lhs, rhs) - elif isinstance(lhs, (PsArrayAccess, PsVectorArrayAccess)): # todo + elif isinstance(lhs, (PsArrayAccess, PsLookup, PsVectorArrayAccess)): # todo return PsAssignment(lhs, rhs) else: - assert False, "That should not have happened." + raise FreezeError( + f"Encountered unsupported expression on assignment left-hand side: {lhs}" + ) def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr: symb = self._ctx.get_symbol(spsym.name) -- GitLab