diff --git a/pystencils/node_collection.py b/pystencils/node_collection.py index 8e396cc38ced5759cb26e6d40e60a32f49cd1cad..b393df0ede2598638707974e19533a2bef735fcd 100644 --- a/pystencils/node_collection.py +++ b/pystencils/node_collection.py @@ -6,9 +6,10 @@ import sympy as sp from sympy.codegen import Assignment from sympy.codegen.rewriting import ReplaceOptim, optimize -from pystencils.astnodes import Block, Node +from pystencils.astnodes import Block, Node, SympyAssignment from pystencils.backends.cbackend import CustomCodeNode from pystencils.functions import DivFunc +from pystencils.simp import AssignmentCollection class NodeCollection: @@ -28,6 +29,10 @@ class NodeCollection: self.simplification_hints = {} + @staticmethod + def from_assignment_collection(assignment_collection: AssignmentCollection): + return NodeCollection([SympyAssignment(a.lhs, a.rhs) for a in assignment_collection.all_assignments]) + def evaluate_terms(self): evaluate_constant_terms = ReplaceOptim( lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer, diff --git a/pystencils_tests/test_nodecollection.py b/pystencils_tests/test_nodecollection.py new file mode 100644 index 0000000000000000000000000000000000000000..ab24e58e7dc4a70985e6a9c631b9085b39a3f00f --- /dev/null +++ b/pystencils_tests/test_nodecollection.py @@ -0,0 +1,13 @@ +import sympy as sp + +from pystencils import AssignmentCollection, Assignment +from pystencils.node_collection import NodeCollection +from pystencils.astnodes import SympyAssignment + + +def test_node_collection_from_assignment_collection(): + x = sp.symbols('x') + assignment_collection = AssignmentCollection([Assignment(x, 2)]) + node_collection = NodeCollection.from_assignment_collection(assignment_collection) + + assert node_collection.all_assignments[0] == SympyAssignment(x, 2)