Skip to content
Snippets Groups Projects
Commit ef543c5e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Fix AssignmentCollection.{free_symbols,bound_symbols,defined_symbols} for non-Assignments

parent f9e88655
Branches
Tags
No related merge requests found
...@@ -5,6 +5,7 @@ from typing import Any, List, Optional, Sequence, Set, Union ...@@ -5,6 +5,7 @@ from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp import sympy as sp
import pystencils
from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
from pystencils.field import Field from pystencils.field import Field
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
...@@ -353,7 +354,10 @@ class Block(Node): ...@@ -353,7 +354,10 @@ class Block(Node):
def symbols_defined(self): def symbols_defined(self):
result = set() result = set()
for a in self.args: for a in self.args:
result.update(a.symbols_defined) if isinstance(a, pystencils.Assignment):
result.update(a.free_symbols)
else:
result.update(a.symbols_defined)
return result return result
@property @property
...@@ -361,8 +365,12 @@ class Block(Node): ...@@ -361,8 +365,12 @@ class Block(Node):
result = set() result = set()
defined_symbols = set() defined_symbols = set()
for a in self.args: for a in self.args:
result.update(a.undefined_symbols) if isinstance(a, pystencils.Assignment):
defined_symbols.update(a.symbols_defined) result.update(a.free_symbols)
defined_symbols.update({a.lhs})
else:
result.update(a.undefined_symbols)
defined_symbols.update(a.symbols_defined)
return result - defined_symbols return result - defined_symbols
def __str__(self): def __str__(self):
......
...@@ -3,6 +3,7 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, ...@@ -3,6 +3,7 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set,
import sympy as sp import sympy as sp
import pystencils
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.simp.simplifications import ( from pystencils.simp.simplifications import (
sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
...@@ -100,15 +101,29 @@ class AssignmentCollection: ...@@ -100,15 +101,29 @@ class AssignmentCollection:
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
free_symbols = set() free_symbols = set()
for eq in self.all_assignments: for eq in self.all_assignments:
free_symbols.update(eq.rhs.atoms(sp.Symbol)) if isinstance(eq, Assignment):
free_symbols.update(eq.rhs.atoms(sp.Symbol))
elif isinstance(eq, pystencils.astnodes.Node):
free_symbols.update(eq.undefined_symbols)
return free_symbols - self.bound_symbols return free_symbols - self.bound_symbols
@property @property
def bound_symbols(self) -> Set[sp.Symbol]: def bound_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur on the left hand side of a main assignment or a subexpression.""" """All symbols which occur on the left hand side of a main assignment or a subexpression."""
bound_symbols_set = set([eq.lhs for eq in self.all_assignments]) bound_symbols_set = set(
assert len(bound_symbols_set) == len(self.subexpressions) + len(self.main_assignments), \ [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]
)
assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
"Not in SSA form - same symbol assigned multiple times" "Not in SSA form - same symbol assigned multiple times"
bound_symbols_set = bound_symbols_set.union(*[
assignment.symbols_defined for assignment in self.all_assignments
if isinstance(assignment, pystencils.astnodes.Node)
]
)
return bound_symbols_set return bound_symbols_set
@property @property
...@@ -124,7 +139,11 @@ class AssignmentCollection: ...@@ -124,7 +139,11 @@ class AssignmentCollection:
@property @property
def defined_symbols(self) -> Set[sp.Symbol]: def defined_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur as left-hand-sides of one of the main equations""" """All symbols which occur as left-hand-sides of one of the main equations"""
return set([assignment.lhs for assignment in self.main_assignments]) return (set(
[assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)]
).union(*[assignment.symbols_defined for assignment in self.main_assignments if isinstance(
assignment, pystencils.astnodes.Node)]
))
@property @property
def operation_count(self): def operation_count(self):
......
import sympy as sp import sympy as sp
from pystencils import Assignment, AssignmentCollection from pystencils import Assignment, AssignmentCollection
from pystencils.astnodes import Conditional
from pystencils.simp.assignment_collection import SymbolGen from pystencils.simp.assignment_collection import SymbolGen
...@@ -27,3 +28,15 @@ def test_assignment_collection(): ...@@ -27,3 +28,15 @@ def test_assignment_collection():
assert 'a_0' in str(ac_inserted) assert 'a_0' in str(ac_inserted)
assert '<table' in ac_inserted._repr_html_() assert '<table' in ac_inserted._repr_html_()
def test_free_and_defined_symbols():
x, y, z, t = sp.symbols("x y z t")
a, b = sp.symbols("a b")
symbol_gen = SymbolGen("a")
ac = AssignmentCollection([Assignment(z, x + y), Conditional(t > 0, Assignment(a, b+1), Assignment(a, b+2))],
[], subexpression_symbol_generator=symbol_gen)
print(ac)
print(ac.__repr__)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment