Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stephan Seitz
pystencils
Commits
ef543c5e
Commit
ef543c5e
authored
Dec 19, 2019
by
Stephan Seitz
Browse files
Fix AssignmentCollection.{free_symbols,bound_symbols,defined_symbols} for non-Assignments
parent
f9e88655
Pipeline
#20661
passed with stage
in 13 minutes and 12 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
pystencils/astnodes.py
View file @
ef543c5e
...
...
@@ -5,6 +5,7 @@ from typing import Any, List, Optional, Sequence, Set, Union
import
sympy
as
sp
import
pystencils
from
pystencils.data_types
import
TypedImaginaryUnit
,
TypedSymbol
,
cast_func
,
create_type
from
pystencils.field
import
Field
from
pystencils.kernelparameters
import
FieldPointerSymbol
,
FieldShapeSymbol
,
FieldStrideSymbol
...
...
@@ -353,7 +354,10 @@ class Block(Node):
def
symbols_defined
(
self
):
result
=
set
()
for
a
in
self
.
args
:
result
.
update
(
a
.
symbols_defined
)
if
isinstance
(
a
,
pystencils
.
Assignment
):
result
.
update
(
a
.
free_symbols
)
else
:
result
.
update
(
a
.
symbols_defined
)
return
result
@
property
...
...
@@ -361,8 +365,12 @@ class Block(Node):
result
=
set
()
defined_symbols
=
set
()
for
a
in
self
.
args
:
result
.
update
(
a
.
undefined_symbols
)
defined_symbols
.
update
(
a
.
symbols_defined
)
if
isinstance
(
a
,
pystencils
.
Assignment
):
result
.
update
(
a
.
free_symbols
)
defined_symbols
.
update
({
a
.
lhs
})
else
:
result
.
update
(
a
.
undefined_symbols
)
defined_symbols
.
update
(
a
.
symbols_defined
)
return
result
-
defined_symbols
def
__str__
(
self
):
...
...
pystencils/simp/assignment_collection.py
View file @
ef543c5e
...
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set,
import
sympy
as
sp
import
pystencils
from
pystencils.assignment
import
Assignment
from
pystencils.simp.simplifications
import
(
sort_assignments_topologically
,
transform_lhs_and_rhs
,
transform_rhs
)
...
...
@@ -100,15 +101,29 @@ class AssignmentCollection:
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
free_symbols
=
set
()
for
eq
in
self
.
all_assignments
:
free_symbols
.
update
(
eq
.
rhs
.
atoms
(
sp
.
Symbol
))
if
isinstance
(
eq
,
Assignment
):
free_symbols
.
update
(
eq
.
rhs
.
atoms
(
sp
.
Symbol
))
elif
isinstance
(
eq
,
pystencils
.
astnodes
.
Node
):
free_symbols
.
update
(
eq
.
undefined_symbols
)
return
free_symbols
-
self
.
bound_symbols
@
property
def
bound_symbols
(
self
)
->
Set
[
sp
.
Symbol
]:
"""All symbols which occur on the left hand side of a main assignment or a subexpression."""
bound_symbols_set
=
set
([
eq
.
lhs
for
eq
in
self
.
all_assignments
])
assert
len
(
bound_symbols_set
)
==
len
(
self
.
subexpressions
)
+
len
(
self
.
main_assignments
),
\
bound_symbols_set
=
set
(
[
assignment
.
lhs
for
assignment
in
self
.
all_assignments
if
isinstance
(
assignment
,
Assignment
)]
)
assert
len
(
bound_symbols_set
)
==
len
(
list
(
a
for
a
in
self
.
all_assignments
if
isinstance
(
a
,
Assignment
))),
\
"Not in SSA form - same symbol assigned multiple times"
bound_symbols_set
=
bound_symbols_set
.
union
(
*
[
assignment
.
symbols_defined
for
assignment
in
self
.
all_assignments
if
isinstance
(
assignment
,
pystencils
.
astnodes
.
Node
)
]
)
return
bound_symbols_set
@
property
...
...
@@ -124,7 +139,11 @@ class AssignmentCollection:
@
property
def
defined_symbols
(
self
)
->
Set
[
sp
.
Symbol
]:
"""All symbols which occur as left-hand-sides of one of the main equations"""
return
set
([
assignment
.
lhs
for
assignment
in
self
.
main_assignments
])
return
(
set
(
[
assignment
.
lhs
for
assignment
in
self
.
main_assignments
if
isinstance
(
assignment
,
Assignment
)]
).
union
(
*
[
assignment
.
symbols_defined
for
assignment
in
self
.
main_assignments
if
isinstance
(
assignment
,
pystencils
.
astnodes
.
Node
)]
))
@
property
def
operation_count
(
self
):
...
...
pystencils_tests/test_assignment_collection.py
View file @
ef543c5e
import
sympy
as
sp
from
pystencils
import
Assignment
,
AssignmentCollection
from
pystencils.astnodes
import
Conditional
from
pystencils.simp.assignment_collection
import
SymbolGen
...
...
@@ -27,3 +28,15 @@ def test_assignment_collection():
assert
'a_0'
in
str
(
ac_inserted
)
assert
'<table'
in
ac_inserted
.
_repr_html_
()
def
test_free_and_defined_symbols
():
x
,
y
,
z
,
t
=
sp
.
symbols
(
"x y z t"
)
a
,
b
=
sp
.
symbols
(
"a b"
)
symbol_gen
=
SymbolGen
(
"a"
)
ac
=
AssignmentCollection
([
Assignment
(
z
,
x
+
y
),
Conditional
(
t
>
0
,
Assignment
(
a
,
b
+
1
),
Assignment
(
a
,
b
+
2
))],
[],
subexpression_symbol_generator
=
symbol_gen
)
print
(
ac
)
print
(
ac
.
__repr__
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment