Skip to content
Snippets Groups Projects
Commit 6723d843 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Make AssignmentCollection constructible/convertible from/to dict

parent d5373ec8
No related merge requests found
import sympy as sp import sympy as sp
from copy import copy from copy import copy
from typing import List, Optional, Dict, Any, Set, Sequence, Iterator, Iterable from typing import List, Optional, Dict, Any, Set, Sequence, Iterator, Iterable, Union
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.sympyextensions import fast_subs, count_operations, sort_assignments_topologically from pystencils.sympyextensions import fast_subs, count_operations, sort_assignments_topologically
...@@ -27,9 +27,17 @@ class AssignmentCollection: ...@@ -27,9 +27,17 @@ class AssignmentCollection:
# ------------------------------- Creation & Inplace Manipulation -------------------------------------------------- # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
def __init__(self, main_assignments: List[Assignment], subexpressions: List[Assignment], def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
simplification_hints: Optional[Dict[str, Any]] = None, subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None: simplification_hints: Optional[Dict[str, Any]]=None,
subexpression_symbol_generator: Iterator[sp.Symbol]=None) -> None:
if isinstance(main_assignments, Dict):
main_assignments = [Assignment(k, v)
for k, v in main_assignments.items()]
if isinstance(subexpressions, Dict):
subexpressions = [Assignment(k, v)
for k, v in subexpressions.items()]
self.main_assignments = main_assignments self.main_assignments = main_assignments
self.subexpressions = subexpressions self.subexpressions = subexpressions
...@@ -65,7 +73,8 @@ class AssignmentCollection: ...@@ -65,7 +73,8 @@ class AssignmentCollection:
eq = Assignment(lhs, rhs) eq = Assignment(lhs, rhs)
self.subexpressions.append(eq) self.subexpressions.append(eq)
if topological_sort: if topological_sort:
self.topological_sort(sort_subexpressions=True, sort_main_assignments=False) self.topological_sort(sort_subexpressions=True,
sort_main_assignments=False)
return lhs return lhs
def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None: def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
...@@ -335,9 +344,26 @@ class AssignmentCollection: ...@@ -335,9 +344,26 @@ class AssignmentCollection:
def __iter__(self): def __iter__(self):
return self.main_assignments.__iter__() return self.main_assignments.__iter__()
@property
def main_assignments_dict(self):
return {a.lhs: a.rhs for a in self.main_assignments}
@property
def subexpressions_dict(self):
return {a.lhs: a.rhs for a in self.subexpressions}
def set_main_assignments_from_dict(self, main_assignments_dict):
self.main_assignments = [Assignment(k, v)
for k, v in main_assignments_dict.items()]
def set_sub_expressions_from_dict(self, sub_expressions_dict):
self.sub_expressions = [Assignment(k, v)
for k, v in sub_expressions_dict.items()]
class SymbolGen: class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ...""" """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def __init__(self, symbol="xi"): def __init__(self, symbol="xi"):
self._ctr = 0 self._ctr = 0
self._symbol = symbol self._symbol = symbol
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment