From ef543c5edc099a3aca59f8886ac42948d920f9b3 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 19 Dec 2019 09:44:29 +0100
Subject: [PATCH] Fix
 AssignmentCollection.{free_symbols,bound_symbols,defined_symbols} for
 non-Assignments

---
 pystencils/astnodes.py                        | 14 +++++++---
 pystencils/simp/assignment_collection.py      | 27 ++++++++++++++++---
 .../test_assignment_collection.py             | 13 +++++++++
 3 files changed, 47 insertions(+), 7 deletions(-)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index cc58ebd30..a9c7c98f2 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -5,6 +5,7 @@ from typing import Any, List, Optional, Sequence, Set, Union
 
 import sympy as sp
 
+import pystencils
 from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
 from pystencils.field import Field
 from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
@@ -353,7 +354,10 @@ class Block(Node):
     def symbols_defined(self):
         result = set()
         for a in self.args:
-            result.update(a.symbols_defined)
+            if isinstance(a, pystencils.Assignment):
+                result.update(a.free_symbols)
+            else:
+                result.update(a.symbols_defined)
         return result
 
     @property
@@ -361,8 +365,12 @@ class Block(Node):
         result = set()
         defined_symbols = set()
         for a in self.args:
-            result.update(a.undefined_symbols)
-            defined_symbols.update(a.symbols_defined)
+            if isinstance(a, pystencils.Assignment):
+                result.update(a.free_symbols)
+                defined_symbols.update({a.lhs})
+            else:
+                result.update(a.undefined_symbols)
+                defined_symbols.update(a.symbols_defined)
         return result - defined_symbols
 
     def __str__(self):
diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py
index d874965c1..3a8bb2bd3 100644
--- a/pystencils/simp/assignment_collection.py
+++ b/pystencils/simp/assignment_collection.py
@@ -3,6 +3,7 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set,
 
 import sympy as sp
 
+import pystencils
 from pystencils.assignment import Assignment
 from pystencils.simp.simplifications import (
     sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
@@ -100,15 +101,29 @@ class AssignmentCollection:
         """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
         free_symbols = set()
         for eq in self.all_assignments:
-            free_symbols.update(eq.rhs.atoms(sp.Symbol))
+            if isinstance(eq, Assignment):
+                free_symbols.update(eq.rhs.atoms(sp.Symbol))
+            elif isinstance(eq, pystencils.astnodes.Node):
+                free_symbols.update(eq.undefined_symbols)
+
         return free_symbols - self.bound_symbols
 
     @property
     def bound_symbols(self) -> Set[sp.Symbol]:
         """All symbols which occur on the left hand side of a main assignment or a subexpression."""
-        bound_symbols_set = set([eq.lhs for eq in self.all_assignments])
-        assert len(bound_symbols_set) == len(self.subexpressions) + len(self.main_assignments), \
+        bound_symbols_set = set(
+            [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]
+        )
+
+        assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
             "Not in SSA form - same symbol assigned multiple times"
+
+        bound_symbols_set = bound_symbols_set.union(*[
+            assignment.symbols_defined for assignment in self.all_assignments
+            if isinstance(assignment, pystencils.astnodes.Node)
+        ]
+        )
+
         return bound_symbols_set
 
     @property
@@ -124,7 +139,11 @@ class AssignmentCollection:
     @property
     def defined_symbols(self) -> Set[sp.Symbol]:
         """All symbols which occur as left-hand-sides of one of the main equations"""
-        return set([assignment.lhs for assignment in self.main_assignments])
+        return (set(
+            [assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)]
+        ).union(*[assignment.symbols_defined for assignment in self.main_assignments if isinstance(
+                assignment, pystencils.astnodes.Node)]
+                ))
 
     @property
     def operation_count(self):
diff --git a/pystencils_tests/test_assignment_collection.py b/pystencils_tests/test_assignment_collection.py
index 1a8f6f9dc..42e3791ab 100644
--- a/pystencils_tests/test_assignment_collection.py
+++ b/pystencils_tests/test_assignment_collection.py
@@ -1,6 +1,7 @@
 import sympy as sp
 
 from pystencils import Assignment, AssignmentCollection
+from pystencils.astnodes import Conditional
 from pystencils.simp.assignment_collection import SymbolGen
 
 
@@ -27,3 +28,15 @@ def test_assignment_collection():
 
     assert 'a_0' in str(ac_inserted)
     assert '<table' in ac_inserted._repr_html_()
+
+
+def test_free_and_defined_symbols():
+    x, y, z, t = sp.symbols("x y z t")
+    a, b = sp.symbols("a b")
+    symbol_gen = SymbolGen("a")
+
+    ac = AssignmentCollection([Assignment(z, x + y), Conditional(t > 0, Assignment(a, b+1), Assignment(a, b+2))],
+                              [], subexpression_symbol_generator=symbol_gen)
+
+    print(ac)
+    print(ac.__repr__)
-- 
GitLab