From 1edb192f6a791bddbe343c3268c8dd7b2919549f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de>
Date: Wed, 4 May 2022 10:47:01 +0200
Subject: [PATCH] Introduced a converting function form AssignmentCollection to
 NodeCollection.

---
 pystencils/node_collection.py           |  7 ++++++-
 pystencils_tests/test_nodecollection.py | 13 +++++++++++++
 2 files changed, 19 insertions(+), 1 deletion(-)
 create mode 100644 pystencils_tests/test_nodecollection.py

diff --git a/pystencils/node_collection.py b/pystencils/node_collection.py
index 8e396cc38..b393df0ed 100644
--- a/pystencils/node_collection.py
+++ b/pystencils/node_collection.py
@@ -6,9 +6,10 @@ import sympy as sp
 from sympy.codegen import Assignment
 from sympy.codegen.rewriting import ReplaceOptim, optimize
 
-from pystencils.astnodes import Block, Node
+from pystencils.astnodes import Block, Node, SympyAssignment
 from pystencils.backends.cbackend import CustomCodeNode
 from pystencils.functions import DivFunc
+from pystencils.simp import AssignmentCollection
 
 
 class NodeCollection:
@@ -28,6 +29,10 @@ class NodeCollection:
 
         self.simplification_hints = {}
 
+    @staticmethod
+    def from_assignment_collection(assignment_collection: AssignmentCollection):
+        return NodeCollection([SympyAssignment(a.lhs, a.rhs) for a in assignment_collection.all_assignments])
+
     def evaluate_terms(self):
         evaluate_constant_terms = ReplaceOptim(
             lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
diff --git a/pystencils_tests/test_nodecollection.py b/pystencils_tests/test_nodecollection.py
new file mode 100644
index 000000000..ab24e58e7
--- /dev/null
+++ b/pystencils_tests/test_nodecollection.py
@@ -0,0 +1,13 @@
+import sympy as sp
+
+from pystencils import AssignmentCollection, Assignment
+from pystencils.node_collection import NodeCollection
+from pystencils.astnodes import SympyAssignment
+
+
+def test_node_collection_from_assignment_collection():
+    x = sp.symbols('x')
+    assignment_collection = AssignmentCollection([Assignment(x, 2)])
+    node_collection = NodeCollection.from_assignment_collection(assignment_collection)
+
+    assert node_collection.all_assignments[0] == SympyAssignment(x, 2)
-- 
GitLab