Skip to content
Snippets Groups Projects

Fix AssignmentCollection.{free_symbols,bound_symbols,defined_symbols} for non-Assignments

Compare and
3 files
+ 47
7
Preferences
Compare changes
Files
3
@@ -3,6 +3,7 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set,
import sympy as sp
import pystencils
from pystencils.assignment import Assignment
from pystencils.simp.simplifications import (
sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
@@ -100,15 +101,29 @@ class AssignmentCollection:
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
free_symbols = set()
for eq in self.all_assignments:
free_symbols.update(eq.rhs.atoms(sp.Symbol))
if isinstance(eq, Assignment):
free_symbols.update(eq.rhs.atoms(sp.Symbol))
elif isinstance(eq, pystencils.astnodes.Node):
free_symbols.update(eq.undefined_symbols)
return free_symbols - self.bound_symbols
@property
def bound_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur on the left hand side of a main assignment or a subexpression."""
bound_symbols_set = set([eq.lhs for eq in self.all_assignments])
assert len(bound_symbols_set) == len(self.subexpressions) + len(self.main_assignments), \
bound_symbols_set = set(
[assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]
)
assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
"Not in SSA form - same symbol assigned multiple times"
bound_symbols_set = bound_symbols_set.union(*[
assignment.symbols_defined for assignment in self.all_assignments
if isinstance(assignment, pystencils.astnodes.Node)
]
)
return bound_symbols_set
@property
@@ -124,7 +139,11 @@ class AssignmentCollection:
@property
def defined_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur as left-hand-sides of one of the main equations"""
return set([assignment.lhs for assignment in self.main_assignments])
return (set(
[assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)]
).union(*[assignment.symbols_defined for assignment in self.main_assignments if isinstance(
assignment, pystencils.astnodes.Node)]
))
@property
def operation_count(self):