From 6723d8435686e30ddea9c07c9a1a14ba6fc558eb Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 16 Nov 2018 10:29:02 +0100
Subject: [PATCH] Make AssignmentCollection constructible/convertible from/to
 dict

---
 simp/assignment_collection.py | 36 ++++++++++++++++++++++++++++++-----
 1 file changed, 31 insertions(+), 5 deletions(-)

diff --git a/simp/assignment_collection.py b/simp/assignment_collection.py
index b86960c40..1a19d0373 100644
--- a/simp/assignment_collection.py
+++ b/simp/assignment_collection.py
@@ -1,6 +1,6 @@
 import sympy as sp
 from copy import copy
-from typing import List, Optional, Dict, Any, Set, Sequence, Iterator, Iterable
+from typing import List, Optional, Dict, Any, Set, Sequence, Iterator, Iterable, Union
 from pystencils.assignment import Assignment
 from pystencils.sympyextensions import fast_subs, count_operations, sort_assignments_topologically
 
@@ -27,9 +27,17 @@ class AssignmentCollection:
 
     # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
 
-    def __init__(self, main_assignments: List[Assignment], subexpressions: List[Assignment],
-                 simplification_hints: Optional[Dict[str, Any]] = None,
-                 subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
+    def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
+                 subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
+                 simplification_hints: Optional[Dict[str, Any]]=None,
+                 subexpression_symbol_generator: Iterator[sp.Symbol]=None) -> None:
+        if isinstance(main_assignments, Dict):
+            main_assignments = [Assignment(k, v)
+                                for k, v in main_assignments.items()]
+        if isinstance(subexpressions, Dict):
+            subexpressions = [Assignment(k, v)
+                              for k, v in subexpressions.items()]
+
         self.main_assignments = main_assignments
         self.subexpressions = subexpressions
 
@@ -65,7 +73,8 @@ class AssignmentCollection:
         eq = Assignment(lhs, rhs)
         self.subexpressions.append(eq)
         if topological_sort:
-            self.topological_sort(sort_subexpressions=True, sort_main_assignments=False)
+            self.topological_sort(sort_subexpressions=True,
+                                  sort_main_assignments=False)
         return lhs
 
     def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
@@ -335,9 +344,26 @@ class AssignmentCollection:
     def __iter__(self):
         return self.main_assignments.__iter__()
 
+    @property
+    def main_assignments_dict(self):
+        return {a.lhs: a.rhs for a in self.main_assignments}
+
+    @property
+    def subexpressions_dict(self):
+        return {a.lhs: a.rhs for a in self.subexpressions}
+
+    def set_main_assignments_from_dict(self, main_assignments_dict):
+        self.main_assignments = [Assignment(k, v)
+                                 for k, v in main_assignments_dict.items()]
+
+    def set_sub_expressions_from_dict(self, sub_expressions_dict):
+        self.sub_expressions = [Assignment(k, v)
+                                for k, v in sub_expressions_dict.items()]
+
 
 class SymbolGen:
     """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
+
     def __init__(self, symbol="xi"):
         self._ctr = 0
         self._symbol = symbol
-- 
GitLab