From 1edb192f6a791bddbe343c3268c8dd7b2919549f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de> Date: Wed, 4 May 2022 10:47:01 +0200 Subject: [PATCH] Introduced a converting function form AssignmentCollection to NodeCollection. --- pystencils/node_collection.py | 7 ++++++- pystencils_tests/test_nodecollection.py | 13 +++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 pystencils_tests/test_nodecollection.py diff --git a/pystencils/node_collection.py b/pystencils/node_collection.py index 8e396cc38..b393df0ed 100644 --- a/pystencils/node_collection.py +++ b/pystencils/node_collection.py @@ -6,9 +6,10 @@ import sympy as sp from sympy.codegen import Assignment from sympy.codegen.rewriting import ReplaceOptim, optimize -from pystencils.astnodes import Block, Node +from pystencils.astnodes import Block, Node, SympyAssignment from pystencils.backends.cbackend import CustomCodeNode from pystencils.functions import DivFunc +from pystencils.simp import AssignmentCollection class NodeCollection: @@ -28,6 +29,10 @@ class NodeCollection: self.simplification_hints = {} + @staticmethod + def from_assignment_collection(assignment_collection: AssignmentCollection): + return NodeCollection([SympyAssignment(a.lhs, a.rhs) for a in assignment_collection.all_assignments]) + def evaluate_terms(self): evaluate_constant_terms = ReplaceOptim( lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer, diff --git a/pystencils_tests/test_nodecollection.py b/pystencils_tests/test_nodecollection.py new file mode 100644 index 000000000..ab24e58e7 --- /dev/null +++ b/pystencils_tests/test_nodecollection.py @@ -0,0 +1,13 @@ +import sympy as sp + +from pystencils import AssignmentCollection, Assignment +from pystencils.node_collection import NodeCollection +from pystencils.astnodes import SympyAssignment + + +def test_node_collection_from_assignment_collection(): + x = sp.symbols('x') + assignment_collection = AssignmentCollection([Assignment(x, 2)]) + node_collection = NodeCollection.from_assignment_collection(assignment_collection) + + assert node_collection.all_assignments[0] == SympyAssignment(x, 2) -- GitLab