Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jean-Noël Grad
pystencils
Commits
0d6780b8
Commit
0d6780b8
authored
Aug 10, 2020
by
Jan Hönig
Browse files
Merge branch 'Extend_testsuite' into 'master'
Extend testsuite See merge request
pycodegen/pystencils!168
parents
4f41d979
20118400
Changes
32
Hide whitespace changes
Inline
Side-by-side
pystencils/backends/cbackend.py
View file @
0d6780b8
...
...
@@ -298,7 +298,7 @@ class CBackend:
return
node
.
get_code
(
self
.
_dialect
,
self
.
_vector_instruction_set
)
def
_print_SourceCodeComment
(
self
,
node
):
return
"/*
"
+
node
.
text
+
"
*/"
return
f
"/*
{
node
.
text
}
*/"
def
_print_EmptyLine
(
self
,
node
):
return
""
...
...
@@ -316,7 +316,7 @@ class CBackend:
result
=
f
"if (
{
condition_expr
}
)
\n
{
true_block
}
"
if
node
.
false_block
:
false_block
=
self
.
_print_Block
(
node
.
false_block
)
result
+=
"else
"
+
false_block
result
+=
f
"else
{
false_block
}
"
return
result
...
...
@@ -336,7 +336,7 @@ class CustomSympyPrinter(CCodePrinter):
return
self
.
_typed_number
(
expr
.
evalf
(),
get_type_of_expression
(
expr
))
if
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
0
<
expr
.
exp
<
8
:
return
"(
"
+
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
+
"
)"
return
f
"(
{
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
}
)"
elif
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
-
8
<
expr
.
exp
<
0
:
return
f
"1 / (
{
self
.
_print
(
sp
.
Mul
(
*
([
expr
.
base
]
*
-
expr
.
exp
),
evaluate
=
False
))
}
)"
else
:
...
...
@@ -589,9 +589,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result
=
self
.
instruction_set
[
'&'
].
format
(
result
,
item
)
return
result
def
_print_Max
(
self
,
expr
):
return
"test"
def
_print_Or
(
self
,
expr
):
result
=
self
.
_scalarFallback
(
'_print_Or'
,
expr
)
if
result
:
...
...
pystencils/backends/cuda_backend.py
View file @
0d6780b8
...
...
@@ -104,7 +104,6 @@ class CudaSympyPrinter(CustomSympyPrinter):
assert
len
(
expr
.
args
)
==
1
,
f
"__fsqrt_rn has one argument, but
{
len
(
expr
.
args
)
}
where given"
return
f
"__fsqrt_rn(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
elif
isinstance
(
expr
,
fast_inv_sqrt
):
print
(
len
(
expr
.
args
)
==
1
)
assert
len
(
expr
.
args
)
==
1
,
f
"__frsqrt_rn has one argument, but
{
len
(
expr
.
args
)
}
where given"
return
f
"__frsqrt_rn(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
return
super
().
_print_Function
(
expr
)
pystencils/datahandling/datahandling_interface.py
View file @
0d6780b8
...
...
@@ -86,6 +86,13 @@ class DataHandling(ABC):
Args:
description (str): String description of the fields to add
dtype: data type of the array as numpy data type
ghost_layers: number of ghost layers - if not specified a default value specified in the constructor
is used
layout: memory layout of array, either structure of arrays 'SoA' or array of structures 'AoS'.
this is only important if values_per_cell > 1
cpu: allocate field on the CPU
gpu: allocate field on the GPU, if None, a GPU field is allocated if default_target is 'gpu'
alignment: either False for no alignment, or the number of bytes to align to
Returns:
Fields representing the just created arrays
"""
...
...
@@ -200,6 +207,10 @@ class DataHandling(ABC):
directly passed to the kernel function and override possible parameters from the DataHandling
"""
@
abstractmethod
def
get_kernel_kwargs
(
self
,
kernel_function
,
**
kwargs
):
"""Returns the input arguments of a kernel"""
@
abstractmethod
def
swap
(
self
,
name1
,
name2
,
gpu
=
False
):
"""Swaps data of two arrays"""
...
...
pystencils/datahandling/serial_datahandling.py
View file @
0d6780b8
...
...
@@ -266,10 +266,10 @@ class SerialDataHandling(DataHandling):
return
name
in
self
.
gpu_arrays
def
synchronization_function_cpu
(
self
,
names
,
stencil_name
=
None
,
**
_
):
return
self
.
synchronization_function
(
names
,
stencil_name
,
'cpu'
)
return
self
.
synchronization_function
(
names
,
stencil_name
,
target
=
'cpu'
)
def
synchronization_function_gpu
(
self
,
names
,
stencil_name
=
None
,
**
_
):
return
self
.
synchronization_function
(
names
,
stencil_name
,
'gpu'
)
return
self
.
synchronization_function
(
names
,
stencil_name
,
target
=
'gpu'
)
def
synchronization_function
(
self
,
names
,
stencil
=
None
,
target
=
None
,
**
_
):
if
target
is
None
:
...
...
@@ -425,14 +425,15 @@ class SerialDataHandling(DataHandling):
np
.
savez_compressed
(
file
,
**
self
.
cpu_arrays
)
def
load_all
(
self
,
file
):
if
'.npz'
not
in
file
:
file
+=
'.npz'
file_contents
=
np
.
load
(
file
)
for
arr_name
,
arr_contents
in
self
.
cpu_arrays
.
items
():
if
arr_name
not
in
file_contents
:
print
(
f
"Skipping read data
{
arr_name
}
because there is no data with this name in data handling"
)
continue
if
file_contents
[
arr_name
].
shape
!=
arr_contents
.
shape
:
print
(
"Skipping read data {} because shapes don't match. "
"Read array shape {}, existing array shape {}"
.
format
(
arr_name
,
file_contents
[
arr_name
].
shape
,
arr_contents
.
shape
))
print
(
f
"Skipping read data
{
arr_name
}
because shapes don't match. "
f
"Read array shape
{
file_contents
[
arr_name
].
shape
}
, existing array shape
{
arr_contents
.
shape
}
"
)
continue
np
.
copyto
(
arr_contents
,
file_contents
[
arr_name
])
pystencils/fd/derivative.py
View file @
0d6780b8
...
...
@@ -228,7 +228,9 @@ def diff_terms(expr):
Example:
>>> x, y = sp.symbols("x, y")
>>> diff_terms( diff(x, 0, 0) )
>>> diff_terms( diff(x, 0, 0) )
{Diff(Diff(x, 0, -1), 0, -1)}
>>> diff_terms( diff(x, 0, 0) + y )
{Diff(Diff(x, 0, -1), 0, -1)}
"""
result
=
set
()
...
...
pystencils/simp/__init__.py
View file @
0d6780b8
from
.assignment_collection
import
AssignmentCollection
from
.simplifications
import
(
add_subexpressions_for_divisions
,
add_subexpressions_for_field_reads
,
apply_on_all_subexpressions
,
apply_to_all_assignments
,
add_subexpressions_for_sums
,
apply_on_all_subexpressions
,
apply_to_all_assignments
,
subexpression_substitution_in_existing_subexpressions
,
subexpression_substitution_in_main_assignments
,
sympy_cse
,
sympy_cse_on_assignment_list
)
from
.simplificationstrategy
import
SimplificationStrategy
...
...
@@ -10,4 +10,4 @@ __all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse'
,
'sympy_cse_on_assignment_list'
,
'apply_to_all_assignments'
,
'apply_on_all_subexpressions'
,
'subexpression_substitution_in_existing_subexpressions'
,
'subexpression_substitution_in_main_assignments'
,
'add_subexpressions_for_divisions'
,
'add_subexpressions_for_field_reads'
]
'add_subexpressions_for_sums'
,
'add_subexpressions_for_field_reads'
]
pystencils/simp/assignment_collection.py
View file @
0d6780b8
...
...
@@ -6,8 +6,7 @@ import sympy as sp
import
pystencils
from
pystencils.assignment
import
Assignment
from
pystencils.simp.simplifications
import
(
sort_assignments_topologically
,
transform_lhs_and_rhs
,
transform_rhs
)
from
pystencils.simp.simplifications
import
(
sort_assignments_topologically
,
transform_lhs_and_rhs
,
transform_rhs
)
from
pystencils.sympyextensions
import
count_operations
,
fast_subs
...
...
@@ -263,7 +262,7 @@ class AssignmentCollection:
own_definitions
=
set
([
e
.
lhs
for
e
in
self
.
main_assignments
])
other_definitions
=
set
([
e
.
lhs
for
e
in
other
.
main_assignments
])
assert
len
(
own_definitions
.
intersection
(
other_definitions
))
==
0
,
\
"Cannot
new_
merge
d
, since both
collection
define the same symbols"
"Cannot merge
collections
, since both define the same symbols"
own_subexpression_symbols
=
{
e
.
lhs
:
e
.
rhs
for
e
in
self
.
subexpressions
}
substitution_dict
=
{}
...
...
@@ -334,7 +333,7 @@ class AssignmentCollection:
kept_subexpressions
=
[]
if
self
.
subexpressions
[
0
].
lhs
in
subexpressions_to_keep
:
substitution_dict
=
{}
kept_subexpressions
=
self
.
subexpressions
[
0
]
kept_subexpressions
.
append
(
self
.
subexpressions
[
0
]
)
else
:
substitution_dict
=
{
self
.
subexpressions
[
0
].
lhs
:
self
.
subexpressions
[
0
].
rhs
}
...
...
pystencils/simp/simplifications.py
View file @
0d6780b8
...
...
@@ -18,7 +18,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
elif
isinstance
(
e1
,
Node
):
symbols
=
e1
.
symbols_defined
else
:
raise
NotImplementedError
(
"Cannot sort topologically. Object of type
"
+
type
(
e1
)
+
"
cannot be handled."
)
raise
NotImplementedError
(
f
"Cannot sort topologically. Object of type
{
type
(
e1
)
}
cannot be handled."
)
for
lhs
in
symbols
:
for
c2
,
e2
in
enumerate
(
assignments
):
...
...
@@ -112,14 +112,14 @@ def add_subexpressions_for_sums(ac):
addends
=
[]
def
contains_sum
(
term
):
if
term
.
func
==
sp
.
add
.
Add
:
if
term
.
func
==
sp
.
Add
:
return
True
if
term
.
is_Atom
:
return
False
return
any
([
contains_sum
(
a
)
for
a
in
term
.
args
])
def
search_addends
(
term
):
if
term
.
func
==
sp
.
add
.
Add
:
if
term
.
func
==
sp
.
Add
:
if
all
([
not
contains_sum
(
a
)
for
a
in
term
.
args
]):
addends
.
extend
(
term
.
args
)
for
a
in
term
.
args
:
...
...
pystencils/stencil.py
View file @
0d6780b8
...
...
@@ -34,6 +34,8 @@ def is_valid(stencil, max_neighborhood=None):
True
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=1)
False
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=2)
True
"""
expected_dim
=
len
(
stencil
[
0
])
for
d
in
stencil
:
...
...
@@ -67,8 +69,11 @@ def have_same_entries(s1, s2):
Examples:
>>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)]
>>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)]
>>> stencil3 = [(-1, 0), (0, -1), (1, 0)]
>>> have_same_entries(stencil1, stencil2)
True
>>> have_same_entries(stencil1, stencil3)
False
"""
if
len
(
s1
)
!=
len
(
s2
):
return
False
...
...
pystencils/sympyextensions.py
View file @
0d6780b8
...
...
@@ -272,7 +272,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
def
replace_second_order_products
(
expr
:
sp
.
Expr
,
search_symbols
:
Iterable
[
sp
.
Symbol
],
positive
:
Optional
[
bool
]
=
None
,
replace_mixed
:
Optional
[
List
[
Assignment
]]
=
None
)
->
sp
.
Expr
:
"""Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ).
"""Replaces second order mixed terms like
4*
x*y by 2*( (x+y)**2 - x**2 - y**2 ).
This makes the term longer - simplify usually is undoing these - however this
transformation can be done to find more common sub-expressions
...
...
@@ -293,7 +293,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym
if
expr
.
is_Mul
:
distinct_search_symbols
=
set
()
nr_of_search_terms
=
0
other_factors
=
1
other_factors
=
sp
.
Integer
(
1
)
for
t
in
expr
.
args
:
if
t
in
search_symbols
:
nr_of_search_terms
+=
1
...
...
@@ -509,13 +509,14 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
if
t
.
exp
>=
0
:
result
[
'muls'
]
+=
int
(
t
.
exp
)
-
1
else
:
result
[
'muls'
]
-=
1
if
result
[
'muls'
]
>
0
:
result
[
'muls'
]
-=
1
result
[
'divs'
]
+=
1
result
[
'muls'
]
+=
(
-
int
(
t
.
exp
))
-
1
elif
sp
.
nsimplify
(
t
.
exp
)
==
sp
.
Rational
(
1
,
2
):
result
[
'sqrts'
]
+=
1
else
:
warnings
.
warn
(
"Cannot handle exponent
"
,
t
.
exp
,
"
of sp.Pow node"
)
warnings
.
warn
(
f
"Cannot handle exponent
{
t
.
exp
}
of sp.Pow node"
)
else
:
warnings
.
warn
(
"Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate"
)
...
...
@@ -526,7 +527,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif
isinstance
(
t
,
sp
.
Rel
):
pass
else
:
warnings
.
warn
(
"Unknown sympy node of type
"
+
str
(
t
.
func
)
+
"
counting will be inaccurate"
)
warnings
.
warn
(
f
"Unknown sympy node of type
{
str
(
t
.
func
)
}
counting will be inaccurate"
)
if
visit_children
:
for
a
in
t
.
args
:
...
...
pystencils/transformations.py
View file @
0d6780b8
...
...
@@ -1206,13 +1206,13 @@ def get_loop_hierarchy(ast_node):
return
reversed
(
result
)
def
get_loop_counter_symbol_hierarchy
(
ast
N
ode
):
def
get_loop_counter_symbol_hierarchy
(
ast
_n
ode
):
"""Determines the loop counter symbols around a given AST node.
:param ast
N
ode: the AST node
:param ast
_n
ode: the AST node
:return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
"""
result
=
[]
node
=
ast
N
ode
node
=
ast
_n
ode
while
node
is
not
None
:
node
=
get_next_parent_of_type
(
node
,
ast
.
LoopOverCoordinate
)
if
node
:
...
...
pystencils/utils.py
View file @
0d6780b8
import
os
import
itertools
from
collections
import
Counter
from
contextlib
import
contextmanager
from
tempfile
import
NamedTemporaryFile
...
...
@@ -96,16 +97,21 @@ def fully_contains(l1, l2):
def
boolean_array_bounding_box
(
boolean_array
):
"""Returns bounding box around "true" area of boolean array"""
dim
=
len
(
boolean_array
.
shape
)
"""Returns bounding box around "true" area of boolean array
>>> a = np.zeros((4, 4), dtype=bool)
>>> a[1:-1, 1:-1] = True
>>> boolean_array_bounding_box(a)
[(1, 3), (1, 3)]
"""
dim
=
boolean_array
.
ndim
shape
=
boolean_array
.
shape
assert
0
not
in
shape
,
"Shape must not contain zero"
bounds
=
[]
for
i
in
range
(
dim
):
for
j
in
range
(
dim
):
if
i
!=
j
:
arr_1d
=
np
.
any
(
boolean_array
,
axis
=
j
)
begin
=
np
.
argmax
(
arr_1d
)
end
=
begin
+
np
.
argmin
(
arr_1d
[
begin
:])
bounds
.
append
((
begin
,
end
))
for
ax
in
itertools
.
combinations
(
reversed
(
range
(
dim
)),
dim
-
1
):
nonzero
=
np
.
any
(
boolean_array
,
axis
=
ax
)
t
=
np
.
where
(
nonzero
)[
0
][[
0
,
-
1
]]
bounds
.
append
((
t
[
0
],
t
[
1
]
+
1
))
return
bounds
...
...
@@ -217,7 +223,8 @@ class LinearEquationSystem:
return
'multiple'
def
solution
(
self
):
"""Solves the system if it has a single solution. Returns a dictionary mapping symbol to solution value."""
"""Solves the system. Under- and overdetermined systems are supported.
Returns a dictionary mapping symbol to solution value."""
return
sp
.
solve_linear_system
(
self
.
_matrix
,
*
self
.
unknowns
)
def
_resize_if_necessary
(
self
,
new_rows
=
1
):
...
...
@@ -233,8 +240,3 @@ class LinearEquationSystem:
break
result
-=
1
self
.
next_zero_row
=
result
def
find_unique_solutions_with_zeros
(
system
:
LinearEquationSystem
):
if
not
system
.
solution_structure
()
!=
'multiple'
:
raise
ValueError
(
"Function works only for underdetermined systems"
)
pystencils_tests/test_assignment_collection.py
View file @
0d6780b8
import
pytest
import
sympy
as
sp
import
pystencils
as
ps
from
pystencils
import
Assignment
,
AssignmentCollection
from
pystencils.astnodes
import
Conditional
from
pystencils.simp.assignment_collection
import
SymbolGen
a
,
b
,
c
=
sp
.
symbols
(
"a b c"
)
x
,
y
,
z
,
t
=
sp
.
symbols
(
"x y z t"
)
symbol_gen
=
SymbolGen
(
"a"
)
f
=
ps
.
fields
(
"f(2) : [2D]"
)
d
=
ps
.
fields
(
"d(2) : [2D]"
)
def
test_assignment_collection
():
x
,
y
,
z
,
t
=
sp
.
symbols
(
"x y z t"
)
symbol_gen
=
SymbolGen
(
"a"
)
def
test_assignment_collection
():
ac
=
AssignmentCollection
([
Assignment
(
z
,
x
+
y
)],
[],
subexpression_symbol_generator
=
symbol_gen
)
...
...
@@ -32,10 +36,6 @@ def test_assignment_collection():
def
test_free_and_defined_symbols
():
x
,
y
,
z
,
t
=
sp
.
symbols
(
"x y z t"
)
a
,
b
=
sp
.
symbols
(
"a b"
)
symbol_gen
=
SymbolGen
(
"a"
)
ac
=
AssignmentCollection
([
Assignment
(
z
,
x
+
y
),
Conditional
(
t
>
0
,
Assignment
(
a
,
b
+
1
),
Assignment
(
a
,
b
+
2
))],
[],
subexpression_symbol_generator
=
symbol_gen
)
...
...
@@ -45,35 +45,128 @@ def test_free_and_defined_symbols():
def
test_vector_assignments
():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import
pystencils
as
ps
import
sympy
as
sp
a
,
b
,
c
=
sp
.
symbols
(
"a b c"
)
assignments
=
ps
.
Assignment
(
sp
.
Matrix
([
a
,
b
,
c
]),
sp
.
Matrix
([
1
,
2
,
3
]))
assignments
=
ps
.
Assignment
(
sp
.
Matrix
([
a
,
b
,
c
]),
sp
.
Matrix
([
1
,
2
,
3
]))
print
(
assignments
)
def
test_wrong_vector_assignments
():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import
pystencils
as
ps
import
sympy
as
sp
a
,
b
=
sp
.
symbols
(
"a b"
)
with
pytest
.
raises
(
AssertionError
,
match
=
r
'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'
):
ps
.
Assignment
(
sp
.
Matrix
([
a
,
b
]),
sp
.
Matrix
([
1
,
2
,
3
]))
match
=
r
'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'
):
ps
.
Assignment
(
sp
.
Matrix
([
a
,
b
]),
sp
.
Matrix
([
1
,
2
,
3
]))
def
test_vector_assignment_collection
():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import
pystencils
as
ps
import
sympy
as
sp
a
,
b
,
c
=
sp
.
symbols
(
"a b c"
)
y
,
x
=
sp
.
Matrix
([
a
,
b
,
c
]),
sp
.
Matrix
([
1
,
2
,
3
])
assignments
=
ps
.
AssignmentCollection
({
y
:
x
})
y_m
,
x_m
=
sp
.
Matrix
([
a
,
b
,
c
]),
sp
.
Matrix
([
1
,
2
,
3
])
assignments
=
ps
.
AssignmentCollection
({
y_m
:
x_m
})
print
(
assignments
)
assignments
=
ps
.
AssignmentCollection
([
ps
.
Assignment
(
y
,
x
)])
assignments
=
ps
.
AssignmentCollection
([
ps
.
Assignment
(
y
_m
,
x_m
)])
print
(
assignments
)
def
test_new_with_substitutions
():
a1
=
ps
.
Assignment
(
f
[
0
,
0
](
0
),
a
*
b
)
a2
=
ps
.
Assignment
(
f
[
0
,
0
](
1
),
b
*
c
)
ac
=
ps
.
AssignmentCollection
([
a1
,
a2
],
subexpressions
=
[])
subs_dict
=
{
f
[
0
,
0
](
0
):
d
[
0
,
0
](
0
),
f
[
0
,
0
](
1
):
d
[
0
,
0
](
1
)}
subs_ac
=
ac
.
new_with_substitutions
(
subs_dict
,
add_substitutions_as_subexpressions
=
False
,
substitute_on_lhs
=
True
,
sort_topologically
=
True
)
assert
subs_ac
.
main_assignments
[
0
].
lhs
==
d
[
0
,
0
](
0
)
assert
subs_ac
.
main_assignments
[
1
].
lhs
==
d
[
0
,
0
](
1
)
subs_ac
=
ac
.
new_with_substitutions
(
subs_dict
,
add_substitutions_as_subexpressions
=
False
,
substitute_on_lhs
=
False
,
sort_topologically
=
True
)
assert
subs_ac
.
main_assignments
[
0
].
lhs
==
f
[
0
,
0
](
0
)
assert
subs_ac
.
main_assignments
[
1
].
lhs
==
f
[
0
,
0
](
1
)
subs_dict
=
{
a
*
b
:
sp
.
symbols
(
'xi'
)}
subs_ac
=
ac
.
new_with_substitutions
(
subs_dict
,
add_substitutions_as_subexpressions
=
False
,
substitute_on_lhs
=
False
,
sort_topologically
=
True
)
assert
subs_ac
.
main_assignments
[
0
].
rhs
==
sp
.
symbols
(
'xi'
)
assert
len
(
subs_ac
.
subexpressions
)
==
0
subs_ac
=
ac
.
new_with_substitutions
(
subs_dict
,
add_substitutions_as_subexpressions
=
True
,
substitute_on_lhs
=
False
,
sort_topologically
=
True
)
assert
subs_ac
.
main_assignments
[
0
].
rhs
==
sp
.
symbols
(
'xi'
)
assert
len
(
subs_ac
.
subexpressions
)
==
1
assert
subs_ac
.
subexpressions
[
0
].
lhs
==
sp
.
symbols
(
'xi'
)
def
test_copy
():
a1
=
ps
.
Assignment
(
f
[
0
,
0
](
0
),
a
*
b
)
a2
=
ps
.
Assignment
(
f
[
0
,
0
](
1
),
b
*
c
)
ac
=
ps
.
AssignmentCollection
([
a1
,
a2
],
subexpressions
=
[])
ac2
=
ac
.
copy
()
assert
ac2
==
ac
def
test_set_expressions
():
a1
=
ps
.
Assignment
(
f
[
0
,
0
](
0
),
a
*
b
)
a2
=
ps
.
Assignment
(
f
[
0
,
0
](
1
),
b
*
c
)
ac
=
ps
.
AssignmentCollection
([
a1
,
a2
],
subexpressions
=
[])
ac
.
set_main_assignments_from_dict
({
d
[
0
,
0
](
0
):
b
*
c
})
assert
len
(
ac
.
main_assignments
)
==
1
assert
ac
.
main_assignments
[
0
]
==
ps
.
Assignment
(
d
[
0
,
0
](
0
),
b
*
c
)
ac
.
set_sub_expressions_from_dict
({
sp
.
symbols
(
'xi'
):
a
*
b
})
assert
len
(
ac
.
subexpressions
)
==
1
assert
ac
.
subexpressions
[
0
]
==
ps
.
Assignment
(
sp
.
symbols
(
'xi'
),
a
*
b
)
ac
=
ac
.
new_without_subexpressions
(
subexpressions_to_keep
=
{
sp
.
symbols
(
'xi'
)})
assert
ac
.
subexpressions
[
0
]
==
ps
.
Assignment
(
sp
.
symbols
(
'xi'
),
a
*
b
)
ac
=
ac
.
new_without_unused_subexpressions
()
assert
len
(
ac
.
subexpressions
)
==
0
ac2
=
ac
.
new_without_subexpressions
()
assert
ac
==
ac2
def
test_free_and_bound_symbols
():
a1
=
ps
.
Assignment
(
a
,
d
[
0
,
0
](
0
))
a2
=
ps
.
Assignment
(
f
[
0
,
0
](
1
),
b
*
c
)
ac
=
ps
.
AssignmentCollection
([
a2
],
subexpressions
=
[
a1
])
assert
f
[
0
,
0
](
1
)
in
ac
.
bound_symbols
assert
d
[
0
,
0
](
0
)
in
ac
.
free_symbols
def
test_new_merged
():
a1
=
ps
.
Assignment
(
a
,
b
*
c
)
a2
=
ps
.
Assignment
(
a
,
x
*
y
)
a3
=
ps
.
Assignment
(
t
,
x
**
2
)
# main assignments
a4
=
ps
.
Assignment
(
f
[
0
,
0
](
0
),
a
)
a5
=
ps
.
Assignment
(
d
[
0
,
0
](
0
),
a
)
ac
=
ps
.
AssignmentCollection
([
a4
],
subexpressions
=
[
a1
])
ac2
=
ps
.
AssignmentCollection
([
a5
],
subexpressions
=
[
a2
,
a3
])
merged_ac
=
ac
.
new_merged
(
ac2
)
assert
len
(
merged_ac
.
subexpressions
)
==
3
assert
len
(
merged_ac
.
main_assignments
)
==
2
assert
ps
.
Assignment
(
sp
.
symbols
(
'xi_0'
),
x
*
y
)
in
merged_ac
.
subexpressions
assert
ps
.
Assignment
(
d
[
0
,
0
](
0
),
sp
.
symbols
(
'xi_0'
))
in
merged_ac
.
main_assignments
assert
a1
in
merged_ac
.
subexpressions
assert
a3
in
merged_ac
.
subexpressions
pystencils_tests/test_astnodes.py
0 → 100644
View file @
0d6780b8
import
sympy
as
sp
import
pystencils
as
ps
from
pystencils
import
Assignment
from
pystencils.astnodes
import
Block
,
SkipIteration
,
LoopOverCoordinate
,
SympyAssignment
from
sympy.codegen.rewriting
import
optims_c99
dst
=
ps
.
fields
(
'dst(8): double[2D]'
)
s
=
sp
.
symbols
(
's_:8'
)
x
=
sp
.
symbols
(
'x'
)
y
=
sp
.
symbols
(
'y'
)
def
test_kernel_function
():
assignments
=
[
Assignment
(
dst
[
0
,
0
](
0
),
s
[
0
]),
Assignment
(
x
,
dst
[
0
,
0
](
2
))
]
ast_node
=
ps
.
create_kernel
(
assignments
)
assert
ast_node
.
target
==
'cpu'
assert
ast_node
.
backend
==
'c'
# symbols_defined and undefined_symbols will always return an emtpy set
assert
ast_node
.
symbols_defined
==
set
()
assert
ast_node
.
undefined_symbols
==
set
()
assert
ast_node
.
fields_written
==
{
dst
}
assert
ast_node
.
fields_read
==
{
dst
}
def
test_skip_iteration
():
# skip iteration is an object which should give back empty data structures.
skipped
=
SkipIteration
()
assert
skipped
.
args
==
[]
assert
skipped
.
symbols_defined
==
set
()
assert
skipped
.
undefined_symbols
==
set
()
def
test_block
():
assignments
=
[
Assignment
(
dst
[
0
,
0
](
0
),
s
[
0
]),
Assignment
(
x
,
dst
[
0
,
0
](
2
))
]
bl
=
Block
(
assignments
)
assert
bl
.
symbols_defined
==
{
dst
[
0
,
0
](
0
),
dst
[
0
,
0
](
2
),
s
[
0
],
x
}
bl
.
append
([
Assignment
(
y
,
10
)])
assert
bl
.
symbols_defined
==
{
dst
[
0
,
0
](
0
),
dst
[
0
,
0
](
2
),
s
[
0
],
x
,
y
}
assert
len
(
bl
.
args
)
==
3
list_iterator
=
iter
([
Assignment
(
s
[
1
],
11
)])
bl
.
insert_front
(
list_iterator
)
assert
bl
.
args
[
0
]
==
Assignment
(
s
[
1
],
11
)
def
test_loop_over_coordinate
():
assignments
=
[
Assignment
(
dst
[
0
,
0
](
0
),
s
[
0
]),
Assignment
(
x
,
dst
[
0
,
0
](
2
))
]
body
=
Block
(
assignments
)
loop
=
LoopOverCoordinate
(
body
,
coordinate_to_loop_over
=
0
,
start
=
0
,
stop
=
10
,
step
=
1
)
assert
loop
.
body
==
body
new_body
=
Block
([
assignments
[
0
]])
loop
=
loop
.
new_loop_with_different_body
(
new_body
)
assert
loop
.
body
==
new_body
assert
loop
.
start
==
0
assert
loop
.
stop
==
10
assert
loop
.
step
==
1
loop
.
replace
(
loop
.
start
,
2
)
loop
.
replace
(
loop
.
stop
,
20
)
loop
.
replace
(
loop
.
step
,
2
)
assert
loop
.
start
==
2
assert
loop
.
stop
==
20
assert
loop
.
step
==
2
def
test_sympy_assignment
():