Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stephan Seitz
pystencils
Commits
82af488a
Commit
82af488a
authored
Jun 19, 2020
by
Jan Hönig
Browse files
Merge branch 'Skip_interpolation_tests' into 'master'
Adapted test cases to Sympy Version 1.6 See merge request
pycodegen/pystencils!158
parents
09de00cf
76c3727b
Pipeline
#24583
passed with stage
in 6 minutes and 14 seconds
Changes
46
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
.flake8
View file @
82af488a
...
@@ -3,4 +3,4 @@ max-line-length=120
...
@@ -3,4 +3,4 @@ max-line-length=120
exclude=pystencils/jupyter.py,
exclude=pystencils/jupyter.py,
pystencils/plot.py
pystencils/plot.py
pystencils/session.py
pystencils/session.py
ignore = W293 W503 W291 C901
ignore = W293 W503 W291 C901
E741
.gitlab-ci.yml
View file @
82af488a
...
@@ -13,6 +13,7 @@ tests-and-coverage:
...
@@ -13,6 +13,7 @@ tests-and-coverage:
-
$ENABLE_NIGHTLY_BUILDS
-
$ENABLE_NIGHTLY_BUILDS
image
:
i10git.cs.fau.de:5005/pycodegen/pycodegen/full
image
:
i10git.cs.fau.de:5005/pycodegen/pycodegen/full
script
:
script
:
-
env
-
pip list
-
pip list
-
export NUM_CORES=$(nproc --all)
-
export NUM_CORES=$(nproc --all)
-
mkdir -p ~/.config/matplotlib
-
mkdir -p ~/.config/matplotlib
...
...
conftest.py
View file @
82af488a
...
@@ -152,6 +152,7 @@ class IPyNbFile(pytest.File):
...
@@ -152,6 +152,7 @@ class IPyNbFile(pytest.File):
notebook
=
nbformat
.
read
(
notebook_contents
,
4
)
notebook
=
nbformat
.
read
(
notebook_contents
,
4
)
code
,
_
=
exporter
.
from_notebook_node
(
notebook
)
code
,
_
=
exporter
.
from_notebook_node
(
notebook
)
yield
IPyNbTest
(
self
.
name
,
self
,
code
)
yield
IPyNbTest
(
self
.
name
,
self
,
code
)
# pytest v 2.4>: yield IPyNbTest.from_parent(name=self.name, parent=self, code=code)
def
teardown
(
self
):
def
teardown
(
self
):
pass
pass
...
@@ -161,3 +162,4 @@ def pytest_collect_file(path, parent):
...
@@ -161,3 +162,4 @@ def pytest_collect_file(path, parent):
glob_exprs
=
[
"*demo*.ipynb"
,
"*tutorial*.ipynb"
,
"test_*.ipynb"
]
glob_exprs
=
[
"*demo*.ipynb"
,
"*tutorial*.ipynb"
,
"test_*.ipynb"
]
if
any
(
path
.
fnmatch
(
g
)
for
g
in
glob_exprs
):
if
any
(
path
.
fnmatch
(
g
)
for
g
in
glob_exprs
):
return
IPyNbFile
(
path
,
parent
)
return
IPyNbFile
(
path
,
parent
)
# pytest v 2.4 >: return IPyNbFile.from_parent(fspath=path, parent=parent)
pystencils/assignment.py
View file @
82af488a
...
@@ -53,7 +53,7 @@ else:
...
@@ -53,7 +53,7 @@ else:
# Tuple of things that can be on the lhs of an assignment
# Tuple of things that can be on the lhs of an assignment
assignable
=
(
sp
.
Symbol
,
MatrixSymbol
,
MatrixElement
,
sp
.
Indexed
)
assignable
=
(
sp
.
Symbol
,
MatrixSymbol
,
MatrixElement
,
sp
.
Indexed
)
if
not
isinstance
(
lhs
,
assignable
):
if
not
isinstance
(
lhs
,
assignable
):
raise
TypeError
(
"Cannot assign to lhs of type
%s."
%
type
(
lhs
))
raise
TypeError
(
f
"Cannot assign to lhs of type
{
type
(
lhs
)
}
."
)
return
sp
.
Rel
.
__new__
(
cls
,
lhs
,
rhs
,
**
assumptions
)
return
sp
.
Rel
.
__new__
(
cls
,
lhs
,
rhs
,
**
assumptions
)
__str__
=
assignment_str
__str__
=
assignment_str
...
...
pystencils/astnodes.py
View file @
82af488a
...
@@ -113,14 +113,14 @@ class Conditional(Node):
...
@@ -113,14 +113,14 @@ class Conditional(Node):
return
self
.
__repr__
()
return
self
.
__repr__
()
def
__repr__
(
self
):
def
__repr__
(
self
):
re
pr
=
'if:({
!r}) '
.
format
(
self
.
condition_expr
)
re
sult
=
f
'if:(
{
self
.
condition_expr
!
r
}
) '
if
self
.
true_block
:
if
self
.
true_block
:
re
pr
+=
'
\n\t
{
}) '
.
format
(
self
.
true_block
)
re
sult
+=
f
'
\n\t
{
self
.
true_block
}
) '
if
self
.
false_block
:
if
self
.
false_block
:
re
pr
=
'else: '
.
format
(
self
.
false_block
)
re
sult
=
'else: '
re
pr
+=
'
\n\t
{
} '
.
format
(
self
.
false_block
)
re
sult
+=
f
'
\n\t
{
self
.
false_block
}
'
return
re
pr
return
re
sult
def
replace_by_true_block
(
self
):
def
replace_by_true_block
(
self
):
"""Replaces the conditional by its True block"""
"""Replaces the conditional by its True block"""
...
@@ -264,7 +264,7 @@ class KernelFunction(Node):
...
@@ -264,7 +264,7 @@ class KernelFunction(Node):
def
__repr__
(
self
):
def
__repr__
(
self
):
params
=
[
p
.
symbol
for
p
in
self
.
get_parameters
()]
params
=
[
p
.
symbol
for
p
in
self
.
get_parameters
()]
return
'{
0} {1}({2})'
.
format
(
type
(
self
).
__name__
,
self
.
function_name
,
params
)
return
f
'
{
type
(
self
).
__name__
}
{
self
.
function_name
}
(
{
params
}
)'
def
compile
(
self
,
*
args
,
**
kwargs
):
def
compile
(
self
,
*
args
,
**
kwargs
):
if
self
.
_compile_function
is
None
:
if
self
.
_compile_function
is
None
:
...
@@ -475,11 +475,11 @@ class LoopOverCoordinate(Node):
...
@@ -475,11 +475,11 @@ class LoopOverCoordinate(Node):
@
staticmethod
@
staticmethod
def
get_loop_counter_name
(
coordinate_to_loop_over
):
def
get_loop_counter_name
(
coordinate_to_loop_over
):
return
"%s_%s"
%
(
LoopOverCoordinate
.
LOOP_COUNTER_NAME_PREFIX
,
coordinate_to_loop_over
)
return
f
"
{
LoopOverCoordinate
.
LOOP_COUNTER_NAME_PREFIX
}
_
{
coordinate_to_loop_over
}
"
@
staticmethod
@
staticmethod
def
get_block_loop_counter_name
(
coordinate_to_loop_over
):
def
get_block_loop_counter_name
(
coordinate_to_loop_over
):
return
"%s_%s"
%
(
LoopOverCoordinate
.
BlOCK_LOOP_COUNTER_NAME_PREFIX
,
coordinate_to_loop_over
)
return
f
"
{
LoopOverCoordinate
.
BlOCK_LOOP_COUNTER_NAME_PREFIX
}
_
{
coordinate_to_loop_over
}
"
@
property
@
property
def
loop_counter_name
(
self
):
def
loop_counter_name
(
self
):
...
@@ -612,7 +612,7 @@ class SympyAssignment(Node):
...
@@ -612,7 +612,7 @@ class SympyAssignment(Node):
replacement
.
parent
=
self
replacement
.
parent
=
self
self
.
rhs
=
replacement
self
.
rhs
=
replacement
else
:
else
:
raise
ValueError
(
'%s
is not in args of
%s'
%
(
replacement
,
self
.
__class__
)
)
raise
ValueError
(
f
'
{
replacement
}
is not in args of
{
self
.
__class__
}
'
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
repr
(
self
.
lhs
)
+
" ← "
+
repr
(
self
.
rhs
)
return
repr
(
self
.
lhs
)
+
" ← "
+
repr
(
self
.
rhs
)
...
@@ -620,7 +620,7 @@ class SympyAssignment(Node):
...
@@ -620,7 +620,7 @@ class SympyAssignment(Node):
def
_repr_html_
(
self
):
def
_repr_html_
(
self
):
printed_lhs
=
sp
.
latex
(
self
.
lhs
)
printed_lhs
=
sp
.
latex
(
self
.
lhs
)
printed_rhs
=
sp
.
latex
(
self
.
rhs
)
printed_rhs
=
sp
.
latex
(
self
.
rhs
)
return
"${printed_lhs}
\\
leftarrow {printed_rhs}$"
.
format
(
printed_lhs
=
printed_lhs
,
printed_rhs
=
printed_rhs
)
return
f
"$
{
printed_lhs
}
\\
leftarrow
{
printed_rhs
}
$"
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
((
self
.
lhs
,
self
.
rhs
))
return
hash
((
self
.
lhs
,
self
.
rhs
))
...
@@ -663,7 +663,7 @@ class ResolvedFieldAccess(sp.Indexed):
...
@@ -663,7 +663,7 @@ class ResolvedFieldAccess(sp.Indexed):
def
__str__
(
self
):
def
__str__
(
self
):
top
=
super
(
ResolvedFieldAccess
,
self
).
__str__
()
top
=
super
(
ResolvedFieldAccess
,
self
).
__str__
()
return
"%s (%s)"
%
(
top
,
self
.
typed_symbol
.
dtype
)
return
f
"
{
top
}
(
{
self
.
typed_symbol
.
dtype
}
)"
def
__getnewargs__
(
self
):
def
__getnewargs__
(
self
):
return
self
.
base
,
self
.
indices
[
0
],
self
.
field
,
self
.
offsets
,
self
.
idx_coordinate_values
return
self
.
base
,
self
.
indices
[
0
],
self
.
field
,
self
.
offsets
,
self
.
idx_coordinate_values
...
@@ -740,7 +740,7 @@ def early_out(condition):
...
@@ -740,7 +740,7 @@ def early_out(condition):
def
get_dummy_symbol
(
dtype
=
'bool'
):
def
get_dummy_symbol
(
dtype
=
'bool'
):
return
TypedSymbol
(
'dummy
%s'
%
uuid
.
uuid4
().
hex
,
create_type
(
dtype
))
return
TypedSymbol
(
f
'dummy
{
uuid
.
uuid4
().
hex
}
'
,
create_type
(
dtype
))
class
SourceCodeComment
(
Node
):
class
SourceCodeComment
(
Node
):
...
...
pystencils/backends/cbackend.py
View file @
82af488a
...
@@ -158,7 +158,7 @@ class CustomCodeNode(Node):
...
@@ -158,7 +158,7 @@ class CustomCodeNode(Node):
class
PrintNode
(
CustomCodeNode
):
class
PrintNode
(
CustomCodeNode
):
# noinspection SpellCheckingInspection
# noinspection SpellCheckingInspection
def
__init__
(
self
,
symbol_to_print
):
def
__init__
(
self
,
symbol_to_print
):
code
=
'
\n
std::cout << "
%s = " << %s << std::endl;
\n
'
%
(
symbol_to_print
.
name
,
symbol_to_print
.
name
)
code
=
f
'
\n
std::cout << "
{
symbol_to_print
.
name
}
= " <<
{
symbol_to_print
.
name
}
<< std::endl;
\n
'
super
(
PrintNode
,
self
).
__init__
(
code
,
symbols_read
=
[
symbol_to_print
],
symbols_defined
=
set
())
super
(
PrintNode
,
self
).
__init__
(
code
,
symbols_read
=
[
symbol_to_print
],
symbols_defined
=
set
())
self
.
headers
.
append
(
"<iostream>"
)
self
.
headers
.
append
(
"<iostream>"
)
...
@@ -203,12 +203,12 @@ class CBackend:
...
@@ -203,12 +203,12 @@ class CBackend:
return
str
(
node
)
return
str
(
node
)
def
_print_KernelFunction
(
self
,
node
):
def
_print_KernelFunction
(
self
,
node
):
function_arguments
=
[
"%s %s"
%
(
self
.
_print
(
s
.
symbol
.
dtype
)
,
s
.
symbol
.
name
)
for
s
in
node
.
get_parameters
()]
function_arguments
=
[
f
"
{
self
.
_print
(
s
.
symbol
.
dtype
)
}
{
s
.
symbol
.
name
}
"
for
s
in
node
.
get_parameters
()]
launch_bounds
=
""
launch_bounds
=
""
if
self
.
_dialect
==
'cuda'
:
if
self
.
_dialect
==
'cuda'
:
max_threads
=
node
.
indexing
.
max_threads_per_block
()
max_threads
=
node
.
indexing
.
max_threads_per_block
()
if
max_threads
:
if
max_threads
:
launch_bounds
=
"__launch_bounds__({
}) "
.
format
(
max_threads
)
launch_bounds
=
f
"__launch_bounds__(
{
max_threads
}
) "
func_declaration
=
"FUNC_PREFIX %svoid %s(%s)"
%
(
launch_bounds
,
node
.
function_name
,
func_declaration
=
"FUNC_PREFIX %svoid %s(%s)"
%
(
launch_bounds
,
node
.
function_name
,
", "
.
join
(
function_arguments
))
", "
.
join
(
function_arguments
))
if
self
.
_signatureOnly
:
if
self
.
_signatureOnly
:
...
@@ -222,19 +222,19 @@ class CBackend:
...
@@ -222,19 +222,19 @@ class CBackend:
return
"{
\n
%s
\n
}"
%
(
self
.
_indent
+
self
.
_indent
.
join
(
block_contents
.
splitlines
(
True
)))
return
"{
\n
%s
\n
}"
%
(
self
.
_indent
+
self
.
_indent
.
join
(
block_contents
.
splitlines
(
True
)))
def
_print_PragmaBlock
(
self
,
node
):
def
_print_PragmaBlock
(
self
,
node
):
return
"%s
\n
%s"
%
(
node
.
pragma_line
,
self
.
_print_Block
(
node
)
)
return
f
"
{
node
.
pragma_line
}
\n
{
self
.
_print_Block
(
node
)
}
"
def
_print_LoopOverCoordinate
(
self
,
node
):
def
_print_LoopOverCoordinate
(
self
,
node
):
counter_symbol
=
node
.
loop_counter_name
counter_symbol
=
node
.
loop_counter_name
start
=
"int
%s = %s"
%
(
counter_symbol
,
self
.
sympy_printer
.
doprint
(
node
.
start
)
)
start
=
f
"int
{
counter_symbol
}
=
{
self
.
sympy_printer
.
doprint
(
node
.
start
)
}
"
condition
=
"%s < %s"
%
(
counter_symbol
,
self
.
sympy_printer
.
doprint
(
node
.
stop
)
)
condition
=
f
"
{
counter_symbol
}
<
{
self
.
sympy_printer
.
doprint
(
node
.
stop
)
}
"
update
=
"%s += %s"
%
(
counter_symbol
,
self
.
sympy_printer
.
doprint
(
node
.
step
)
,)
update
=
f
"
{
counter_symbol
}
+=
{
self
.
sympy_printer
.
doprint
(
node
.
step
)
}
"
loop_str
=
"for (
%s; %s; %s)"
%
(
start
,
condition
,
update
)
loop_str
=
f
"for (
{
start
}
;
{
condition
}
;
{
update
}
)"
prefix
=
"
\n
"
.
join
(
node
.
prefix_lines
)
prefix
=
"
\n
"
.
join
(
node
.
prefix_lines
)
if
prefix
:
if
prefix
:
prefix
+=
"
\n
"
prefix
+=
"
\n
"
return
"%s%s
\n
%s"
%
(
prefix
,
loop_str
,
self
.
_print
(
node
.
body
)
)
return
f
"
{
prefix
}{
loop_str
}
\n
{
self
.
_print
(
node
.
body
)
}
"
def
_print_SympyAssignment
(
self
,
node
):
def
_print_SympyAssignment
(
self
,
node
):
if
node
.
is_declaration
:
if
node
.
is_declaration
:
...
@@ -262,7 +262,7 @@ class CBackend:
...
@@ -262,7 +262,7 @@ class CBackend:
instr
=
'maskStore'
if
aligned
else
'maskStoreU'
instr
=
'maskStore'
if
aligned
else
'maskStoreU'
printed_mask
=
self
.
sympy_printer
.
doprint
(
mask
)
printed_mask
=
self
.
sympy_printer
.
doprint
(
mask
)
if
self
.
_vector_instruction_set
[
'dataTypePrefix'
][
'double'
]
==
'__mm256d'
:
if
self
.
_vector_instruction_set
[
'dataTypePrefix'
][
'double'
]
==
'__mm256d'
:
printed_mask
=
"_mm256_castpd_si256({
})"
.
format
(
printed_mask
)
printed_mask
=
f
"_mm256_castpd_si256(
{
printed_mask
}
)"
rhs_type
=
get_type_of_expression
(
node
.
rhs
)
rhs_type
=
get_type_of_expression
(
node
.
rhs
)
if
type
(
rhs_type
)
is
not
VectorType
:
if
type
(
rhs_type
)
is
not
VectorType
:
...
@@ -274,7 +274,7 @@ class CBackend:
...
@@ -274,7 +274,7 @@ class CBackend:
self
.
sympy_printer
.
doprint
(
rhs
),
self
.
sympy_printer
.
doprint
(
rhs
),
printed_mask
)
+
';'
printed_mask
)
+
';'
else
:
else
:
return
"%s = %s;"
%
(
self
.
sympy_printer
.
doprint
(
node
.
lhs
)
,
self
.
sympy_printer
.
doprint
(
node
.
rhs
)
)
return
f
"
{
self
.
sympy_printer
.
doprint
(
node
.
lhs
)
}
=
{
self
.
sympy_printer
.
doprint
(
node
.
rhs
)
}
;"
def
_print_TemporaryMemoryAllocation
(
self
,
node
):
def
_print_TemporaryMemoryAllocation
(
self
,
node
):
align
=
64
align
=
64
...
@@ -314,7 +314,7 @@ class CBackend:
...
@@ -314,7 +314,7 @@ class CBackend:
raise
ValueError
(
"Problem with Conditional inside vectorized loop - use vec_any or vec_all"
)
raise
ValueError
(
"Problem with Conditional inside vectorized loop - use vec_any or vec_all"
)
condition_expr
=
self
.
sympy_printer
.
doprint
(
node
.
condition_expr
)
condition_expr
=
self
.
sympy_printer
.
doprint
(
node
.
condition_expr
)
true_block
=
self
.
_print_Block
(
node
.
true_block
)
true_block
=
self
.
_print_Block
(
node
.
true_block
)
result
=
"if (
%s)
\n
%s "
%
(
condition_expr
,
true_block
)
result
=
f
"if (
{
condition_expr
}
)
\n
{
true_block
}
"
if
node
.
false_block
:
if
node
.
false_block
:
false_block
=
self
.
_print_Block
(
node
.
false_block
)
false_block
=
self
.
_print_Block
(
node
.
false_block
)
result
+=
"else "
+
false_block
result
+=
"else "
+
false_block
...
@@ -343,7 +343,7 @@ class CustomSympyPrinter(CCodePrinter):
...
@@ -343,7 +343,7 @@ class CustomSympyPrinter(CCodePrinter):
if
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
0
<
expr
.
exp
<
8
:
if
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
0
<
expr
.
exp
<
8
:
return
"("
+
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
+
")"
return
"("
+
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
+
")"
elif
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
-
8
<
expr
.
exp
<
0
:
elif
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
-
8
<
expr
.
exp
<
0
:
return
"1 / ({
})"
.
format
(
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
(
-
expr
.
exp
),
evaluate
=
False
))
)
return
f
"1 / (
{
self
.
_print
(
sp
.
Mul
(
*
(
[
expr
.
base
]
*
-
expr
.
exp
),
evaluate
=
False
))
}
)"
else
:
else
:
return
super
(
CustomSympyPrinter
,
self
).
_print_Pow
(
expr
)
return
super
(
CustomSympyPrinter
,
self
).
_print_Pow
(
expr
)
...
@@ -362,10 +362,10 @@ class CustomSympyPrinter(CCodePrinter):
...
@@ -362,10 +362,10 @@ class CustomSympyPrinter(CCodePrinter):
return
result
.
replace
(
"
\n
"
,
""
)
return
result
.
replace
(
"
\n
"
,
""
)
def
_print_Abs
(
self
,
expr
):
def
_print_Abs
(
self
,
expr
):
if
expr
.
is_integer
:
if
expr
.
args
[
0
].
is_integer
:
return
'abs({
0})'
.
format
(
self
.
_print
(
expr
.
args
[
0
])
)
return
f
'abs(
{
self
.
_print
(
expr
.
args
[
0
])
}
)'
else
:
else
:
return
'fabs({
0})'
.
format
(
self
.
_print
(
expr
.
args
[
0
])
)
return
f
'fabs(
{
self
.
_print
(
expr
.
args
[
0
])
}
)'
def
_print_Type
(
self
,
node
):
def
_print_Type
(
self
,
node
):
return
str
(
node
)
return
str
(
node
)
...
@@ -382,37 +382,37 @@ class CustomSympyPrinter(CCodePrinter):
...
@@ -382,37 +382,37 @@ class CustomSympyPrinter(CCodePrinter):
return
expr
.
to_c
(
self
.
_print
)
return
expr
.
to_c
(
self
.
_print
)
if
isinstance
(
expr
,
reinterpret_cast_func
):
if
isinstance
(
expr
,
reinterpret_cast_func
):
arg
,
data_type
=
expr
.
args
arg
,
data_type
=
expr
.
args
return
"*((
%s)(& %s))"
%
(
self
.
_print
(
PointerType
(
data_type
,
restrict
=
False
))
,
self
.
_print
(
arg
)
)
return
f
"*((
{
self
.
_print
(
PointerType
(
data_type
,
restrict
=
False
))
}
)(&
{
self
.
_print
(
arg
)
}
))"
elif
isinstance
(
expr
,
address_of
):
elif
isinstance
(
expr
,
address_of
):
assert
len
(
expr
.
args
)
==
1
,
"address_of must only have one argument"
assert
len
(
expr
.
args
)
==
1
,
"address_of must only have one argument"
return
"&(
%s)"
%
self
.
_print
(
expr
.
args
[
0
])
return
f
"&(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
elif
isinstance
(
expr
,
cast_func
):
elif
isinstance
(
expr
,
cast_func
):
arg
,
data_type
=
expr
.
args
arg
,
data_type
=
expr
.
args
if
isinstance
(
arg
,
sp
.
Number
)
and
arg
.
is_finite
:
if
isinstance
(
arg
,
sp
.
Number
)
and
arg
.
is_finite
:
return
self
.
_typed_number
(
arg
,
data_type
)
return
self
.
_typed_number
(
arg
,
data_type
)
else
:
else
:
return
"((
%s)(%s))"
%
(
data_type
,
self
.
_print
(
arg
)
)
return
f
"((
{
data_type
}
)(
{
self
.
_print
(
arg
)
}
))"
elif
isinstance
(
expr
,
fast_division
):
elif
isinstance
(
expr
,
fast_division
):
return
"({
})"
.
format
(
self
.
_print
(
expr
.
args
[
0
]
/
expr
.
args
[
1
])
)
return
f
"(
{
self
.
_print
(
expr
.
args
[
0
]
/
expr
.
args
[
1
])
}
)"
elif
isinstance
(
expr
,
fast_sqrt
):
elif
isinstance
(
expr
,
fast_sqrt
):
return
"({
})"
.
format
(
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
]))
)
return
f
"(
{
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
elif
isinstance
(
expr
,
vec_any
)
or
isinstance
(
expr
,
vec_all
):
elif
isinstance
(
expr
,
vec_any
)
or
isinstance
(
expr
,
vec_all
):
return
self
.
_print
(
expr
.
args
[
0
])
return
self
.
_print
(
expr
.
args
[
0
])
elif
isinstance
(
expr
,
fast_inv_sqrt
):
elif
isinstance
(
expr
,
fast_inv_sqrt
):
return
"({
})"
.
format
(
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
]))
)
return
f
"(
{
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
elif
isinstance
(
expr
,
sp
.
Abs
):
elif
isinstance
(
expr
,
sp
.
Abs
):
return
"abs({
})"
.
format
(
self
.
_print
(
expr
.
args
[
0
])
)
return
f
"abs(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
elif
isinstance
(
expr
,
sp
.
Mod
):
elif
isinstance
(
expr
,
sp
.
Mod
):
if
expr
.
args
[
0
].
is_integer
and
expr
.
args
[
1
].
is_integer
:
if
expr
.
args
[
0
].
is_integer
and
expr
.
args
[
1
].
is_integer
:
return
"({
} % {})"
.
format
(
self
.
_print
(
expr
.
args
[
0
])
,
self
.
_print
(
expr
.
args
[
1
])
)
return
f
"(
{
self
.
_print
(
expr
.
args
[
0
])
}
%
{
self
.
_print
(
expr
.
args
[
1
])
}
)"
else
:
else
:
return
"fmod({
}, {})"
.
format
(
self
.
_print
(
expr
.
args
[
0
]),
self
.
_print
(
expr
.
args
[
1
])
)
return
f
"fmod(
{
self
.
_print
(
expr
.
args
[
0
])
}
,
{
self
.
_print
(
expr
.
args
[
1
])
}
)"
elif
expr
.
func
in
infix_functions
:
elif
expr
.
func
in
infix_functions
:
return
"(
%s %s %s)"
%
(
self
.
_print
(
expr
.
args
[
0
])
,
infix_functions
[
expr
.
func
]
,
self
.
_print
(
expr
.
args
[
1
])
)
return
f
"(
{
self
.
_print
(
expr
.
args
[
0
])
}
{
infix_functions
[
expr
.
func
]
}
{
self
.
_print
(
expr
.
args
[
1
])
}
)"
elif
expr
.
func
==
int_power_of_2
:
elif
expr
.
func
==
int_power_of_2
:
return
"(1 << (
%s))"
%
(
self
.
_print
(
expr
.
args
[
0
])
)
return
f
"(1 << (
{
self
.
_print
(
expr
.
args
[
0
])
}
))"
elif
expr
.
func
==
int_div
:
elif
expr
.
func
==
int_div
:
return
"((
%s) / (%s))"
%
(
self
.
_print
(
expr
.
args
[
0
])
,
self
.
_print
(
expr
.
args
[
1
])
)
return
f
"((
{
self
.
_print
(
expr
.
args
[
0
])
}
) / (
{
self
.
_print
(
expr
.
args
[
1
])
}
))"
else
:
else
:
name
=
expr
.
name
if
hasattr
(
expr
,
'name'
)
else
expr
.
__class__
.
__name__
name
=
expr
.
name
if
hasattr
(
expr
,
'name'
)
else
expr
.
__class__
.
__name__
arg_str
=
', '
.
join
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
arg_str
=
', '
.
join
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
...
@@ -540,14 +540,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
...
@@ -540,14 +540,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result
=
self
.
instruction_set
[
'/'
].
format
(
self
.
_print
(
expr
.
args
[
0
]),
self
.
_print
(
expr
.
args
[
1
]))
result
=
self
.
instruction_set
[
'/'
].
format
(
self
.
_print
(
expr
.
args
[
0
]),
self
.
_print
(
expr
.
args
[
1
]))
return
result
return
result
elif
expr
.
func
==
fast_sqrt
:
elif
expr
.
func
==
fast_sqrt
:
return
"({
})"
.
format
(
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
]))
)
return
f
"(
{
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
elif
expr
.
func
==
fast_inv_sqrt
:
elif
expr
.
func
==
fast_inv_sqrt
:
result
=
self
.
_scalarFallback
(
'_print_Function'
,
expr
)
result
=
self
.
_scalarFallback
(
'_print_Function'
,
expr
)
if
not
result
:
if
not
result
:
if
self
.
instruction_set
[
'rsqrt'
]:
if
self
.
instruction_set
[
'rsqrt'
]:
return
self
.
instruction_set
[
'rsqrt'
].
format
(
self
.
_print
(
expr
.
args
[
0
]))
return
self
.
instruction_set
[
'rsqrt'
].
format
(
self
.
_print
(
expr
.
args
[
0
]))
else
:
else
:
return
"({
})"
.
format
(
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
]))
)
return
f
"(
{
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
elif
isinstance
(
expr
,
vec_any
):
elif
isinstance
(
expr
,
vec_any
):
expr_type
=
get_type_of_expression
(
expr
.
args
[
0
])
expr_type
=
get_type_of_expression
(
expr
.
args
[
0
])
if
type
(
expr_type
)
is
not
VectorType
:
if
type
(
expr_type
)
is
not
VectorType
:
...
...
pystencils/backends/cuda_backend.py
View file @
82af488a
...
@@ -98,7 +98,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
...
@@ -98,7 +98,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
if
isinstance
(
expr
,
fast_division
):
if
isinstance
(
expr
,
fast_division
):
return
"__fdividef(%s, %s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
"__fdividef(%s, %s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
elif
isinstance
(
expr
,
fast_sqrt
):
elif
isinstance
(
expr
,
fast_sqrt
):
return
"__fsqrt_rn(
%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
f
"__fsqrt_rn(
{
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
}
)"
elif
isinstance
(
expr
,
fast_inv_sqrt
):
elif
isinstance
(
expr
,
fast_inv_sqrt
):
return
"__frsqrt_rn(
%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
f
"__frsqrt_rn(
{
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
}
)"
return
super
().
_print_Function
(
expr
)
return
super
().
_print_Function
(
expr
)
pystencils/backends/dot.py
View file @
82af488a
...
@@ -57,7 +57,7 @@ def __shortened(node):
...
@@ -57,7 +57,7 @@ def __shortened(node):
params
=
node
.
get_parameters
()
params
=
node
.
get_parameters
()
param_names
=
[
p
.
field_name
for
p
in
params
if
p
.
is_field_pointer
]
param_names
=
[
p
.
field_name
for
p
in
params
if
p
.
is_field_pointer
]
param_names
+=
[
p
.
symbol
.
name
for
p
in
params
if
not
p
.
is_field_parameter
]
param_names
+=
[
p
.
symbol
.
name
for
p
in
params
if
not
p
.
is_field_parameter
]
return
"Func:
%s (%s)"
%
(
node
.
function_name
,
","
.
join
(
param_names
)
)
return
f
"Func:
{
node
.
function_name
}
(
{
','
.
join
(
param_names
)
}
)"
elif
isinstance
(
node
,
SympyAssignment
):
elif
isinstance
(
node
,
SympyAssignment
):
return
repr
(
node
.
lhs
)
return
repr
(
node
.
lhs
)
elif
isinstance
(
node
,
Block
):
elif
isinstance
(
node
,
Block
):
...
@@ -65,7 +65,7 @@ def __shortened(node):
...
@@ -65,7 +65,7 @@ def __shortened(node):
elif
isinstance
(
node
,
Conditional
):
elif
isinstance
(
node
,
Conditional
):
return
repr
(
node
)
return
repr
(
node
)
else
:
else
:
raise
NotImplementedError
(
"Cannot handle node type
%s"
%
(
type
(
node
)
,)
)
raise
NotImplementedError
(
f
"Cannot handle node type
{
type
(
node
)
}
"
)
def
print_dot
(
node
,
view
=
False
,
short
=
False
,
**
kwargs
):
def
print_dot
(
node
,
view
=
False
,
short
=
False
,
**
kwargs
):
...
...
pystencils/backends/opencl_backend.py
View file @
82af488a
...
@@ -98,7 +98,7 @@ class OpenClSympyPrinter(CudaSympyPrinter):
...
@@ -98,7 +98,7 @@ class OpenClSympyPrinter(CudaSympyPrinter):
if
isinstance
(
expr
,
fast_division
):
if
isinstance
(
expr
,
fast_division
):
return
"native_divide(%s, %s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
"native_divide(%s, %s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
elif
isinstance
(
expr
,
fast_sqrt
):
elif
isinstance
(
expr
,
fast_sqrt
):
return
"native_sqrt(
%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
f
"native_sqrt(
{
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
}
)"
elif
isinstance
(
expr
,
fast_inv_sqrt
):
elif
isinstance
(
expr
,
fast_inv_sqrt
):
return
"native_rsqrt(
%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
f
"native_rsqrt(
{
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
}
)"
return
CustomSympyPrinter
.
_print_Function
(
self
,
expr
)
return
CustomSympyPrinter
.
_print_Function
(
self
,
expr
)
pystencils/backends/simd_instruction_sets.py
View file @
82af488a
...
@@ -51,7 +51,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
...
@@ -51,7 +51,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
})
})
for
comparison_op
,
constant
in
comparisons
.
items
():
for
comparison_op
,
constant
in
comparisons
.
items
():
base_names
[
comparison_op
]
=
'cmp[0, 1,
%s]'
%
(
constant
,)
base_names
[
comparison_op
]
=
f
'cmp[0, 1,
{
constant
}
]'
headers
=
{
headers
=
{
'avx512'
:
[
'<immintrin.h>'
],
'avx512'
:
[
'<immintrin.h>'
],
...
@@ -89,16 +89,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
...
@@ -89,16 +89,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
name
=
function_shortcut
[:
function_shortcut
.
index
(
'['
)]
name
=
function_shortcut
[:
function_shortcut
.
index
(
'['
)]
if
intrinsic_id
==
'makeVecConst'
:
if
intrinsic_id
==
'makeVecConst'
:
arg_string
=
"({
})"
.
format
(
","
.
join
([
"
{0}
"
]
*
result
[
'width'
])
)
arg_string
=
f
"(
{
','
.
join
([
'
{
0
}
'
] * result['
width
'])
}
)"
elif
intrinsic_id
==
'makeVec'
:
elif
intrinsic_id
==
'makeVec'
:
params
=
[
"{"
+
str
(
i
)
+
"}"
for
i
in
reversed
(
range
(
result
[
'width'
]))]
params
=
[
"{"
+
str
(
i
)
+
"}"
for
i
in
reversed
(
range
(
result
[
'width'
]))]
arg_string
=
"({
})"
.
format
(
","
.
join
(
params
)
)
arg_string
=
f
"(
{
','
.
join
(
params
)
}
)"
elif
intrinsic_id
==
'makeVecBool'
:
elif
intrinsic_id
==
'makeVecBool'
:
params
=
[
"(({{{i}}} ? -1.0 : 0.0)"
.
format
(
i
=
i
)
for
i
in
reversed
(
range
(
result
[
'width'
]))]
params
=
[
f
"(({{
{
i
}
}} ? -1.0 : 0.0)"
for
i
in
reversed
(
range
(
result
[
'width'
]))]
arg_string
=
"({
})"
.
format
(
","
.
join
(
params
)
)
arg_string
=
f
"(
{
','
.
join
(
params
)
}
)"
elif
intrinsic_id
==
'makeVecConstBool'
:
elif
intrinsic_id
==
'makeVecConstBool'
:
params
=
[
"(({0}) ? -1.0 : 0.0)"
for
_
in
range
(
result
[
'width'
])]
params
=
[
"(({0}) ? -1.0 : 0.0)"
for
_
in
range
(
result
[
'width'
])]
arg_string
=
"({
})"
.
format
(
","
.
join
(
params
)
)
arg_string
=
f
"(
{
','
.
join
(
params
)
}
)"
else
:
else
:
args
=
function_shortcut
[
function_shortcut
.
index
(
'['
)
+
1
:
-
1
]
args
=
function_shortcut
[
function_shortcut
.
index
(
'['
)
+
1
:
-
1
]
arg_string
=
"("
arg_string
=
"("
...
@@ -141,9 +141,9 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
...
@@ -141,9 +141,9 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
result
[
'bool'
]
=
"__mmask%d"
%
(
size
,)
result
[
'bool'
]
=
"__mmask%d"
%
(
size
,)
params
=
" | "
.
join
([
"({{{i}}} ? {power} : 0)"
.
format
(
i
=
i
,
power
=
2
**
i
)
for
i
in
range
(
8
)])
params
=
" | "
.
join
([
"({{{i}}} ? {power} : 0)"
.
format
(
i
=
i
,
power
=
2
**
i
)
for
i
in
range
(
8
)])
result
[
'makeVecBool'
]
=
"__mmask8(({
}) )"
.
format
(
params
)
result
[
'makeVecBool'
]
=
f
"__mmask8((
{
params
}
) )"
params
=
" | "
.
join
([
"({{0}} ? {power} : 0)"
.
format
(
power
=
2
**
i
)
for
i
in
range
(
8
)])
params
=
" | "
.
join
([
"({{0}} ? {power} : 0)"
.
format
(
power
=
2
**
i
)
for
i
in
range
(
8
)])
result
[
'makeVecConstBool'
]
=
"__mmask8(({
}) )"
.
format
(
params
)
result
[
'makeVecConstBool'
]
=
f
"__mmask8((
{
params
}
) )"
if
instruction_set
==
'avx'
and
data_type
==
'float'
:
if
instruction_set
==
'avx'
and
data_type
==
'float'
:
result
[
'rsqrt'
]
=
"_mm256_rsqrt_ps({0})"
result
[
'rsqrt'
]
=
"_mm256_rsqrt_ps({0})"
...
...
pystencils/boundaries/boundaryhandling.py
View file @
82af488a
...
@@ -66,13 +66,13 @@ class FlagInterface:
...
@@ -66,13 +66,13 @@ class FlagInterface:
self
.
_used_flags
.
add
(
flag
)
self
.
_used_flags
.
add
(
flag
)
assert
self
.
_is_power_of_2
(
flag
)
assert
self
.
_is_power_of_2
(
flag
)
return
flag
return
flag
raise
ValueError
(
"All available {} flags are reserved"
.
format
(
self
.
max_bits
)
)
raise
ValueError
(
f
"All available
{
self
.
max_bits
}
flags are reserved"
)
def
reserve_flag
(
self
,
flag
):
def
reserve_flag
(
self
,
flag
):
assert
self
.
_is_power_of_2
(
flag
)
assert
self
.
_is_power_of_2
(
flag
)
flag
=
self
.
dtype
(
flag
)
flag
=
self
.
dtype
(
flag
)
if
flag
in
self
.
_used_flags
:
if
flag
in
self
.
_used_flags
:
raise
ValueError
(
"The flag {flag} is already reserved"
.
format
(
flag
=
flag
)
)
raise
ValueError
(
f
"The flag
{
flag
}
is already reserved"
)
self
.
_used_flags
.
add
(
flag
)
self
.
_used_flags
.
add
(
flag
)
return
flag
return
flag