From 9a20fdcd8e176f644cbc835b923acbecce61ea65 Mon Sep 17 00:00:00 2001
From: Daniel Bauer <daniel.j.bauer@fau.de>
Date: Wed, 22 May 2024 09:56:51 +0200
Subject: [PATCH] Add support for PsConditional to UndefinedSymbolsCollector
 and PsStatement to CanonicalClone

---
 src/pystencils/backend/ast/analysis.py                    | 8 ++++++++
 src/pystencils/backend/transformations/canonical_clone.py | 4 ++++
 2 files changed, 12 insertions(+)

diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py
index 0ea13c563..040c61678 100644
--- a/src/pystencils/backend/ast/analysis.py
+++ b/src/pystencils/backend/ast/analysis.py
@@ -6,6 +6,7 @@ from .structural import (
     PsAstNode,
     PsBlock,
     PsComment,
+    PsConditional,
     PsDeclaration,
     PsExpression,
     PsLoop,
@@ -56,6 +57,12 @@ class UndefinedSymbolsCollector:
                 undefined_vars.discard(ctr.symbol)
                 return undefined_vars
 
+            case PsConditional(cond, branch_true, branch_false):
+                undefined_vars = self(cond) | self(branch_true)
+                if branch_false is not None:
+                    undefined_vars |= self(branch_false)
+                return undefined_vars
+
             case PsComment():
                 return set()
 
@@ -86,6 +93,7 @@ class UndefinedSymbolsCollector:
                 PsAssignment()
                 | PsBlock()
                 | PsComment()
+                | PsConditional()
                 | PsExpression()
                 | PsLoop()
                 | PsStatement()
diff --git a/src/pystencils/backend/transformations/canonical_clone.py b/src/pystencils/backend/transformations/canonical_clone.py
index 538bb2779..7c040d304 100644
--- a/src/pystencils/backend/transformations/canonical_clone.py
+++ b/src/pystencils/backend/transformations/canonical_clone.py
@@ -12,6 +12,7 @@ from ..ast.structural import (
     PsDeclaration,
     PsAssignment,
     PsComment,
+    PsStatement,
 )
 from ..ast.expressions import PsExpression, PsSymbolExpr
 
@@ -99,6 +100,9 @@ class CanonicalClone:
                 self._replace_symbols(expr_clone, cc)
                 return cast(Node_T, expr_clone)
 
+            case PsStatement(expr):
+                return cast(Node_T, PsStatement(self.visit(expr, cc)))
+
             case _:
                 raise PsInternalCompilerError(
                     f"Don't know how to canonically clone {type(node)}"
-- 
GitLab