diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index cc58ebd301886db63bdf4e37d86285643768e9bd..a9c7c98f2a685112a249b03a7c2bd8be0afcd363 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -5,6 +5,7 @@ from typing import Any, List, Optional, Sequence, Set, Union import sympy as sp +import pystencils from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type from pystencils.field import Field from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol @@ -353,7 +354,10 @@ class Block(Node): def symbols_defined(self): result = set() for a in self.args: - result.update(a.symbols_defined) + if isinstance(a, pystencils.Assignment): + result.update(a.free_symbols) + else: + result.update(a.symbols_defined) return result @property @@ -361,8 +365,12 @@ class Block(Node): result = set() defined_symbols = set() for a in self.args: - result.update(a.undefined_symbols) - defined_symbols.update(a.symbols_defined) + if isinstance(a, pystencils.Assignment): + result.update(a.free_symbols) + defined_symbols.update({a.lhs}) + else: + result.update(a.undefined_symbols) + defined_symbols.update(a.symbols_defined) return result - defined_symbols def __str__(self): diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index d874965c1d8600fb1b616afc338d4c62d2506f62..3a8bb2bd3b315e23fb9e11ef6385f264428c7bd5 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, import sympy as sp +import pystencils from pystencils.assignment import Assignment from pystencils.simp.simplifications import ( sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) @@ -100,15 +101,29 @@ class AssignmentCollection: """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" free_symbols = set() for eq in self.all_assignments: - free_symbols.update(eq.rhs.atoms(sp.Symbol)) + if isinstance(eq, Assignment): + free_symbols.update(eq.rhs.atoms(sp.Symbol)) + elif isinstance(eq, pystencils.astnodes.Node): + free_symbols.update(eq.undefined_symbols) + return free_symbols - self.bound_symbols @property def bound_symbols(self) -> Set[sp.Symbol]: """All symbols which occur on the left hand side of a main assignment or a subexpression.""" - bound_symbols_set = set([eq.lhs for eq in self.all_assignments]) - assert len(bound_symbols_set) == len(self.subexpressions) + len(self.main_assignments), \ + bound_symbols_set = set( + [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)] + ) + + assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \ "Not in SSA form - same symbol assigned multiple times" + + bound_symbols_set = bound_symbols_set.union(*[ + assignment.symbols_defined for assignment in self.all_assignments + if isinstance(assignment, pystencils.astnodes.Node) + ] + ) + return bound_symbols_set @property @@ -124,7 +139,11 @@ class AssignmentCollection: @property def defined_symbols(self) -> Set[sp.Symbol]: """All symbols which occur as left-hand-sides of one of the main equations""" - return set([assignment.lhs for assignment in self.main_assignments]) + return (set( + [assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)] + ).union(*[assignment.symbols_defined for assignment in self.main_assignments if isinstance( + assignment, pystencils.astnodes.Node)] + )) @property def operation_count(self): diff --git a/pystencils_tests/test_assignment_collection.py b/pystencils_tests/test_assignment_collection.py index 1a8f6f9dc5b4f7091e302741ff73b9994f4ed3a2..42e3791ab331cc111343bfe410bc59e3f52ba80b 100644 --- a/pystencils_tests/test_assignment_collection.py +++ b/pystencils_tests/test_assignment_collection.py @@ -1,6 +1,7 @@ import sympy as sp from pystencils import Assignment, AssignmentCollection +from pystencils.astnodes import Conditional from pystencils.simp.assignment_collection import SymbolGen @@ -27,3 +28,15 @@ def test_assignment_collection(): assert 'a_0' in str(ac_inserted) assert '<table' in ac_inserted._repr_html_() + + +def test_free_and_defined_symbols(): + x, y, z, t = sp.symbols("x y z t") + a, b = sp.symbols("a b") + symbol_gen = SymbolGen("a") + + ac = AssignmentCollection([Assignment(z, x + y), Conditional(t > 0, Assignment(a, b+1), Assignment(a, b+2))], + [], subexpression_symbol_generator=symbol_gen) + + print(ac) + print(ac.__repr__)