From e1c2463c70fed9de4cc1351df46ebd665a92d0cb Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 13 Mar 2024 19:51:09 +0100
Subject: [PATCH] fix: Lookups and Derefs are now Lvalues

---
 src/pystencils/backend/ast/expressions.py     | 16 +++++---------
 src/pystencils/backend/ast/structural.py      | 22 +++++++++++--------
 .../backend/kernelcreation/freeze.py          |  6 +++--
 3 files changed, 23 insertions(+), 21 deletions(-)

diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index 7c24b55b2..73f34a1fd 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -60,15 +60,11 @@ class PsExpression(PsAstNode, ABC):
         pass
 
 
-class PsLvalueExpr(PsExpression, ABC):
-    """Base class for all expressions that may occur as an lvalue"""
-
-    @abstractmethod
-    def clone(self) -> PsLvalueExpr:
-        pass
+class PsLvalue(ABC):
+    """Mix-in for all expressions that may occur as an lvalue"""
 
 
-class PsSymbolExpr(PsLeafMixIn, PsLvalueExpr):
+class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression):
     """A single symbol as an expression."""
 
     __match_args__ = ("symbol",)
@@ -124,7 +120,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
         return f"Constant({repr(self._constant)})"
 
 
-class PsSubscript(PsLvalueExpr):
+class PsSubscript(PsLvalue, PsExpression):
     __match_args__ = ("base", "index")
 
     def __init__(self, base: PsExpression, index: PsExpression):
@@ -271,7 +267,7 @@ class PsVectorArrayAccess(PsArrayAccess):
         )
 
 
-class PsLookup(PsExpression):
+class PsLookup(PsExpression, PsLvalue):
     __match_args__ = ("aggregate", "member_name")
 
     def __init__(self, aggregate: PsExpression, member_name: str) -> None:
@@ -384,7 +380,7 @@ class PsNeg(PsUnOp):
         return operator.neg
 
 
-class PsDeref(PsUnOp):
+class PsDeref(PsLvalue, PsUnOp):
     pass
 
 
diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index e5b88891c..441faa606 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -3,7 +3,7 @@ from typing import Sequence, cast
 from types import NoneType
 
 from .astnode import PsAstNode, PsLeafMixIn
-from .expressions import PsExpression, PsLvalueExpr, PsSymbolExpr
+from .expressions import PsExpression, PsLvalue, PsSymbolExpr
 
 from .util import failing_cast
 
@@ -76,16 +76,20 @@ class PsAssignment(PsAstNode):
         "rhs",
     )
 
-    def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression):
-        self._lhs = lhs
+    def __init__(self, lhs: PsExpression, rhs: PsExpression):
+        if not isinstance(lhs, PsLvalue):
+            raise ValueError("Assignment LHS must be an lvalue")
+        self._lhs: PsExpression = lhs
         self._rhs = rhs
 
     @property
-    def lhs(self) -> PsLvalueExpr:
+    def lhs(self) -> PsExpression:
         return self._lhs
 
     @lhs.setter
-    def lhs(self, lvalue: PsLvalueExpr):
+    def lhs(self, lvalue: PsExpression):
+        if not isinstance(lvalue, PsLvalue):
+            raise ValueError("Assignment LHS must be an lvalue")
         self._lhs = lvalue
 
     @property
@@ -105,7 +109,7 @@ class PsAssignment(PsAstNode):
     def set_child(self, idx: int, c: PsAstNode):
         idx = [0, 1][idx]  # trick to normalize index
         if idx == 0:
-            self._lhs = failing_cast(PsLvalueExpr, c)
+            self.lhs = failing_cast(PsExpression, c)
         elif idx == 1:
             self._rhs = failing_cast(PsExpression, c)
         else:
@@ -125,11 +129,11 @@ class PsDeclaration(PsAssignment):
         super().__init__(lhs, rhs)
 
     @property
-    def lhs(self) -> PsLvalueExpr:
+    def lhs(self) -> PsExpression:
         return self._lhs
 
     @lhs.setter
-    def lhs(self, lvalue: PsLvalueExpr):
+    def lhs(self, lvalue: PsExpression):
         self._lhs = failing_cast(PsSymbolExpr, lvalue)
 
     @property
@@ -146,7 +150,7 @@ class PsDeclaration(PsAssignment):
     def set_child(self, idx: int, c: PsAstNode):
         idx = [0, 1][idx]  # trick to normalize index
         if idx == 0:
-            self._lhs = failing_cast(PsSymbolExpr, c)
+            self.lhs = failing_cast(PsSymbolExpr, c)
         elif idx == 1:
             self._rhs = failing_cast(PsExpression, c)
         else:
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index ecdcf2f94..5d07b9e71 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -132,10 +132,12 @@ class FreezeExpressions:
 
         if isinstance(lhs, PsSymbolExpr):
             return PsDeclaration(lhs, rhs)
-        elif isinstance(lhs, (PsArrayAccess, PsVectorArrayAccess)):  # todo
+        elif isinstance(lhs, (PsArrayAccess, PsLookup, PsVectorArrayAccess)):  # todo
             return PsAssignment(lhs, rhs)
         else:
-            assert False, "That should not have happened."
+            raise FreezeError(
+                f"Encountered unsupported expression on assignment left-hand side: {lhs}"
+            )
 
     def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
         symb = self._ctx.get_symbol(spsym.name)
-- 
GitLab