Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Julian Hammer
pystencils
Commits
0d6780b8
Commit
0d6780b8
authored
Aug 10, 2020
by
Jan Hönig
Browse files
Merge branch 'Extend_testsuite' into 'master'
Extend testsuite See merge request
!168
parents
4f41d979
20118400
Changes
32
Hide whitespace changes
Inline
Side-by-side
pystencils/backends/cbackend.py
View file @
0d6780b8
...
...
@@ -298,7 +298,7 @@ class CBackend:
return
node
.
get_code
(
self
.
_dialect
,
self
.
_vector_instruction_set
)
def
_print_SourceCodeComment
(
self
,
node
):
return
"/*
"
+
node
.
text
+
"
*/"
return
f
"/*
{
node
.
text
}
*/"
def
_print_EmptyLine
(
self
,
node
):
return
""
...
...
@@ -316,7 +316,7 @@ class CBackend:
result
=
f
"if (
{
condition_expr
}
)
\n
{
true_block
}
"
if
node
.
false_block
:
false_block
=
self
.
_print_Block
(
node
.
false_block
)
result
+=
"else
"
+
false_block
result
+=
f
"else
{
false_block
}
"
return
result
...
...
@@ -336,7 +336,7 @@ class CustomSympyPrinter(CCodePrinter):
return
self
.
_typed_number
(
expr
.
evalf
(),
get_type_of_expression
(
expr
))
if
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
0
<
expr
.
exp
<
8
:
return
"(
"
+
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
+
"
)"
return
f
"(
{
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
}
)"
elif
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
-
8
<
expr
.
exp
<
0
:
return
f
"1 / (
{
self
.
_print
(
sp
.
Mul
(
*
([
expr
.
base
]
*
-
expr
.
exp
),
evaluate
=
False
))
}
)"
else
:
...
...
@@ -589,9 +589,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result
=
self
.
instruction_set
[
'&'
].
format
(
result
,
item
)
return
result
def
_print_Max
(
self
,
expr
):
return
"test"
def
_print_Or
(
self
,
expr
):
result
=
self
.
_scalarFallback
(
'_print_Or'
,
expr
)
if
result
:
...
...
pystencils/backends/cuda_backend.py
View file @
0d6780b8
...
...
@@ -104,7 +104,6 @@ class CudaSympyPrinter(CustomSympyPrinter):
assert
len
(
expr
.
args
)
==
1
,
f
"__fsqrt_rn has one argument, but
{
len
(
expr
.
args
)
}
where given"
return
f
"__fsqrt_rn(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
elif
isinstance
(
expr
,
fast_inv_sqrt
):
print
(
len
(
expr
.
args
)
==
1
)
assert
len
(
expr
.
args
)
==
1
,
f
"__frsqrt_rn has one argument, but
{
len
(
expr
.
args
)
}
where given"
return
f
"__frsqrt_rn(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
return
super
().
_print_Function
(
expr
)
pystencils/datahandling/datahandling_interface.py
View file @
0d6780b8
...
...
@@ -86,6 +86,13 @@ class DataHandling(ABC):
Args:
description (str): String description of the fields to add
dtype: data type of the array as numpy data type
ghost_layers: number of ghost layers - if not specified a default value specified in the constructor
is used
layout: memory layout of array, either structure of arrays 'SoA' or array of structures 'AoS'.
this is only important if values_per_cell > 1
cpu: allocate field on the CPU
gpu: allocate field on the GPU, if None, a GPU field is allocated if default_target is 'gpu'
alignment: either False for no alignment, or the number of bytes to align to
Returns:
Fields representing the just created arrays
"""
...
...
@@ -200,6 +207,10 @@ class DataHandling(ABC):
directly passed to the kernel function and override possible parameters from the DataHandling
"""
@
abstractmethod
def
get_kernel_kwargs
(
self
,
kernel_function
,
**
kwargs
):
"""Returns the input arguments of a kernel"""
@
abstractmethod
def
swap
(
self
,
name1
,
name2
,
gpu
=
False
):
"""Swaps data of two arrays"""
...
...
pystencils/datahandling/serial_datahandling.py
View file @
0d6780b8
...
...
@@ -266,10 +266,10 @@ class SerialDataHandling(DataHandling):
return
name
in
self
.
gpu_arrays
def
synchronization_function_cpu
(
self
,
names
,
stencil_name
=
None
,
**
_
):
return
self
.
synchronization_function
(
names
,
stencil_name
,
'cpu'
)
return
self
.
synchronization_function
(
names
,
stencil_name
,
target
=
'cpu'
)
def
synchronization_function_gpu
(
self
,
names
,
stencil_name
=
None
,
**
_
):
return
self
.
synchronization_function
(
names
,
stencil_name
,
'gpu'
)
return
self
.
synchronization_function
(
names
,
stencil_name
,
target
=
'gpu'
)
def
synchronization_function
(
self
,
names
,
stencil
=
None
,
target
=
None
,
**
_
):
if
target
is
None
:
...
...
@@ -425,14 +425,15 @@ class SerialDataHandling(DataHandling):
np
.
savez_compressed
(
file
,
**
self
.
cpu_arrays
)
def
load_all
(
self
,
file
):
if
'.npz'
not
in
file
:
file
+=
'.npz'
file_contents
=
np
.
load
(
file
)
for
arr_name
,
arr_contents
in
self
.
cpu_arrays
.
items
():
if
arr_name
not
in
file_contents
:
print
(
f
"Skipping read data
{
arr_name
}
because there is no data with this name in data handling"
)
continue
if
file_contents
[
arr_name
].
shape
!=
arr_contents
.
shape
:
print
(
"Skipping read data {} because shapes don't match. "
"Read array shape {}, existing array shape {}"
.
format
(
arr_name
,
file_contents
[
arr_name
].
shape
,
arr_contents
.
shape
))
print
(
f
"Skipping read data
{
arr_name
}
because shapes don't match. "
f
"Read array shape
{
file_contents
[
arr_name
].
shape
}
, existing array shape
{
arr_contents
.
shape
}
"
)
continue
np
.
copyto
(
arr_contents
,
file_contents
[
arr_name
])
pystencils/fd/derivative.py
View file @
0d6780b8
...
...
@@ -228,7 +228,9 @@ def diff_terms(expr):
Example:
>>> x, y = sp.symbols("x, y")
>>> diff_terms( diff(x, 0, 0) )
>>> diff_terms( diff(x, 0, 0) )
{Diff(Diff(x, 0, -1), 0, -1)}
>>> diff_terms( diff(x, 0, 0) + y )
{Diff(Diff(x, 0, -1), 0, -1)}
"""
result
=
set
()
...
...
pystencils/simp/__init__.py
View file @
0d6780b8
from
.assignment_collection
import
AssignmentCollection
from
.simplifications
import
(
add_subexpressions_for_divisions
,
add_subexpressions_for_field_reads
,
apply_on_all_subexpressions
,
apply_to_all_assignments
,
add_subexpressions_for_sums
,
apply_on_all_subexpressions
,
apply_to_all_assignments
,
subexpression_substitution_in_existing_subexpressions
,
subexpression_substitution_in_main_assignments
,
sympy_cse
,
sympy_cse_on_assignment_list
)
from
.simplificationstrategy
import
SimplificationStrategy
...
...
@@ -10,4 +10,4 @@ __all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse'
,
'sympy_cse_on_assignment_list'
,
'apply_to_all_assignments'
,
'apply_on_all_subexpressions'
,
'subexpression_substitution_in_existing_subexpressions'
,
'subexpression_substitution_in_main_assignments'
,
'add_subexpressions_for_divisions'
,
'add_subexpressions_for_field_reads'
]
'add_subexpressions_for_sums'
,
'add_subexpressions_for_field_reads'
]
pystencils/simp/assignment_collection.py
View file @
0d6780b8
...
...
@@ -6,8 +6,7 @@ import sympy as sp
import
pystencils
from
pystencils.assignment
import
Assignment
from
pystencils.simp.simplifications
import
(
sort_assignments_topologically
,
transform_lhs_and_rhs
,
transform_rhs
)
from
pystencils.simp.simplifications
import
(
sort_assignments_topologically
,
transform_lhs_and_rhs
,
transform_rhs
)
from
pystencils.sympyextensions
import
count_operations
,
fast_subs
...
...
@@ -263,7 +262,7 @@ class AssignmentCollection:
own_definitions
=
set
([
e
.
lhs
for
e
in
self
.
main_assignments
])
other_definitions
=
set
([
e
.
lhs
for
e
in
other
.
main_assignments
])
assert
len
(
own_definitions
.
intersection
(
other_definitions
))
==
0
,
\
"Cannot
new_
merge
d
, since both
collection
define the same symbols"
"Cannot merge
collections
, since both define the same symbols"
own_subexpression_symbols
=
{
e
.
lhs
:
e
.
rhs
for
e
in
self
.
subexpressions
}
substitution_dict
=
{}
...
...
@@ -334,7 +333,7 @@ class AssignmentCollection:
kept_subexpressions
=
[]
if
self
.
subexpressions
[
0
].
lhs
in
subexpressions_to_keep
:
substitution_dict
=
{}
kept_subexpressions
=
self
.
subexpressions
[
0
]
kept_subexpressions
.
append
(
self
.
subexpressions
[
0
]
)
else
:
substitution_dict
=
{
self
.
subexpressions
[
0
].
lhs
:
self
.
subexpressions
[
0
].
rhs
}
...
...
pystencils/simp/simplifications.py
View file @
0d6780b8
...
...
@@ -18,7 +18,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
elif
isinstance
(
e1
,
Node
):
symbols
=
e1
.
symbols_defined
else
:
raise
NotImplementedError
(
"Cannot sort topologically. Object of type
"
+
type
(
e1
)
+
"
cannot be handled."
)
raise
NotImplementedError
(
f
"Cannot sort topologically. Object of type
{
type
(
e1
)
}
cannot be handled."
)
for
lhs
in
symbols
:
for
c2
,
e2
in
enumerate
(
assignments
):
...
...
@@ -112,14 +112,14 @@ def add_subexpressions_for_sums(ac):
addends
=
[]
def
contains_sum
(
term
):
if
term
.
func
==
sp
.
add
.
Add
:
if
term
.
func
==
sp
.
Add
:
return
True
if
term
.
is_Atom
:
return
False
return
any
([
contains_sum
(
a
)
for
a
in
term
.
args
])
def
search_addends
(
term
):
if
term
.
func
==
sp
.
add
.
Add
:
if
term
.
func
==
sp
.
Add
:
if
all
([
not
contains_sum
(
a
)
for
a
in
term
.
args
]):
addends
.
extend
(
term
.
args
)
for
a
in
term
.
args
:
...
...
pystencils/stencil.py
View file @
0d6780b8
...
...
@@ -34,6 +34,8 @@ def is_valid(stencil, max_neighborhood=None):
True
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=1)
False
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=2)
True
"""
expected_dim
=
len
(
stencil
[
0
])
for
d
in
stencil
:
...
...
@@ -67,8 +69,11 @@ def have_same_entries(s1, s2):
Examples:
>>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)]
>>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)]
>>> stencil3 = [(-1, 0), (0, -1), (1, 0)]
>>> have_same_entries(stencil1, stencil2)
True
>>> have_same_entries(stencil1, stencil3)
False
"""
if
len
(
s1
)
!=
len
(
s2
):
return
False
...
...
pystencils/sympyextensions.py
View file @
0d6780b8
...
...
@@ -272,7 +272,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
def
replace_second_order_products
(
expr
:
sp
.
Expr
,
search_symbols
:
Iterable
[
sp
.
Symbol
],
positive
:
Optional
[
bool
]
=
None
,
replace_mixed
:
Optional
[
List
[
Assignment
]]
=
None
)
->
sp
.
Expr
:
"""Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ).
"""Replaces second order mixed terms like
4*
x*y by 2*( (x+y)**2 - x**2 - y**2 ).
This makes the term longer - simplify usually is undoing these - however this
transformation can be done to find more common sub-expressions
...
...
@@ -293,7 +293,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym
if
expr
.
is_Mul
:
distinct_search_symbols
=
set
()
nr_of_search_terms
=
0
other_factors
=
1
other_factors
=
sp
.
Integer
(
1
)
for
t
in
expr
.
args
:
if
t
in
search_symbols
:
nr_of_search_terms
+=
1
...
...
@@ -509,13 +509,14 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
if
t
.
exp
>=
0
:
result
[
'muls'
]
+=
int
(
t
.
exp
)
-
1
else
:
result
[
'muls'
]
-=
1
if
result
[
'muls'
]
>
0
:
result
[
'muls'
]
-=
1
result
[
'divs'
]
+=
1
result
[
'muls'
]
+=
(
-
int
(
t
.
exp
))
-
1
elif
sp
.
nsimplify
(
t
.
exp
)
==
sp
.
Rational
(
1
,
2
):
result
[
'sqrts'
]
+=
1
else
:
warnings
.
warn
(
"Cannot handle exponent
"
,
t
.
exp
,
"
of sp.Pow node"
)
warnings
.
warn
(
f
"Cannot handle exponent
{
t
.
exp
}
of sp.Pow node"
)
else
:
warnings
.
warn
(
"Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate"
)
...
...
@@ -526,7 +527,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif
isinstance
(
t
,
sp
.
Rel
):
pass
else
:
warnings
.
warn
(
"Unknown sympy node of type
"
+
str
(
t
.
func
)
+
"
counting will be inaccurate"
)
warnings
.
warn
(
f
"Unknown sympy node of type
{
str
(
t
.
func
)
}
counting will be inaccurate"
)
if
visit_children
:
for
a
in
t
.
args
:
...
...
pystencils/transformations.py
View file @
0d6780b8
...
...
@@ -1206,13 +1206,13 @@ def get_loop_hierarchy(ast_node):
return
reversed
(
result
)
def
get_loop_counter_symbol_hierarchy
(
ast
N
ode
):
def
get_loop_counter_symbol_hierarchy
(
ast
_n
ode
):
"""Determines the loop counter symbols around a given AST node.
:param ast
N
ode: the AST node
:param ast
_n
ode: the AST node
:return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
"""
result
=
[]
node
=
ast
N
ode
node
=
ast
_n
ode
while
node
is
not
None
:
node
=
get_next_parent_of_type
(
node
,
ast
.
LoopOverCoordinate
)
if
node
:
...
...
pystencils/utils.py
View file @
0d6780b8
import
os
import
itertools
from
collections
import
Counter
from
contextlib
import
contextmanager
from
tempfile
import
NamedTemporaryFile
...
...
@@ -96,16 +97,21 @@ def fully_contains(l1, l2):
def
boolean_array_bounding_box
(
boolean_array
):
"""Returns bounding box around "true" area of boolean array"""
dim
=
len
(
boolean_array
.
shape
)
"""Returns bounding box around "true" area of boolean array
>>> a = np.zeros((4, 4), dtype=bool)
>>> a[1:-1, 1:-1] = True
>>> boolean_array_bounding_box(a)
[(1, 3), (1, 3)]
"""
dim
=
boolean_array
.
ndim
shape
=
boolean_array
.
shape
assert
0
not
in
shape
,
"Shape must not contain zero"
bounds
=
[]
for
i
in
range
(
dim
):
for
j
in
range
(
dim
):
if
i
!=
j
:
arr_1d
=
np
.
any
(
boolean_array
,
axis
=
j
)
begin
=
np
.
argmax
(
arr_1d
)
end
=
begin
+
np
.
argmin
(
arr_1d
[
begin
:])
bounds
.
append
((
begin
,
end
))
for
ax
in
itertools
.
combinations
(
reversed
(
range
(
dim
)),
dim
-
1
):
nonzero
=
np
.
any
(
boolean_array
,
axis
=
ax
)
t
=
np
.
where
(
nonzero
)[
0
][[
0
,
-
1
]]
bounds
.
append
((
t
[
0
],
t
[
1
]
+
1
))
return
bounds
...
...
@@ -217,7 +223,8 @@ class LinearEquationSystem:
return
'multiple'
def
solution
(
self
):
"""Solves the system if it has a single solution. Returns a dictionary mapping symbol to solution value."""
"""Solves the system. Under- and overdetermined systems are supported.
Returns a dictionary mapping symbol to solution value."""
return
sp
.
solve_linear_system
(
self
.
_matrix
,
*
self
.
unknowns
)
def
_resize_if_necessary
(
self
,
new_rows
=
1
):
...
...
@@ -233,8 +240,3 @@ class LinearEquationSystem:
break
result
-=
1
self
.
next_zero_row
=
result
def
find_unique_solutions_with_zeros
(
system
:
LinearEquationSystem
):
if
not
system
.
solution_structure
()
!=
'multiple'
:
raise
ValueError
(
"Function works only for underdetermined systems"
)
pystencils_tests/test_assignment_collection.py
View file @
0d6780b8
import
pytest
import
sympy
as
sp
import
pystencils
as
ps
from
pystencils
import
Assignment
,
AssignmentCollection
from
pystencils.astnodes
import
Conditional
from
pystencils.simp.assignment_collection
import
SymbolGen
a
,
b
,
c
=
sp
.
symbols
(
"a b c"
)
x
,
y
,
z
,
t
=
sp
.
symbols
(
"x y z t"
)
symbol_gen
=
SymbolGen
(
"a"
)
f
=
ps
.
fields
(
"f(2) : [2D]"
)
d
=
ps
.
fields
(
"d(2) : [2D]"
)
def
test_assignment_collection
():
x
,
y
,
z
,
t
=
sp
.
symbols
(
"x y z t"
)
symbol_gen
=
SymbolGen
(
"a"
)
def
test_assignment_collection
():
ac
=
AssignmentCollection
([
Assignment
(
z
,
x
+
y
)],
[],
subexpression_symbol_generator
=
symbol_gen
)
...
...
@@ -32,10 +36,6 @@ def test_assignment_collection():
def
test_free_and_defined_symbols
():
x
,
y
,
z
,
t
=
sp
.
symbols
(
"x y z t"
)
a
,
b
=
sp
.
symbols
(
"a b"
)
symbol_gen
=
SymbolGen
(
"a"
)
ac
=
AssignmentCollection
([
Assignment
(
z
,
x
+
y
),
Conditional
(
t
>
0
,
Assignment
(
a
,
b
+
1
),
Assignment
(
a
,
b
+
2
))],
[],
subexpression_symbol_generator
=
symbol_gen
)
...
...
@@ -45,35 +45,128 @@ def test_free_and_defined_symbols():
def
test_vector_assignments
():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import
pystencils
as
ps
import
sympy
as
sp
a
,
b
,
c
=
sp
.
symbols
(
"a b c"
)
assignments
=
ps
.
Assignment
(
sp
.
Matrix
([
a
,
b
,
c
]),
sp
.
Matrix
([
1
,
2
,
3
]))
assignments
=
ps
.
Assignment
(
sp
.
Matrix
([
a
,
b
,
c
]),
sp
.
Matrix
([
1
,
2
,
3
]))
print
(
assignments
)
def
test_wrong_vector_assignments
():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import
pystencils
as
ps
import
sympy
as
sp
a
,
b
=
sp
.
symbols
(
"a b"
)
with
pytest
.
raises
(
AssertionError
,
match
=
r
'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'
):
ps
.
Assignment
(
sp
.
Matrix
([
a
,
b
]),
sp
.
Matrix
([
1
,
2
,
3
]))
match
=
r
'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'
):
ps
.
Assignment
(
sp
.
Matrix
([
a
,
b
]),
sp
.
Matrix
([
1
,
2
,
3
]))
def
test_vector_assignment_collection
():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import
pystencils
as
ps
import
sympy
as
sp
a
,
b
,
c
=
sp
.
symbols
(
"a b c"
)
y
,
x
=
sp
.
Matrix
([
a
,
b
,
c
]),
sp
.
Matrix
([
1
,
2
,
3
])
assignments
=
ps
.
AssignmentCollection
({
y
:
x
})
y_m
,
x_m
=
sp
.
Matrix
([
a
,
b
,
c
]),
sp
.
Matrix
([
1
,
2
,
3
])
assignments
=
ps
.
AssignmentCollection
({
y_m
:
x_m
})
print
(
assignments
)
assignments
=
ps
.
AssignmentCollection
([
ps
.
Assignment
(
y
,
x
)])
assignments
=
ps
.
AssignmentCollection
([
ps
.
Assignment
(
y
_m
,
x_m
)])
print
(
assignments
)
def
test_new_with_substitutions
():
a1
=
ps
.
Assignment
(
f
[
0
,
0
](
0
),
a
*
b
)
a2
=
ps
.
Assignment
(
f
[
0
,
0
](
1
),
b
*
c
)
ac
=
ps
.
AssignmentCollection
([
a1
,
a2
],
subexpressions
=
[])
subs_dict
=
{
f
[
0
,
0
](
0
):
d
[
0
,
0
](
0
),
f
[
0
,
0
](
1
):
d
[
0
,
0
](
1
)}
subs_ac
=
ac
.
new_with_substitutions
(
subs_dict
,
add_substitutions_as_subexpressions
=
False
,
substitute_on_lhs
=
True
,
sort_topologically
=
True
)
assert
subs_ac
.
main_assignments
[
0
].
lhs
==
d
[
0
,
0
](
0
)
assert
subs_ac
.
main_assignments
[
1
].
lhs
==
d
[
0
,
0
](
1
)
subs_ac
=
ac
.
new_with_substitutions
(
subs_dict
,
add_substitutions_as_subexpressions
=
False
,
substitute_on_lhs
=
False
,
sort_topologically
=
True
)
assert
subs_ac
.
main_assignments
[
0
].
lhs
==
f
[
0
,
0
](
0
)
assert
subs_ac
.
main_assignments
[
1
].
lhs
==
f
[
0
,
0
](
1
)
subs_dict
=
{
a
*
b
:
sp
.
symbols
(
'xi'
)}
subs_ac
=
ac
.
new_with_substitutions
(
subs_dict
,
add_substitutions_as_subexpressions
=
False
,
substitute_on_lhs
=
False
,
sort_topologically
=
True
)
assert
subs_ac
.
main_assignments
[
0
].
rhs
==
sp
.
symbols
(
'xi'
)
assert
len
(
subs_ac
.
subexpressions
)
==
0
subs_ac
=
ac
.
new_with_substitutions
(
subs_dict
,
add_substitutions_as_subexpressions
=
True
,
substitute_on_lhs
=
False
,
sort_topologically
=
True
)
assert
subs_ac
.
main_assignments
[
0
].
rhs
==
sp
.
symbols
(
'xi'
)
assert
len
(
subs_ac
.
subexpressions
)
==
1
assert
subs_ac
.
subexpressions
[
0
].
lhs
==
sp
.
symbols
(
'xi'
)
def
test_copy
():
a1
=
ps
.
Assignment
(
f
[
0
,
0
](
0
),
a
*
b
)
a2
=
ps
.
Assignment
(
f
[
0
,
0
](
1
),
b
*
c
)
ac
=
ps
.
AssignmentCollection
([
a1
,
a2
],
subexpressions
=
[])
ac2
=
ac
.
copy
()
assert
ac2
==
ac
def
test_set_expressions
():
a1
=
ps
.
Assignment
(
f
[
0
,
0
](
0
),
a
*
b
)
a2
=
ps
.
Assignment
(
f
[
0
,
0
](
1
),
b
*
c
)
ac
=
ps
.
AssignmentCollection
([
a1
,
a2
],
subexpressions
=
[])
ac
.
set_main_assignments_from_dict
({
d
[
0
,
0
](
0
):
b
*
c
})
assert
len
(
ac
.
main_assignments
)
==
1
assert
ac
.
main_assignments
[
0
]
==
ps
.
Assignment
(
d
[
0
,
0
](
0
),
b
*
c
)
ac
.
set_sub_expressions_from_dict
({
sp
.
symbols
(
'xi'
):
a
*
b
})
assert
len
(
ac
.
subexpressions
)
==
1
assert
ac
.
subexpressions
[
0
]
==
ps
.
Assignment
(
sp
.
symbols
(
'xi'
),
a
*
b
)
ac
=
ac
.
new_without_subexpressions
(
subexpressions_to_keep
=
{
sp
.
symbols
(
'xi'
)})
assert
ac
.
subexpressions
[
0
]
==
ps
.
Assignment
(
sp
.
symbols
(
'xi'
),
a
*
b
)
ac
=
ac
.
new_without_unused_subexpressions
()
assert
len
(
ac
.
subexpressions
)
==
0
ac2
=
ac
.
new_without_subexpressions
()
assert
ac
==
ac2
def
test_free_and_bound_symbols
():
a1
=
ps
.
Assignment
(
a
,
d
[
0
,
0
](
0
))
a2
=
ps
.
Assignment
(
f
[
0
,
0
](
1
),
b
*
c
)
ac
=
ps
.
AssignmentCollection
([
a2
],
subexpressions
=
[
a1
])
assert
f
[
0
,
0
](
1
)
in
ac
.
bound_symbols
assert
d
[
0
,
0
](
0
)
in
ac
.
free_symbols
def
test_new_merged
():
a1
=
ps
.
Assignment
(
a
,
b
*
c
)
a2
=
ps
.
Assignment
(
a
,
x
*
y
)
a3
=
ps
.
Assignment
(
t
,
x
**
2
)
# main assignments
a4
=
ps
.
Assignment
(
f
[
0
,
0
](
0
),
a
)
a5
=
ps
.
Assignment
(
d
[
0
,
0
](
0
),
a
)
ac
=
ps
.
AssignmentCollection
([
a4
],
subexpressions
=
[
a1
])
ac2
=
ps
.
AssignmentCollection
([
a5
],
subexpressions
=
[
a2
,
a3
])
merged_ac
=
ac
.
new_merged
(
ac2
)
assert
len
(
merged_ac
.
subexpressions
)
==
3
assert
len
(
merged_ac
.
main_assignments
)
==
2
assert
ps
.
Assignment
(
sp
.
symbols
(
'xi_0'
),
x
*
y
)
in
merged_ac
.
subexpressions
assert
ps
.
Assignment
(
d
[
0
,
0
](
0
),
sp
.
symbols
(
'xi_0'
))
in
merged_ac
.
main_assignments
assert
a1
in
merged_ac
.
subexpressions
assert
a3
in
merged_ac
.
subexpressions
pystencils_tests/test_astnodes.py
0 → 100644
View file @
0d6780b8
import
sympy
as
sp
import
pystencils
as
ps
from
pystencils
import
Assignment
from
pystencils.astnodes
import
Block
,
SkipIteration
,
LoopOverCoordinate
,
SympyAssignment
from
sympy.codegen.rewriting
import
optims_c99
dst
=
ps
.
fields
(
'dst(8): double[2D]'
)
s
=
sp
.
symbols
(
's_:8'
)
x
=
sp
.
symbols
(
'x'
)
y
=
sp
.
symbols
(
'y'
)
def
test_kernel_function
():
assignments
=
[
Assignment
(
dst
[
0
,
0
](
0
),
s
[
0
]),
Assignment
(
x
,
dst
[
0
,
0
](
2
))
]
ast_node
=
ps
.
create_kernel
(
assignments
)
assert
ast_node
.
target
==
'cpu'
assert
ast_node
.
backend
==
'c'
# symbols_defined and undefined_symbols will always return an emtpy set
assert
ast_node
.
symbols_defined
==
set
()
assert
ast_node
.
undefined_symbols
==
set
()
assert
ast_node
.
fields_written
==
{
dst
}
assert
ast_node
.
fields_read
==
{
dst
}
def
test_skip_iteration
():
# skip iteration is an object which should give back empty data structures.
skipped
=
SkipIteration
()
assert
skipped
.
args
==
[]
assert
skipped
.
symbols_defined
==
set
()
assert
skipped
.
undefined_symbols
==
set
()
def
test_block
():
assignments
=
[
Assignment
(
dst
[
0
,
0
](
0
),
s
[
0
]),
Assignment
(
x
,
dst
[
0
,
0
](
2
))
]
bl
=
Block
(
assignments
)
assert
bl
.
symbols_defined
==
{
dst
[
0
,
0
](
0
),
dst
[
0
,
0
](
2
),
s
[
0
],
x
}
bl
.
append
([
Assignment
(
y
,
10
)])
assert
bl
.
symbols_defined
==
{
dst
[
0
,
0
](
0
),
dst
[
0
,
0
](
2
),
s
[
0
],
x
,
y
}
assert
len
(
bl
.
args
)
==
3
list_iterator
=
iter
([
Assignment
(
s
[
1
],
11
)])
bl
.
insert_front
(
list_iterator
)
assert
bl
.
args
[
0
]
==
Assignment
(
s
[
1
],
11
)
def
test_loop_over_coordinate
():
assignments
=
[
Assignment
(
dst
[
0
,
0
](
0
),
s
[
0
]),
Assignment
(
x
,
dst
[
0
,
0
](
2
))
]
body
=
Block
(
assignments
)
loop
=
LoopOverCoordinate
(
body
,
coordinate_to_loop_over
=
0
,
start
=
0
,
stop
=
10
,
step
=
1
)
assert
loop
.
body
==
body
new_body
=
Block
([
assignments
[
0
]])
loop
=
loop
.
new_loop_with_different_body
(
new_body
)
assert
loop
.
body
==
new_body
assert
loop
.
start
==
0
assert
loop
.
stop
==
10
assert
loop
.
step
==
1
loop
.
replace
(
loop
.
start
,
2
)
loop
.
replace
(
loop
.
stop
,
20
)
loop
.
replace
(
loop
.
step
,
2
)
assert
loop
.
start
==
2
assert
loop
.
stop
==
20
assert
loop
.
step
==
2
def
test_sympy_assignment
():