From ef543c5edc099a3aca59f8886ac42948d920f9b3 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Thu, 19 Dec 2019 09:44:29 +0100 Subject: [PATCH] Fix AssignmentCollection.{free_symbols,bound_symbols,defined_symbols} for non-Assignments --- pystencils/astnodes.py | 14 +++++++--- pystencils/simp/assignment_collection.py | 27 ++++++++++++++++--- .../test_assignment_collection.py | 13 +++++++++ 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index cc58ebd30..a9c7c98f2 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -5,6 +5,7 @@ from typing import Any, List, Optional, Sequence, Set, Union import sympy as sp +import pystencils from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type from pystencils.field import Field from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol @@ -353,7 +354,10 @@ class Block(Node): def symbols_defined(self): result = set() for a in self.args: - result.update(a.symbols_defined) + if isinstance(a, pystencils.Assignment): + result.update(a.free_symbols) + else: + result.update(a.symbols_defined) return result @property @@ -361,8 +365,12 @@ class Block(Node): result = set() defined_symbols = set() for a in self.args: - result.update(a.undefined_symbols) - defined_symbols.update(a.symbols_defined) + if isinstance(a, pystencils.Assignment): + result.update(a.free_symbols) + defined_symbols.update({a.lhs}) + else: + result.update(a.undefined_symbols) + defined_symbols.update(a.symbols_defined) return result - defined_symbols def __str__(self): diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index d874965c1..3a8bb2bd3 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, import sympy as sp +import pystencils from pystencils.assignment import Assignment from pystencils.simp.simplifications import ( sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) @@ -100,15 +101,29 @@ class AssignmentCollection: """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" free_symbols = set() for eq in self.all_assignments: - free_symbols.update(eq.rhs.atoms(sp.Symbol)) + if isinstance(eq, Assignment): + free_symbols.update(eq.rhs.atoms(sp.Symbol)) + elif isinstance(eq, pystencils.astnodes.Node): + free_symbols.update(eq.undefined_symbols) + return free_symbols - self.bound_symbols @property def bound_symbols(self) -> Set[sp.Symbol]: """All symbols which occur on the left hand side of a main assignment or a subexpression.""" - bound_symbols_set = set([eq.lhs for eq in self.all_assignments]) - assert len(bound_symbols_set) == len(self.subexpressions) + len(self.main_assignments), \ + bound_symbols_set = set( + [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)] + ) + + assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \ "Not in SSA form - same symbol assigned multiple times" + + bound_symbols_set = bound_symbols_set.union(*[ + assignment.symbols_defined for assignment in self.all_assignments + if isinstance(assignment, pystencils.astnodes.Node) + ] + ) + return bound_symbols_set @property @@ -124,7 +139,11 @@ class AssignmentCollection: @property def defined_symbols(self) -> Set[sp.Symbol]: """All symbols which occur as left-hand-sides of one of the main equations""" - return set([assignment.lhs for assignment in self.main_assignments]) + return (set( + [assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)] + ).union(*[assignment.symbols_defined for assignment in self.main_assignments if isinstance( + assignment, pystencils.astnodes.Node)] + )) @property def operation_count(self): diff --git a/pystencils_tests/test_assignment_collection.py b/pystencils_tests/test_assignment_collection.py index 1a8f6f9dc..42e3791ab 100644 --- a/pystencils_tests/test_assignment_collection.py +++ b/pystencils_tests/test_assignment_collection.py @@ -1,6 +1,7 @@ import sympy as sp from pystencils import Assignment, AssignmentCollection +from pystencils.astnodes import Conditional from pystencils.simp.assignment_collection import SymbolGen @@ -27,3 +28,15 @@ def test_assignment_collection(): assert 'a_0' in str(ac_inserted) assert '<table' in ac_inserted._repr_html_() + + +def test_free_and_defined_symbols(): + x, y, z, t = sp.symbols("x y z t") + a, b = sp.symbols("a b") + symbol_gen = SymbolGen("a") + + ac = AssignmentCollection([Assignment(z, x + y), Conditional(t > 0, Assignment(a, b+1), Assignment(a, b+2))], + [], subexpression_symbol_generator=symbol_gen) + + print(ac) + print(ac.__repr__) -- GitLab