Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Sebastian Bindgen
pystencils
Commits
0d6780b8
Commit
0d6780b8
authored
Aug 10, 2020
by
Jan Hönig
Browse files
Merge branch 'Extend_testsuite' into 'master'
Extend testsuite See merge request
pycodegen/pystencils!168
parents
4f41d979
20118400
Changes
32
Hide whitespace changes
Inline
Side-by-side
pystencils/backends/cbackend.py
View file @
0d6780b8
...
@@ -298,7 +298,7 @@ class CBackend:
...
@@ -298,7 +298,7 @@ class CBackend:
return
node
.
get_code
(
self
.
_dialect
,
self
.
_vector_instruction_set
)
return
node
.
get_code
(
self
.
_dialect
,
self
.
_vector_instruction_set
)
def
_print_SourceCodeComment
(
self
,
node
):
def
_print_SourceCodeComment
(
self
,
node
):
return
"/*
"
+
node
.
text
+
"
*/"
return
f
"/*
{
node
.
text
}
*/"
def
_print_EmptyLine
(
self
,
node
):
def
_print_EmptyLine
(
self
,
node
):
return
""
return
""
...
@@ -316,7 +316,7 @@ class CBackend:
...
@@ -316,7 +316,7 @@ class CBackend:
result
=
f
"if (
{
condition_expr
}
)
\n
{
true_block
}
"
result
=
f
"if (
{
condition_expr
}
)
\n
{
true_block
}
"
if
node
.
false_block
:
if
node
.
false_block
:
false_block
=
self
.
_print_Block
(
node
.
false_block
)
false_block
=
self
.
_print_Block
(
node
.
false_block
)
result
+=
"else
"
+
false_block
result
+=
f
"else
{
false_block
}
"
return
result
return
result
...
@@ -336,7 +336,7 @@ class CustomSympyPrinter(CCodePrinter):
...
@@ -336,7 +336,7 @@ class CustomSympyPrinter(CCodePrinter):
return
self
.
_typed_number
(
expr
.
evalf
(),
get_type_of_expression
(
expr
))
return
self
.
_typed_number
(
expr
.
evalf
(),
get_type_of_expression
(
expr
))
if
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
0
<
expr
.
exp
<
8
:
if
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
0
<
expr
.
exp
<
8
:
return
"(
"
+
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
+
"
)"
return
f
"(
{
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
}
)"
elif
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
-
8
<
expr
.
exp
<
0
:
elif
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
-
8
<
expr
.
exp
<
0
:
return
f
"1 / (
{
self
.
_print
(
sp
.
Mul
(
*
([
expr
.
base
]
*
-
expr
.
exp
),
evaluate
=
False
))
}
)"
return
f
"1 / (
{
self
.
_print
(
sp
.
Mul
(
*
([
expr
.
base
]
*
-
expr
.
exp
),
evaluate
=
False
))
}
)"
else
:
else
:
...
@@ -589,9 +589,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
...
@@ -589,9 +589,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result
=
self
.
instruction_set
[
'&'
].
format
(
result
,
item
)
result
=
self
.
instruction_set
[
'&'
].
format
(
result
,
item
)
return
result
return
result
def
_print_Max
(
self
,
expr
):
return
"test"
def
_print_Or
(
self
,
expr
):
def
_print_Or
(
self
,
expr
):
result
=
self
.
_scalarFallback
(
'_print_Or'
,
expr
)
result
=
self
.
_scalarFallback
(
'_print_Or'
,
expr
)
if
result
:
if
result
:
...
...
pystencils/backends/cuda_backend.py
View file @
0d6780b8
...
@@ -104,7 +104,6 @@ class CudaSympyPrinter(CustomSympyPrinter):
...
@@ -104,7 +104,6 @@ class CudaSympyPrinter(CustomSympyPrinter):
assert
len
(
expr
.
args
)
==
1
,
f
"__fsqrt_rn has one argument, but
{
len
(
expr
.
args
)
}
where given"
assert
len
(
expr
.
args
)
==
1
,
f
"__fsqrt_rn has one argument, but
{
len
(
expr
.
args
)
}
where given"
return
f
"__fsqrt_rn(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
return
f
"__fsqrt_rn(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
elif
isinstance
(
expr
,
fast_inv_sqrt
):
elif
isinstance
(
expr
,
fast_inv_sqrt
):
print
(
len
(
expr
.
args
)
==
1
)
assert
len
(
expr
.
args
)
==
1
,
f
"__frsqrt_rn has one argument, but
{
len
(
expr
.
args
)
}
where given"
assert
len
(
expr
.
args
)
==
1
,
f
"__frsqrt_rn has one argument, but
{
len
(
expr
.
args
)
}
where given"
return
f
"__frsqrt_rn(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
return
f
"__frsqrt_rn(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
return
super
().
_print_Function
(
expr
)
return
super
().
_print_Function
(
expr
)
pystencils/datahandling/datahandling_interface.py
View file @
0d6780b8
...
@@ -86,6 +86,13 @@ class DataHandling(ABC):
...
@@ -86,6 +86,13 @@ class DataHandling(ABC):
Args:
Args:
description (str): String description of the fields to add
description (str): String description of the fields to add
dtype: data type of the array as numpy data type
dtype: data type of the array as numpy data type
ghost_layers: number of ghost layers - if not specified a default value specified in the constructor
is used
layout: memory layout of array, either structure of arrays 'SoA' or array of structures 'AoS'.
this is only important if values_per_cell > 1
cpu: allocate field on the CPU
gpu: allocate field on the GPU, if None, a GPU field is allocated if default_target is 'gpu'
alignment: either False for no alignment, or the number of bytes to align to
Returns:
Returns:
Fields representing the just created arrays
Fields representing the just created arrays
"""
"""
...
@@ -200,6 +207,10 @@ class DataHandling(ABC):
...
@@ -200,6 +207,10 @@ class DataHandling(ABC):
directly passed to the kernel function and override possible parameters from the DataHandling
directly passed to the kernel function and override possible parameters from the DataHandling
"""
"""
@
abstractmethod
def
get_kernel_kwargs
(
self
,
kernel_function
,
**
kwargs
):
"""Returns the input arguments of a kernel"""
@
abstractmethod
@
abstractmethod
def
swap
(
self
,
name1
,
name2
,
gpu
=
False
):
def
swap
(
self
,
name1
,
name2
,
gpu
=
False
):
"""Swaps data of two arrays"""
"""Swaps data of two arrays"""
...
...
pystencils/datahandling/serial_datahandling.py
View file @
0d6780b8
...
@@ -266,10 +266,10 @@ class SerialDataHandling(DataHandling):
...
@@ -266,10 +266,10 @@ class SerialDataHandling(DataHandling):
return
name
in
self
.
gpu_arrays
return
name
in
self
.
gpu_arrays
def
synchronization_function_cpu
(
self
,
names
,
stencil_name
=
None
,
**
_
):
def
synchronization_function_cpu
(
self
,
names
,
stencil_name
=
None
,
**
_
):
return
self
.
synchronization_function
(
names
,
stencil_name
,
'cpu'
)
return
self
.
synchronization_function
(
names
,
stencil_name
,
target
=
'cpu'
)
def
synchronization_function_gpu
(
self
,
names
,
stencil_name
=
None
,
**
_
):
def
synchronization_function_gpu
(
self
,
names
,
stencil_name
=
None
,
**
_
):
return
self
.
synchronization_function
(
names
,
stencil_name
,
'gpu'
)
return
self
.
synchronization_function
(
names
,
stencil_name
,
target
=
'gpu'
)
def
synchronization_function
(
self
,
names
,
stencil
=
None
,
target
=
None
,
**
_
):
def
synchronization_function
(
self
,
names
,
stencil
=
None
,
target
=
None
,
**
_
):
if
target
is
None
:
if
target
is
None
:
...
@@ -425,14 +425,15 @@ class SerialDataHandling(DataHandling):
...
@@ -425,14 +425,15 @@ class SerialDataHandling(DataHandling):
np
.
savez_compressed
(
file
,
**
self
.
cpu_arrays
)
np
.
savez_compressed
(
file
,
**
self
.
cpu_arrays
)
def
load_all
(
self
,
file
):
def
load_all
(
self
,
file
):
if
'.npz'
not
in
file
:
file
+=
'.npz'
file_contents
=
np
.
load
(
file
)
file_contents
=
np
.
load
(
file
)
for
arr_name
,
arr_contents
in
self
.
cpu_arrays
.
items
():
for
arr_name
,
arr_contents
in
self
.
cpu_arrays
.
items
():
if
arr_name
not
in
file_contents
:
if
arr_name
not
in
file_contents
:
print
(
f
"Skipping read data
{
arr_name
}
because there is no data with this name in data handling"
)
print
(
f
"Skipping read data
{
arr_name
}
because there is no data with this name in data handling"
)
continue
continue
if
file_contents
[
arr_name
].
shape
!=
arr_contents
.
shape
:
if
file_contents
[
arr_name
].
shape
!=
arr_contents
.
shape
:
print
(
"Skipping read data {} because shapes don't match. "
print
(
f
"Skipping read data
{
arr_name
}
because shapes don't match. "
"Read array shape {}, existing array shape {}"
.
format
(
arr_name
,
file_contents
[
arr_name
].
shape
,
f
"Read array shape
{
file_contents
[
arr_name
].
shape
}
, existing array shape
{
arr_contents
.
shape
}
"
)
arr_contents
.
shape
))
continue
continue
np
.
copyto
(
arr_contents
,
file_contents
[
arr_name
])
np
.
copyto
(
arr_contents
,
file_contents
[
arr_name
])
pystencils/fd/derivative.py
View file @
0d6780b8
...
@@ -228,7 +228,9 @@ def diff_terms(expr):
...
@@ -228,7 +228,9 @@ def diff_terms(expr):
Example:
Example:
>>> x, y = sp.symbols("x, y")
>>> x, y = sp.symbols("x, y")
>>> diff_terms( diff(x, 0, 0) )
>>> diff_terms( diff(x, 0, 0) )
{Diff(Diff(x, 0, -1), 0, -1)}
>>> diff_terms( diff(x, 0, 0) + y )
{Diff(Diff(x, 0, -1), 0, -1)}
{Diff(Diff(x, 0, -1), 0, -1)}
"""
"""
result
=
set
()
result
=
set
()
...
...
pystencils/simp/__init__.py
View file @
0d6780b8
from
.assignment_collection
import
AssignmentCollection
from
.assignment_collection
import
AssignmentCollection
from
.simplifications
import
(
from
.simplifications
import
(
add_subexpressions_for_divisions
,
add_subexpressions_for_field_reads
,
add_subexpressions_for_divisions
,
add_subexpressions_for_field_reads
,
apply_on_all_subexpressions
,
apply_to_all_assignments
,
add_subexpressions_for_sums
,
apply_on_all_subexpressions
,
apply_to_all_assignments
,
subexpression_substitution_in_existing_subexpressions
,
subexpression_substitution_in_existing_subexpressions
,
subexpression_substitution_in_main_assignments
,
sympy_cse
,
sympy_cse_on_assignment_list
)
subexpression_substitution_in_main_assignments
,
sympy_cse
,
sympy_cse_on_assignment_list
)
from
.simplificationstrategy
import
SimplificationStrategy
from
.simplificationstrategy
import
SimplificationStrategy
...
@@ -10,4 +10,4 @@ __all__ = ['AssignmentCollection', 'SimplificationStrategy',
...
@@ -10,4 +10,4 @@ __all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse'
,
'sympy_cse_on_assignment_list'
,
'apply_to_all_assignments'
,
'sympy_cse'
,
'sympy_cse_on_assignment_list'
,
'apply_to_all_assignments'
,
'apply_on_all_subexpressions'
,
'subexpression_substitution_in_existing_subexpressions'
,
'apply_on_all_subexpressions'
,
'subexpression_substitution_in_existing_subexpressions'
,
'subexpression_substitution_in_main_assignments'
,
'add_subexpressions_for_divisions'
,
'subexpression_substitution_in_main_assignments'
,
'add_subexpressions_for_divisions'
,
'add_subexpressions_for_field_reads'
]
'add_subexpressions_for_sums'
,
'add_subexpressions_for_field_reads'
]
pystencils/simp/assignment_collection.py
View file @
0d6780b8
...
@@ -6,8 +6,7 @@ import sympy as sp
...
@@ -6,8 +6,7 @@ import sympy as sp
import
pystencils
import
pystencils
from
pystencils.assignment
import
Assignment
from
pystencils.assignment
import
Assignment
from
pystencils.simp.simplifications
import
(
from
pystencils.simp.simplifications
import
(
sort_assignments_topologically
,
transform_lhs_and_rhs
,
transform_rhs
)
sort_assignments_topologically
,
transform_lhs_and_rhs
,
transform_rhs
)
from
pystencils.sympyextensions
import
count_operations
,
fast_subs
from
pystencils.sympyextensions
import
count_operations
,
fast_subs
...
@@ -263,7 +262,7 @@ class AssignmentCollection:
...
@@ -263,7 +262,7 @@ class AssignmentCollection:
own_definitions
=
set
([
e
.
lhs
for
e
in
self
.
main_assignments
])
own_definitions
=
set
([
e
.
lhs
for
e
in
self
.
main_assignments
])
other_definitions
=
set
([
e
.
lhs
for
e
in
other
.
main_assignments
])
other_definitions
=
set
([
e
.
lhs
for
e
in
other
.
main_assignments
])
assert
len
(
own_definitions
.
intersection
(
other_definitions
))
==
0
,
\
assert
len
(
own_definitions
.
intersection
(
other_definitions
))
==
0
,
\
"Cannot
new_
merge
d
, since both
collection
define the same symbols"
"Cannot merge
collections
, since both define the same symbols"
own_subexpression_symbols
=
{
e
.
lhs
:
e
.
rhs
for
e
in
self
.
subexpressions
}
own_subexpression_symbols
=
{
e
.
lhs
:
e
.
rhs
for
e
in
self
.
subexpressions
}
substitution_dict
=
{}
substitution_dict
=
{}
...
@@ -334,7 +333,7 @@ class AssignmentCollection:
...
@@ -334,7 +333,7 @@ class AssignmentCollection:
kept_subexpressions
=
[]
kept_subexpressions
=
[]
if
self
.
subexpressions
[
0
].
lhs
in
subexpressions_to_keep
:
if
self
.
subexpressions
[
0
].
lhs
in
subexpressions_to_keep
:
substitution_dict
=
{}
substitution_dict
=
{}
kept_subexpressions
=
self
.
subexpressions
[
0
]
kept_subexpressions
.
append
(
self
.
subexpressions
[
0
]
)
else
:
else
:
substitution_dict
=
{
self
.
subexpressions
[
0
].
lhs
:
self
.
subexpressions
[
0
].
rhs
}
substitution_dict
=
{
self
.
subexpressions
[
0
].
lhs
:
self
.
subexpressions
[
0
].
rhs
}
...
...
pystencils/simp/simplifications.py
View file @
0d6780b8
...
@@ -18,7 +18,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
...
@@ -18,7 +18,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
elif
isinstance
(
e1
,
Node
):
elif
isinstance
(
e1
,
Node
):
symbols
=
e1
.
symbols_defined
symbols
=
e1
.
symbols_defined
else
:
else
:
raise
NotImplementedError
(
"Cannot sort topologically. Object of type
"
+
type
(
e1
)
+
"
cannot be handled."
)
raise
NotImplementedError
(
f
"Cannot sort topologically. Object of type
{
type
(
e1
)
}
cannot be handled."
)
for
lhs
in
symbols
:
for
lhs
in
symbols
:
for
c2
,
e2
in
enumerate
(
assignments
):
for
c2
,
e2
in
enumerate
(
assignments
):
...
@@ -112,14 +112,14 @@ def add_subexpressions_for_sums(ac):
...
@@ -112,14 +112,14 @@ def add_subexpressions_for_sums(ac):
addends
=
[]
addends
=
[]
def
contains_sum
(
term
):
def
contains_sum
(
term
):
if
term
.
func
==
sp
.
add
.
Add
:
if
term
.
func
==
sp
.
Add
:
return
True
return
True
if
term
.
is_Atom
:
if
term
.
is_Atom
:
return
False
return
False
return
any
([
contains_sum
(
a
)
for
a
in
term
.
args
])
return
any
([
contains_sum
(
a
)
for
a
in
term
.
args
])
def
search_addends
(
term
):
def
search_addends
(
term
):
if
term
.
func
==
sp
.
add
.
Add
:
if
term
.
func
==
sp
.
Add
:
if
all
([
not
contains_sum
(
a
)
for
a
in
term
.
args
]):
if
all
([
not
contains_sum
(
a
)
for
a
in
term
.
args
]):
addends
.
extend
(
term
.
args
)
addends
.
extend
(
term
.
args
)
for
a
in
term
.
args
:
for
a
in
term
.
args
:
...
...
pystencils/stencil.py
View file @
0d6780b8
...
@@ -34,6 +34,8 @@ def is_valid(stencil, max_neighborhood=None):
...
@@ -34,6 +34,8 @@ def is_valid(stencil, max_neighborhood=None):
True
True
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=1)
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=1)
False
False
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=2)
True
"""
"""
expected_dim
=
len
(
stencil
[
0
])
expected_dim
=
len
(
stencil
[
0
])
for
d
in
stencil
:
for
d
in
stencil
:
...
@@ -67,8 +69,11 @@ def have_same_entries(s1, s2):
...
@@ -67,8 +69,11 @@ def have_same_entries(s1, s2):
Examples:
Examples:
>>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)]
>>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)]
>>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)]
>>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)]
>>> stencil3 = [(-1, 0), (0, -1), (1, 0)]
>>> have_same_entries(stencil1, stencil2)
>>> have_same_entries(stencil1, stencil2)
True
True
>>> have_same_entries(stencil1, stencil3)
False
"""
"""
if
len
(
s1
)
!=
len
(
s2
):
if
len
(
s1
)
!=
len
(
s2
):
return
False
return
False
...
...
pystencils/sympyextensions.py
View file @
0d6780b8
...
@@ -272,7 +272,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
...
@@ -272,7 +272,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
def
replace_second_order_products
(
expr
:
sp
.
Expr
,
search_symbols
:
Iterable
[
sp
.
Symbol
],
def
replace_second_order_products
(
expr
:
sp
.
Expr
,
search_symbols
:
Iterable
[
sp
.
Symbol
],
positive
:
Optional
[
bool
]
=
None
,
positive
:
Optional
[
bool
]
=
None
,
replace_mixed
:
Optional
[
List
[
Assignment
]]
=
None
)
->
sp
.
Expr
:
replace_mixed
:
Optional
[
List
[
Assignment
]]
=
None
)
->
sp
.
Expr
:
"""Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ).
"""Replaces second order mixed terms like
4*
x*y by 2*( (x+y)**2 - x**2 - y**2 ).
This makes the term longer - simplify usually is undoing these - however this
This makes the term longer - simplify usually is undoing these - however this
transformation can be done to find more common sub-expressions
transformation can be done to find more common sub-expressions
...
@@ -293,7 +293,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym
...
@@ -293,7 +293,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym
if
expr
.
is_Mul
:
if
expr
.
is_Mul
:
distinct_search_symbols
=
set
()
distinct_search_symbols
=
set
()
nr_of_search_terms
=
0
nr_of_search_terms
=
0
other_factors
=
1
other_factors
=
sp
.
Integer
(
1
)
for
t
in
expr
.
args
:
for
t
in
expr
.
args
:
if
t
in
search_symbols
:
if
t
in
search_symbols
:
nr_of_search_terms
+=
1
nr_of_search_terms
+=
1
...
@@ -509,13 +509,14 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
...
@@ -509,13 +509,14 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
if
t
.
exp
>=
0
:
if
t
.
exp
>=
0
:
result
[
'muls'
]
+=
int
(
t
.
exp
)
-
1
result
[
'muls'
]
+=
int
(
t
.
exp
)
-
1
else
:
else
:
result
[
'muls'
]
-=
1
if
result
[
'muls'
]
>
0
:
result
[
'muls'
]
-=
1
result
[
'divs'
]
+=
1
result
[
'divs'
]
+=
1
result
[
'muls'
]
+=
(
-
int
(
t
.
exp
))
-
1
result
[
'muls'
]
+=
(
-
int
(
t
.
exp
))
-
1
elif
sp
.
nsimplify
(
t
.
exp
)
==
sp
.
Rational
(
1
,
2
):
elif
sp
.
nsimplify
(
t
.
exp
)
==
sp
.
Rational
(
1
,
2
):
result
[
'sqrts'
]
+=
1
result
[
'sqrts'
]
+=
1
else
:
else
:
warnings
.
warn
(
"Cannot handle exponent
"
,
t
.
exp
,
"
of sp.Pow node"
)
warnings
.
warn
(
f
"Cannot handle exponent
{
t
.
exp
}
of sp.Pow node"
)
else
:
else
:
warnings
.
warn
(
"Counting operations: only integer exponents are supported in Pow, "
warnings
.
warn
(
"Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate"
)
"counting will be inaccurate"
)
...
@@ -526,7 +527,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
...
@@ -526,7 +527,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif
isinstance
(
t
,
sp
.
Rel
):
elif
isinstance
(
t
,
sp
.
Rel
):
pass
pass
else
:
else
:
warnings
.
warn
(
"Unknown sympy node of type
"
+
str
(
t
.
func
)
+
"
counting will be inaccurate"
)
warnings
.
warn
(
f
"Unknown sympy node of type
{
str
(
t
.
func
)
}
counting will be inaccurate"
)
if
visit_children
:
if
visit_children
:
for
a
in
t
.
args
:
for
a
in
t
.
args
:
...
...
pystencils/transformations.py
View file @
0d6780b8
...
@@ -1206,13 +1206,13 @@ def get_loop_hierarchy(ast_node):
...
@@ -1206,13 +1206,13 @@ def get_loop_hierarchy(ast_node):
return
reversed
(
result
)
return
reversed
(
result
)
def
get_loop_counter_symbol_hierarchy
(
ast
N
ode
):
def
get_loop_counter_symbol_hierarchy
(
ast
_n
ode
):
"""Determines the loop counter symbols around a given AST node.
"""Determines the loop counter symbols around a given AST node.
:param ast
N
ode: the AST node
:param ast
_n
ode: the AST node
:return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
:return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
"""
"""
result
=
[]
result
=
[]
node
=
ast
N
ode
node
=
ast
_n
ode
while
node
is
not
None
:
while
node
is
not
None
:
node
=
get_next_parent_of_type
(
node
,
ast
.
LoopOverCoordinate
)
node
=
get_next_parent_of_type
(
node
,
ast
.
LoopOverCoordinate
)
if
node
:
if
node
:
...
...
pystencils/utils.py
View file @
0d6780b8
import
os
import
os
import
itertools
from
collections
import
Counter
from
collections
import
Counter
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
tempfile
import
NamedTemporaryFile
from
tempfile
import
NamedTemporaryFile
...
@@ -96,16 +97,21 @@ def fully_contains(l1, l2):
...
@@ -96,16 +97,21 @@ def fully_contains(l1, l2):
def
boolean_array_bounding_box
(
boolean_array
):
def
boolean_array_bounding_box
(
boolean_array
):
"""Returns bounding box around "true" area of boolean array"""
"""Returns bounding box around "true" area of boolean array
dim
=
len
(
boolean_array
.
shape
)
>>> a = np.zeros((4, 4), dtype=bool)
>>> a[1:-1, 1:-1] = True
>>> boolean_array_bounding_box(a)
[(1, 3), (1, 3)]
"""
dim
=
boolean_array
.
ndim
shape
=
boolean_array
.
shape
assert
0
not
in
shape
,
"Shape must not contain zero"
bounds
=
[]
bounds
=
[]
for
i
in
range
(
dim
):
for
ax
in
itertools
.
combinations
(
reversed
(
range
(
dim
)),
dim
-
1
):
for
j
in
range
(
dim
):
nonzero
=
np
.
any
(
boolean_array
,
axis
=
ax
)
if
i
!=
j
:
t
=
np
.
where
(
nonzero
)[
0
][[
0
,
-
1
]]
arr_1d
=
np
.
any
(
boolean_array
,
axis
=
j
)
bounds
.
append
((
t
[
0
],
t
[
1
]
+
1
))
begin
=
np
.
argmax
(
arr_1d
)
end
=
begin
+
np
.
argmin
(
arr_1d
[
begin
:])
bounds
.
append
((
begin
,
end
))
return
bounds
return
bounds
...
@@ -217,7 +223,8 @@ class LinearEquationSystem:
...
@@ -217,7 +223,8 @@ class LinearEquationSystem:
return
'multiple'
return
'multiple'
def
solution
(
self
):
def
solution
(
self
):
"""Solves the system if it has a single solution. Returns a dictionary mapping symbol to solution value."""
"""Solves the system. Under- and overdetermined systems are supported.
Returns a dictionary mapping symbol to solution value."""
return
sp
.
solve_linear_system
(
self
.
_matrix
,
*
self
.
unknowns
)
return
sp
.
solve_linear_system
(
self
.
_matrix
,
*
self
.
unknowns
)
def
_resize_if_necessary
(
self
,
new_rows
=
1
):
def
_resize_if_necessary
(
self
,
new_rows
=
1
):
...
@@ -233,8 +240,3 @@ class LinearEquationSystem:
...
@@ -233,8 +240,3 @@ class LinearEquationSystem:
break
break
result
-=
1
result
-=
1
self
.
next_zero_row
=
result
self
.
next_zero_row
=
result
def
find_unique_solutions_with_zeros
(
system
:
LinearEquationSystem
):
if
not
system
.
solution_structure
()
!=
'multiple'
:
raise
ValueError
(
"Function works only for underdetermined systems"
)
pystencils_tests/test_assignment_collection.py
View file @
0d6780b8
import
pytest
import
pytest
import
sympy
as
sp
import
sympy
as
sp
import
pystencils
as
ps
from
pystencils
import
Assignment
,
AssignmentCollection
from
pystencils
import
Assignment
,
AssignmentCollection
from
pystencils.astnodes
import
Conditional
from
pystencils.astnodes
import
Conditional
from
pystencils.simp.assignment_collection
import
SymbolGen
from
pystencils.simp.assignment_collection
import
SymbolGen
a
,
b
,
c
=
sp
.
symbols
(
"a b c"
)
x
,
y
,
z
,
t
=
sp
.
symbols
(
"x y z t"
)
symbol_gen
=
SymbolGen
(
"a"
)
f
=
ps
.
fields
(
"f(2) : [2D]"
)
d
=
ps
.
fields
(
"d(2) : [2D]"
)
def
test_assignment_collection
():
x
,
y
,
z
,
t
=
sp
.
symbols
(
"x y z t"
)
symbol_gen
=
SymbolGen
(
"a"
)
def
test_assignment_collection
():
ac
=
AssignmentCollection
([
Assignment
(
z
,
x
+
y
)],
ac
=
AssignmentCollection
([
Assignment
(
z
,
x
+
y
)],
[],
subexpression_symbol_generator
=
symbol_gen
)
[],
subexpression_symbol_generator
=
symbol_gen
)
...
@@ -32,10 +36,6 @@ def test_assignment_collection():
...
@@ -32,10 +36,6 @@ def test_assignment_collection():
def
test_free_and_defined_symbols
():
def
test_free_and_defined_symbols
():
x
,
y
,
z
,
t
=
sp
.
symbols
(
"x y z t"
)
a
,
b
=
sp
.
symbols
(
"a b"
)
symbol_gen
=
SymbolGen
(
"a"
)
ac
=
AssignmentCollection
([
Assignment
(
z
,
x
+
y
),
Conditional
(
t
>
0
,
Assignment
(
a
,
b
+
1
),
Assignment
(
a
,
b
+
2
))],
ac
=
AssignmentCollection
([
Assignment
(
z
,
x
+
y
),
Conditional
(
t
>
0
,
Assignment
(
a
,
b
+
1
),
Assignment
(
a
,
b
+
2
))],
[],
subexpression_symbol_generator
=
symbol_gen
)
[],
subexpression_symbol_generator
=
symbol_gen
)
...
@@ -45,35 +45,128 @@ def test_free_and_defined_symbols():
...
@@ -45,35 +45,128 @@ def test_free_and_defined_symbols():
def
test_vector_assignments
():
def
test_vector_assignments
():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
assignments
=
ps
.
Assignment
(
sp
.
Matrix
([
a
,
b
,
c
]),
sp
.
Matrix
([
1
,
2
,
3
]))
import
pystencils
as
ps
import
sympy
as
sp
a
,
b
,
c
=
sp
.
symbols
(
"a b c"
)
assignments
=
ps
.
Assignment
(
sp
.
Matrix
([
a
,
b
,
c
]),
sp
.
Matrix
([
1
,
2
,
3
]))
print
(
assignments
)
print
(
assignments
)
def
test_wrong_vector_assignments
():
def
test_wrong_vector_assignments
():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import
pystencils
as
ps
import
sympy
as
sp
a
,
b
=
sp
.
symbols
(
"a b"
)
with
pytest
.
raises
(
AssertionError
,
with
pytest
.
raises
(
AssertionError
,
match
=
r
'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'
):
match
=
r
'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'
):
ps
.
Assignment
(
sp
.
Matrix
([
a
,
b
]),
sp
.
Matrix
([
1
,
2
,
3
]))
ps
.
Assignment
(
sp
.
Matrix
([
a
,
b
]),
sp
.
Matrix
([
1
,
2
,
3
]))
def
test_vector_assignment_collection
():
def
test_vector_assignment_collection
():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import
pystencils
as
ps
y_m
,
x_m
=
sp
.
Matrix
([
a
,
b
,
c
]),
sp
.
Matrix
([
1
,
2
,
3
])
import
sympy
as
sp
assignments
=
ps
.
AssignmentCollection
({
y_m
:
x_m
})
a
,
b
,
c
=
sp
.
symbols
(
"a b c"
)
y
,
x
=
sp
.
Matrix
([
a
,
b
,
c
]),
sp
.
Matrix
([
1
,
2
,
3
])
assignments
=
ps
.
AssignmentCollection
({
y
:
x
})
print
(
assignments
)
print
(
assignments
)
assignments
=
ps
.
AssignmentCollection
([
ps
.
Assignment
(
y
,
x
)])
assignments
=
ps
.
AssignmentCollection
([
ps
.
Assignment
(
y
_m
,
x_m
)])
print
(
assignments
)
print
(
assignments
)
def
test_new_with_substitutions
():
a1
=
ps
.
Assignment
(
f
[
0
,
0
](
0
),
a
*
b
)
a2
=
ps
.
Assignment
(
f
[
0
,
0
](
1
),
b
*
c
)
ac
=
ps
.
AssignmentCollection
([
a1
,
a2
],
subexpressions
=
[])
subs_dict
=
{
f
[
0
,
0
](
0
):
d
[
0
,
0
](
0
),
f
[
0
,
0
](
1
):
d
[
0
,
0
](
1
)}
subs_ac
=
ac
.
new_with_substitutions
(
subs_dict
,
add_substitutions_as_subexpressions
=
False
,
substitute_on_lhs
=
True
,
sort_topologically
=
True
)
assert
subs_ac
.
main_assignments
[
0
].
lhs
==
d
[
0
,
0
](
0
)
assert
subs_ac
.
main_assignments
[
1
].
lhs
==
d
[
0
,
0
](
1
)
subs_ac
=
ac
.
new_with_substitutions
(
subs_dict
,
add_substitutions_as_subexpressions
=
False
,