From 6723d8435686e30ddea9c07c9a1a14ba6fc558eb Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 16 Nov 2018 10:29:02 +0100 Subject: [PATCH] Make AssignmentCollection constructible/convertible from/to dict --- simp/assignment_collection.py | 36 ++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/simp/assignment_collection.py b/simp/assignment_collection.py index b86960c40..1a19d0373 100644 --- a/simp/assignment_collection.py +++ b/simp/assignment_collection.py @@ -1,6 +1,6 @@ import sympy as sp from copy import copy -from typing import List, Optional, Dict, Any, Set, Sequence, Iterator, Iterable +from typing import List, Optional, Dict, Any, Set, Sequence, Iterator, Iterable, Union from pystencils.assignment import Assignment from pystencils.sympyextensions import fast_subs, count_operations, sort_assignments_topologically @@ -27,9 +27,17 @@ class AssignmentCollection: # ------------------------------- Creation & Inplace Manipulation -------------------------------------------------- - def __init__(self, main_assignments: List[Assignment], subexpressions: List[Assignment], - simplification_hints: Optional[Dict[str, Any]] = None, - subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None: + def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]], + subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]], + simplification_hints: Optional[Dict[str, Any]]=None, + subexpression_symbol_generator: Iterator[sp.Symbol]=None) -> None: + if isinstance(main_assignments, Dict): + main_assignments = [Assignment(k, v) + for k, v in main_assignments.items()] + if isinstance(subexpressions, Dict): + subexpressions = [Assignment(k, v) + for k, v in subexpressions.items()] + self.main_assignments = main_assignments self.subexpressions = subexpressions @@ -65,7 +73,8 @@ class AssignmentCollection: eq = Assignment(lhs, rhs) self.subexpressions.append(eq) if topological_sort: - self.topological_sort(sort_subexpressions=True, sort_main_assignments=False) + self.topological_sort(sort_subexpressions=True, + sort_main_assignments=False) return lhs def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None: @@ -335,9 +344,26 @@ class AssignmentCollection: def __iter__(self): return self.main_assignments.__iter__() + @property + def main_assignments_dict(self): + return {a.lhs: a.rhs for a in self.main_assignments} + + @property + def subexpressions_dict(self): + return {a.lhs: a.rhs for a in self.subexpressions} + + def set_main_assignments_from_dict(self, main_assignments_dict): + self.main_assignments = [Assignment(k, v) + for k, v in main_assignments_dict.items()] + + def set_sub_expressions_from_dict(self, sub_expressions_dict): + self.sub_expressions = [Assignment(k, v) + for k, v in sub_expressions_dict.items()] + class SymbolGen: """Default symbol generator producing number symbols ζ_0, ζ_1, ...""" + def __init__(self, symbol="xi"): self._ctr = 0 self._symbol = symbol -- GitLab