Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stephan Seitz
pystencils
Commits
bced838a
Commit
bced838a
authored
Jun 19, 2020
by
Markus Holzer
Browse files
Replaced all format strings with f-strings
parent
c83f4f5d
Changes
35
Hide whitespace changes
Inline
Side-by-side
pystencils/assignment.py
View file @
bced838a
...
@@ -53,7 +53,7 @@ else:
...
@@ -53,7 +53,7 @@ else:
# Tuple of things that can be on the lhs of an assignment
# Tuple of things that can be on the lhs of an assignment
assignable
=
(
sp
.
Symbol
,
MatrixSymbol
,
MatrixElement
,
sp
.
Indexed
)
assignable
=
(
sp
.
Symbol
,
MatrixSymbol
,
MatrixElement
,
sp
.
Indexed
)
if
not
isinstance
(
lhs
,
assignable
):
if
not
isinstance
(
lhs
,
assignable
):
raise
TypeError
(
"Cannot assign to lhs of type
%s."
%
type
(
lhs
))
raise
TypeError
(
f
"Cannot assign to lhs of type
{
type
(
lhs
)
}
."
)
return
sp
.
Rel
.
__new__
(
cls
,
lhs
,
rhs
,
**
assumptions
)
return
sp
.
Rel
.
__new__
(
cls
,
lhs
,
rhs
,
**
assumptions
)
__str__
=
assignment_str
__str__
=
assignment_str
...
...
pystencils/astnodes.py
View file @
bced838a
...
@@ -113,12 +113,12 @@ class Conditional(Node):
...
@@ -113,12 +113,12 @@ class Conditional(Node):
return
self
.
__repr__
()
return
self
.
__repr__
()
def
__repr__
(
self
):
def
__repr__
(
self
):
repr
=
'if:({
!r}) '
.
format
(
self
.
condition_expr
)
repr
=
f
'if:(
{
self
.
condition_expr
!r}
) '
if
self
.
true_block
:
if
self
.
true_block
:
repr
+=
'
\n\t
{
}) '
.
format
(
self
.
true_block
)
repr
+=
f
'
\n\t
{
self
.
true_block
}
) '
if
self
.
false_block
:
if
self
.
false_block
:
repr
=
'else: '
repr
=
'else: '
repr
+=
'
\n\t
{
} '
.
format
(
self
.
false_block
)
repr
+=
f
'
\n\t
{
self
.
false_block
}
'
return
repr
return
repr
...
@@ -264,7 +264,7 @@ class KernelFunction(Node):
...
@@ -264,7 +264,7 @@ class KernelFunction(Node):
def
__repr__
(
self
):
def
__repr__
(
self
):
params
=
[
p
.
symbol
for
p
in
self
.
get_parameters
()]
params
=
[
p
.
symbol
for
p
in
self
.
get_parameters
()]
return
'{
0} {1}({2})'
.
format
(
type
(
self
).
__name__
,
self
.
function_name
,
params
)
return
f
'
{
type
(
self
).
__name__
}
{
self
.
function_name
}
(
{
params
}
)'
def
compile
(
self
,
*
args
,
**
kwargs
):
def
compile
(
self
,
*
args
,
**
kwargs
):
if
self
.
_compile_function
is
None
:
if
self
.
_compile_function
is
None
:
...
@@ -475,11 +475,11 @@ class LoopOverCoordinate(Node):
...
@@ -475,11 +475,11 @@ class LoopOverCoordinate(Node):
@
staticmethod
@
staticmethod
def
get_loop_counter_name
(
coordinate_to_loop_over
):
def
get_loop_counter_name
(
coordinate_to_loop_over
):
return
"%s_%s"
%
(
LoopOverCoordinate
.
LOOP_COUNTER_NAME_PREFIX
,
coordinate_to_loop_over
)
return
f
"
{
LoopOverCoordinate
.
LOOP_COUNTER_NAME_PREFIX
}
_
{
coordinate_to_loop_over
}
"
@
staticmethod
@
staticmethod
def
get_block_loop_counter_name
(
coordinate_to_loop_over
):
def
get_block_loop_counter_name
(
coordinate_to_loop_over
):
return
"%s_%s"
%
(
LoopOverCoordinate
.
BlOCK_LOOP_COUNTER_NAME_PREFIX
,
coordinate_to_loop_over
)
return
f
"
{
LoopOverCoordinate
.
BlOCK_LOOP_COUNTER_NAME_PREFIX
}
_
{
coordinate_to_loop_over
}
"
@
property
@
property
def
loop_counter_name
(
self
):
def
loop_counter_name
(
self
):
...
@@ -612,7 +612,7 @@ class SympyAssignment(Node):
...
@@ -612,7 +612,7 @@ class SympyAssignment(Node):
replacement
.
parent
=
self
replacement
.
parent
=
self
self
.
rhs
=
replacement
self
.
rhs
=
replacement
else
:
else
:
raise
ValueError
(
'%s
is not in args of
%s'
%
(
replacement
,
self
.
__class__
)
)
raise
ValueError
(
f
'
{
replacement
}
is not in args of
{
self
.
__class__
}
'
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
repr
(
self
.
lhs
)
+
" ← "
+
repr
(
self
.
rhs
)
return
repr
(
self
.
lhs
)
+
" ← "
+
repr
(
self
.
rhs
)
...
@@ -620,7 +620,7 @@ class SympyAssignment(Node):
...
@@ -620,7 +620,7 @@ class SympyAssignment(Node):
def
_repr_html_
(
self
):
def
_repr_html_
(
self
):
printed_lhs
=
sp
.
latex
(
self
.
lhs
)
printed_lhs
=
sp
.
latex
(
self
.
lhs
)
printed_rhs
=
sp
.
latex
(
self
.
rhs
)
printed_rhs
=
sp
.
latex
(
self
.
rhs
)
return
"${printed_lhs}
\\
leftarrow {printed_rhs}$"
.
format
(
printed_lhs
=
printed_lhs
,
printed_rhs
=
printed_rhs
)
return
f
"$
{
printed_lhs
}
\\
leftarrow
{
printed_rhs
}
$"
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
((
self
.
lhs
,
self
.
rhs
))
return
hash
((
self
.
lhs
,
self
.
rhs
))
...
@@ -663,7 +663,7 @@ class ResolvedFieldAccess(sp.Indexed):
...
@@ -663,7 +663,7 @@ class ResolvedFieldAccess(sp.Indexed):
def
__str__
(
self
):
def
__str__
(
self
):
top
=
super
(
ResolvedFieldAccess
,
self
).
__str__
()
top
=
super
(
ResolvedFieldAccess
,
self
).
__str__
()
return
"%s (%s)"
%
(
top
,
self
.
typed_symbol
.
dtype
)
return
f
"
{
top
}
(
{
self
.
typed_symbol
.
dtype
}
)"
def
__getnewargs__
(
self
):
def
__getnewargs__
(
self
):
return
self
.
base
,
self
.
indices
[
0
],
self
.
field
,
self
.
offsets
,
self
.
idx_coordinate_values
return
self
.
base
,
self
.
indices
[
0
],
self
.
field
,
self
.
offsets
,
self
.
idx_coordinate_values
...
@@ -740,7 +740,7 @@ def early_out(condition):
...
@@ -740,7 +740,7 @@ def early_out(condition):
def
get_dummy_symbol
(
dtype
=
'bool'
):
def
get_dummy_symbol
(
dtype
=
'bool'
):
return
TypedSymbol
(
'dummy
%s'
%
uuid
.
uuid4
().
hex
,
create_type
(
dtype
))
return
TypedSymbol
(
f
'dummy
{
uuid
.
uuid4
().
hex
}
'
,
create_type
(
dtype
))
class
SourceCodeComment
(
Node
):
class
SourceCodeComment
(
Node
):
...
...
pystencils/backends/cbackend.py
View file @
bced838a
...
@@ -158,7 +158,7 @@ class CustomCodeNode(Node):
...
@@ -158,7 +158,7 @@ class CustomCodeNode(Node):
class
PrintNode
(
CustomCodeNode
):
class
PrintNode
(
CustomCodeNode
):
# noinspection SpellCheckingInspection
# noinspection SpellCheckingInspection
def
__init__
(
self
,
symbol_to_print
):
def
__init__
(
self
,
symbol_to_print
):
code
=
'
\n
std::cout << "
%s = " << %s << std::endl;
\n
'
%
(
symbol_to_print
.
name
,
symbol_to_print
.
name
)
code
=
f
'
\n
std::cout << "
{
symbol_to_print
.
name
}
= " <<
{
symbol_to_print
.
name
}
<< std::endl;
\n
'
super
(
PrintNode
,
self
).
__init__
(
code
,
symbols_read
=
[
symbol_to_print
],
symbols_defined
=
set
())
super
(
PrintNode
,
self
).
__init__
(
code
,
symbols_read
=
[
symbol_to_print
],
symbols_defined
=
set
())
self
.
headers
.
append
(
"<iostream>"
)
self
.
headers
.
append
(
"<iostream>"
)
...
@@ -203,12 +203,12 @@ class CBackend:
...
@@ -203,12 +203,12 @@ class CBackend:
return
str
(
node
)
return
str
(
node
)
def
_print_KernelFunction
(
self
,
node
):
def
_print_KernelFunction
(
self
,
node
):
function_arguments
=
[
"%s %s"
%
(
self
.
_print
(
s
.
symbol
.
dtype
)
,
s
.
symbol
.
name
)
for
s
in
node
.
get_parameters
()]
function_arguments
=
[
f
"
{
self
.
_print
(
s
.
symbol
.
dtype
)
}
{
s
.
symbol
.
name
}
"
for
s
in
node
.
get_parameters
()]
launch_bounds
=
""
launch_bounds
=
""
if
self
.
_dialect
==
'cuda'
:
if
self
.
_dialect
==
'cuda'
:
max_threads
=
node
.
indexing
.
max_threads_per_block
()
max_threads
=
node
.
indexing
.
max_threads_per_block
()
if
max_threads
:
if
max_threads
:
launch_bounds
=
"__launch_bounds__({
}) "
.
format
(
max_threads
)
launch_bounds
=
f
"__launch_bounds__(
{
max_threads
}
) "
func_declaration
=
"FUNC_PREFIX %svoid %s(%s)"
%
(
launch_bounds
,
node
.
function_name
,
func_declaration
=
"FUNC_PREFIX %svoid %s(%s)"
%
(
launch_bounds
,
node
.
function_name
,
", "
.
join
(
function_arguments
))
", "
.
join
(
function_arguments
))
if
self
.
_signatureOnly
:
if
self
.
_signatureOnly
:
...
@@ -222,19 +222,19 @@ class CBackend:
...
@@ -222,19 +222,19 @@ class CBackend:
return
"{
\n
%s
\n
}"
%
(
self
.
_indent
+
self
.
_indent
.
join
(
block_contents
.
splitlines
(
True
)))
return
"{
\n
%s
\n
}"
%
(
self
.
_indent
+
self
.
_indent
.
join
(
block_contents
.
splitlines
(
True
)))
def
_print_PragmaBlock
(
self
,
node
):
def
_print_PragmaBlock
(
self
,
node
):
return
"%s
\n
%s"
%
(
node
.
pragma_line
,
self
.
_print_Block
(
node
)
)
return
f
"
{
node
.
pragma_line
}
\n
{
self
.
_print_Block
(
node
)
}
"
def
_print_LoopOverCoordinate
(
self
,
node
):
def
_print_LoopOverCoordinate
(
self
,
node
):
counter_symbol
=
node
.
loop_counter_name
counter_symbol
=
node
.
loop_counter_name
start
=
"int
%s = %s"
%
(
counter_symbol
,
self
.
sympy_printer
.
doprint
(
node
.
start
)
)
start
=
f
"int
{
counter_symbol
}
=
{
self
.
sympy_printer
.
doprint
(
node
.
start
)
}
"
condition
=
"%s < %s"
%
(
counter_symbol
,
self
.
sympy_printer
.
doprint
(
node
.
stop
)
)
condition
=
f
"
{
counter_symbol
}
<
{
self
.
sympy_printer
.
doprint
(
node
.
stop
)
}
"
update
=
"%s += %s"
%
(
counter_symbol
,
self
.
sympy_printer
.
doprint
(
node
.
step
)
,)
update
=
f
"
{
counter_symbol
}
+=
{
self
.
sympy_printer
.
doprint
(
node
.
step
)
}
"
loop_str
=
"for (
%s; %s; %s)"
%
(
start
,
condition
,
update
)
loop_str
=
f
"for (
{
start
}
;
{
condition
}
;
{
update
}
)"
prefix
=
"
\n
"
.
join
(
node
.
prefix_lines
)
prefix
=
"
\n
"
.
join
(
node
.
prefix_lines
)
if
prefix
:
if
prefix
:
prefix
+=
"
\n
"
prefix
+=
"
\n
"
return
"%s%s
\n
%s"
%
(
prefix
,
loop_str
,
self
.
_print
(
node
.
body
)
)
return
f
"
{
prefix
}{
loop_str
}
\n
{
self
.
_print
(
node
.
body
)
}
"
def
_print_SympyAssignment
(
self
,
node
):
def
_print_SympyAssignment
(
self
,
node
):
if
node
.
is_declaration
:
if
node
.
is_declaration
:
...
@@ -262,7 +262,7 @@ class CBackend:
...
@@ -262,7 +262,7 @@ class CBackend:
instr
=
'maskStore'
if
aligned
else
'maskStoreU'
instr
=
'maskStore'
if
aligned
else
'maskStoreU'
printed_mask
=
self
.
sympy_printer
.
doprint
(
mask
)
printed_mask
=
self
.
sympy_printer
.
doprint
(
mask
)
if
self
.
_vector_instruction_set
[
'dataTypePrefix'
][
'double'
]
==
'__mm256d'
:
if
self
.
_vector_instruction_set
[
'dataTypePrefix'
][
'double'
]
==
'__mm256d'
:
printed_mask
=
"_mm256_castpd_si256({
})"
.
format
(
printed_mask
)
printed_mask
=
f
"_mm256_castpd_si256(
{
printed_mask
}
)"
rhs_type
=
get_type_of_expression
(
node
.
rhs
)
rhs_type
=
get_type_of_expression
(
node
.
rhs
)
if
type
(
rhs_type
)
is
not
VectorType
:
if
type
(
rhs_type
)
is
not
VectorType
:
...
@@ -274,7 +274,7 @@ class CBackend:
...
@@ -274,7 +274,7 @@ class CBackend:
self
.
sympy_printer
.
doprint
(
rhs
),
self
.
sympy_printer
.
doprint
(
rhs
),
printed_mask
)
+
';'
printed_mask
)
+
';'
else
:
else
:
return
"%s = %s;"
%
(
self
.
sympy_printer
.
doprint
(
node
.
lhs
)
,
self
.
sympy_printer
.
doprint
(
node
.
rhs
)
)
return
f
"
{
self
.
sympy_printer
.
doprint
(
node
.
lhs
)
}
=
{
self
.
sympy_printer
.
doprint
(
node
.
rhs
)
}
;"
def
_print_TemporaryMemoryAllocation
(
self
,
node
):
def
_print_TemporaryMemoryAllocation
(
self
,
node
):
align
=
64
align
=
64
...
@@ -314,7 +314,7 @@ class CBackend:
...
@@ -314,7 +314,7 @@ class CBackend:
raise
ValueError
(
"Problem with Conditional inside vectorized loop - use vec_any or vec_all"
)
raise
ValueError
(
"Problem with Conditional inside vectorized loop - use vec_any or vec_all"
)
condition_expr
=
self
.
sympy_printer
.
doprint
(
node
.
condition_expr
)
condition_expr
=
self
.
sympy_printer
.
doprint
(
node
.
condition_expr
)
true_block
=
self
.
_print_Block
(
node
.
true_block
)
true_block
=
self
.
_print_Block
(
node
.
true_block
)
result
=
"if (
%s)
\n
%s "
%
(
condition_expr
,
true_block
)
result
=
f
"if (
{
condition_expr
}
)
\n
{
true_block
}
"
if
node
.
false_block
:
if
node
.
false_block
:
false_block
=
self
.
_print_Block
(
node
.
false_block
)
false_block
=
self
.
_print_Block
(
node
.
false_block
)
result
+=
"else "
+
false_block
result
+=
"else "
+
false_block
...
@@ -343,7 +343,7 @@ class CustomSympyPrinter(CCodePrinter):
...
@@ -343,7 +343,7 @@ class CustomSympyPrinter(CCodePrinter):
if
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
0
<
expr
.
exp
<
8
:
if
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
0
<
expr
.
exp
<
8
:
return
"("
+
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
+
")"
return
"("
+
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
+
")"
elif
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
-
8
<
expr
.
exp
<
0
:
elif
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
-
8
<
expr
.
exp
<
0
:
return
"1 / ({
})"
.
format
(
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
(
-
expr
.
exp
),
evaluate
=
False
))
)
return
f
"1 / (
{
self
.
_print
(
sp
.
Mul
(
*
(
[
expr
.
base
]
*
-
expr
.
exp
),
evaluate
=
False
))
}
)"
else
:
else
:
return
super
(
CustomSympyPrinter
,
self
).
_print_Pow
(
expr
)
return
super
(
CustomSympyPrinter
,
self
).
_print_Pow
(
expr
)
...
@@ -363,9 +363,9 @@ class CustomSympyPrinter(CCodePrinter):
...
@@ -363,9 +363,9 @@ class CustomSympyPrinter(CCodePrinter):
def
_print_Abs
(
self
,
expr
):
def
_print_Abs
(
self
,
expr
):
if
expr
.
args
[
0
].
is_integer
:
if
expr
.
args
[
0
].
is_integer
:
return
'abs({
0})'
.
format
(
self
.
_print
(
expr
.
args
[
0
])
)
return
f
'abs(
{
self
.
_print
(
expr
.
args
[
0
])
}
)'
else
:
else
:
return
'fabs({
0})'
.
format
(
self
.
_print
(
expr
.
args
[
0
])
)
return
f
'fabs(
{
self
.
_print
(
expr
.
args
[
0
])
}
)'
def
_print_Type
(
self
,
node
):
def
_print_Type
(
self
,
node
):
return
str
(
node
)
return
str
(
node
)
...
@@ -382,37 +382,37 @@ class CustomSympyPrinter(CCodePrinter):
...
@@ -382,37 +382,37 @@ class CustomSympyPrinter(CCodePrinter):
return
expr
.
to_c
(
self
.
_print
)
return
expr
.
to_c
(
self
.
_print
)
if
isinstance
(
expr
,
reinterpret_cast_func
):
if
isinstance
(
expr
,
reinterpret_cast_func
):
arg
,
data_type
=
expr
.
args
arg
,
data_type
=
expr
.
args
return
"*((
%s)(& %s))"
%
(
self
.
_print
(
PointerType
(
data_type
,
restrict
=
False
))
,
self
.
_print
(
arg
)
)
return
f
"*((
{
self
.
_print
(
PointerType
(
data_type
,
restrict
=
False
))
}
)(&
{
self
.
_print
(
arg
)
}
))"
elif
isinstance
(
expr
,
address_of
):
elif
isinstance
(
expr
,
address_of
):
assert
len
(
expr
.
args
)
==
1
,
"address_of must only have one argument"
assert
len
(
expr
.
args
)
==
1
,
"address_of must only have one argument"
return
"&(
%s)"
%
self
.
_print
(
expr
.
args
[
0
])
return
f
"&(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
elif
isinstance
(
expr
,
cast_func
):
elif
isinstance
(
expr
,
cast_func
):
arg
,
data_type
=
expr
.
args
arg
,
data_type
=
expr
.
args
if
isinstance
(
arg
,
sp
.
Number
)
and
arg
.
is_finite
:
if
isinstance
(
arg
,
sp
.
Number
)
and
arg
.
is_finite
:
return
self
.
_typed_number
(
arg
,
data_type
)
return
self
.
_typed_number
(
arg
,
data_type
)
else
:
else
:
return
"((
%s)(%s))"
%
(
data_type
,
self
.
_print
(
arg
)
)
return
f
"((
{
data_type
}
)(
{
self
.
_print
(
arg
)
}
))"
elif
isinstance
(
expr
,
fast_division
):
elif
isinstance
(
expr
,
fast_division
):
return
"({
})"
.
format
(
self
.
_print
(
expr
.
args
[
0
]
/
expr
.
args
[
1
])
)
return
f
"(
{
self
.
_print
(
expr
.
args
[
0
]
/
expr
.
args
[
1
])
}
)"
elif
isinstance
(
expr
,
fast_sqrt
):
elif
isinstance
(
expr
,
fast_sqrt
):
return
"({
})"
.
format
(
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
]))
)
return
f
"(
{
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
elif
isinstance
(
expr
,
vec_any
)
or
isinstance
(
expr
,
vec_all
):
elif
isinstance
(
expr
,
vec_any
)
or
isinstance
(
expr
,
vec_all
):
return
self
.
_print
(
expr
.
args
[
0
])
return
self
.
_print
(
expr
.
args
[
0
])
elif
isinstance
(
expr
,
fast_inv_sqrt
):
elif
isinstance
(
expr
,
fast_inv_sqrt
):
return
"({
})"
.
format
(
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
]))
)
return
f
"(
{
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
elif
isinstance
(
expr
,
sp
.
Abs
):
elif
isinstance
(
expr
,
sp
.
Abs
):
return
"abs({
})"
.
format
(
self
.
_print
(
expr
.
args
[
0
])
)
return
f
"abs(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
elif
isinstance
(
expr
,
sp
.
Mod
):
elif
isinstance
(
expr
,
sp
.
Mod
):
if
expr
.
args
[
0
].
is_integer
and
expr
.
args
[
1
].
is_integer
:
if
expr
.
args
[
0
].
is_integer
and
expr
.
args
[
1
].
is_integer
:
return
"({
} % {})"
.
format
(
self
.
_print
(
expr
.
args
[
0
])
,
self
.
_print
(
expr
.
args
[
1
])
)
return
f
"(
{
self
.
_print
(
expr
.
args
[
0
])
}
%
{
self
.
_print
(
expr
.
args
[
1
])
}
)"
else
:
else
:
return
"fmod({
}, {})"
.
format
(
self
.
_print
(
expr
.
args
[
0
]),
self
.
_print
(
expr
.
args
[
1
])
)
return
f
"fmod(
{
self
.
_print
(
expr
.
args
[
0
])
}
,
{
self
.
_print
(
expr
.
args
[
1
])
}
)"
elif
expr
.
func
in
infix_functions
:
elif
expr
.
func
in
infix_functions
:
return
"(
%s %s %s)"
%
(
self
.
_print
(
expr
.
args
[
0
])
,
infix_functions
[
expr
.
func
]
,
self
.
_print
(
expr
.
args
[
1
])
)
return
f
"(
{
self
.
_print
(
expr
.
args
[
0
])
}
{
infix_functions
[
expr
.
func
]
}
{
self
.
_print
(
expr
.
args
[
1
])
}
)"
elif
expr
.
func
==
int_power_of_2
:
elif
expr
.
func
==
int_power_of_2
:
return
"(1 << (
%s))"
%
(
self
.
_print
(
expr
.
args
[
0
])
)
return
f
"(1 << (
{
self
.
_print
(
expr
.
args
[
0
])
}
))"
elif
expr
.
func
==
int_div
:
elif
expr
.
func
==
int_div
:
return
"((
%s) / (%s))"
%
(
self
.
_print
(
expr
.
args
[
0
])
,
self
.
_print
(
expr
.
args
[
1
])
)
return
f
"((
{
self
.
_print
(
expr
.
args
[
0
])
}
) / (
{
self
.
_print
(
expr
.
args
[
1
])
}
))"
else
:
else
:
name
=
expr
.
name
if
hasattr
(
expr
,
'name'
)
else
expr
.
__class__
.
__name__
name
=
expr
.
name
if
hasattr
(
expr
,
'name'
)
else
expr
.
__class__
.
__name__
arg_str
=
', '
.
join
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
arg_str
=
', '
.
join
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
...
@@ -540,14 +540,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
...
@@ -540,14 +540,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result
=
self
.
instruction_set
[
'/'
].
format
(
self
.
_print
(
expr
.
args
[
0
]),
self
.
_print
(
expr
.
args
[
1
]))
result
=
self
.
instruction_set
[
'/'
].
format
(
self
.
_print
(
expr
.
args
[
0
]),
self
.
_print
(
expr
.
args
[
1
]))
return
result
return
result
elif
expr
.
func
==
fast_sqrt
:
elif
expr
.
func
==
fast_sqrt
:
return
"({
})"
.
format
(
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
]))
)
return
f
"(
{
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
elif
expr
.
func
==
fast_inv_sqrt
:
elif
expr
.
func
==
fast_inv_sqrt
:
result
=
self
.
_scalarFallback
(
'_print_Function'
,
expr
)
result
=
self
.
_scalarFallback
(
'_print_Function'
,
expr
)
if
not
result
:
if
not
result
:
if
self
.
instruction_set
[
'rsqrt'
]:
if
self
.
instruction_set
[
'rsqrt'
]:
return
self
.
instruction_set
[
'rsqrt'
].
format
(
self
.
_print
(
expr
.
args
[
0
]))
return
self
.
instruction_set
[
'rsqrt'
].
format
(
self
.
_print
(
expr
.
args
[
0
]))
else
:
else
:
return
"({
})"
.
format
(
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
]))
)
return
f
"(
{
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
elif
isinstance
(
expr
,
vec_any
):
elif
isinstance
(
expr
,
vec_any
):
expr_type
=
get_type_of_expression
(
expr
.
args
[
0
])
expr_type
=
get_type_of_expression
(
expr
.
args
[
0
])
if
type
(
expr_type
)
is
not
VectorType
:
if
type
(
expr_type
)
is
not
VectorType
:
...
...
pystencils/backends/cuda_backend.py
View file @
bced838a
...
@@ -98,7 +98,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
...
@@ -98,7 +98,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
if
isinstance
(
expr
,
fast_division
):
if
isinstance
(
expr
,
fast_division
):
return
"__fdividef(%s, %s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
"__fdividef(%s, %s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
elif
isinstance
(
expr
,
fast_sqrt
):
elif
isinstance
(
expr
,
fast_sqrt
):
return
"__fsqrt_rn(
%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
f
"__fsqrt_rn(
{
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
}
)"
elif
isinstance
(
expr
,
fast_inv_sqrt
):
elif
isinstance
(
expr
,
fast_inv_sqrt
):
return
"__frsqrt_rn(
%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
f
"__frsqrt_rn(
{
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
}
)"
return
super
().
_print_Function
(
expr
)
return
super
().
_print_Function
(
expr
)
pystencils/backends/dot.py
View file @
bced838a
...
@@ -57,7 +57,7 @@ def __shortened(node):
...
@@ -57,7 +57,7 @@ def __shortened(node):
params
=
node
.
get_parameters
()
params
=
node
.
get_parameters
()
param_names
=
[
p
.
field_name
for
p
in
params
if
p
.
is_field_pointer
]
param_names
=
[
p
.
field_name
for
p
in
params
if
p
.
is_field_pointer
]
param_names
+=
[
p
.
symbol
.
name
for
p
in
params
if
not
p
.
is_field_parameter
]
param_names
+=
[
p
.
symbol
.
name
for
p
in
params
if
not
p
.
is_field_parameter
]
return
"Func:
%s (%s)"
%
(
node
.
function_name
,
","
.
join
(
param_names
)
)
return
f
"Func:
{
node
.
function_name
}
(
{
','
.
join
(
param_names
)
}
)"
elif
isinstance
(
node
,
SympyAssignment
):
elif
isinstance
(
node
,
SympyAssignment
):
return
repr
(
node
.
lhs
)
return
repr
(
node
.
lhs
)
elif
isinstance
(
node
,
Block
):
elif
isinstance
(
node
,
Block
):
...
@@ -65,7 +65,7 @@ def __shortened(node):
...
@@ -65,7 +65,7 @@ def __shortened(node):
elif
isinstance
(
node
,
Conditional
):
elif
isinstance
(
node
,
Conditional
):
return
repr
(
node
)
return
repr
(
node
)
else
:
else
:
raise
NotImplementedError
(
"Cannot handle node type
%s"
%
(
type
(
node
)
,)
)
raise
NotImplementedError
(
f
"Cannot handle node type
{
type
(
node
)
}
"
)
def
print_dot
(
node
,
view
=
False
,
short
=
False
,
**
kwargs
):
def
print_dot
(
node
,
view
=
False
,
short
=
False
,
**
kwargs
):
...
...
pystencils/backends/opencl_backend.py
View file @
bced838a
...
@@ -98,7 +98,7 @@ class OpenClSympyPrinter(CudaSympyPrinter):
...
@@ -98,7 +98,7 @@ class OpenClSympyPrinter(CudaSympyPrinter):
if
isinstance
(
expr
,
fast_division
):
if
isinstance
(
expr
,
fast_division
):
return
"native_divide(%s, %s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
"native_divide(%s, %s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
elif
isinstance
(
expr
,
fast_sqrt
):
elif
isinstance
(
expr
,
fast_sqrt
):
return
"native_sqrt(
%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
f
"native_sqrt(
{
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
}
)"
elif
isinstance
(
expr
,
fast_inv_sqrt
):
elif
isinstance
(
expr
,
fast_inv_sqrt
):
return
"native_rsqrt(
%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
f
"native_rsqrt(
{
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
}
)"
return
CustomSympyPrinter
.
_print_Function
(
self
,
expr
)
return
CustomSympyPrinter
.
_print_Function
(
self
,
expr
)
pystencils/backends/simd_instruction_sets.py
View file @
bced838a
...
@@ -51,7 +51,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
...
@@ -51,7 +51,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
})
})
for
comparison_op
,
constant
in
comparisons
.
items
():
for
comparison_op
,
constant
in
comparisons
.
items
():
base_names
[
comparison_op
]
=
'cmp[0, 1,
%s]'
%
(
constant
,)
base_names
[
comparison_op
]
=
f
'cmp[0, 1,
{
constant
}
]'
headers
=
{
headers
=
{
'avx512'
:
[
'<immintrin.h>'
],
'avx512'
:
[
'<immintrin.h>'
],
...
@@ -89,16 +89,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
...
@@ -89,16 +89,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
name
=
function_shortcut
[:
function_shortcut
.
index
(
'['
)]
name
=
function_shortcut
[:
function_shortcut
.
index
(
'['
)]
if
intrinsic_id
==
'makeVecConst'
:
if
intrinsic_id
==
'makeVecConst'
:
arg_string
=
"({
})"
.
format
(
","
.
join
([
"
{0}
"
]
*
result
[
'width'
])
)
arg_string
=
f
"(
{
','
.
join
([
'
{
0
}
'
] * result['
width
'])
}
)"
elif
intrinsic_id
==
'makeVec'
:
elif
intrinsic_id
==
'makeVec'
:
params
=
[
"{"
+
str
(
i
)
+
"}"
for
i
in
reversed
(
range
(
result
[
'width'
]))]
params
=
[
"{"
+
str
(
i
)
+
"}"
for
i
in
reversed
(
range
(
result
[
'width'
]))]
arg_string
=
"({
})"
.
format
(
","
.
join
(
params
)
)
arg_string
=
f
"(
{
','
.
join
(
params
)
}
)"
elif
intrinsic_id
==
'makeVecBool'
:
elif
intrinsic_id
==
'makeVecBool'
:
params
=
[
"(({{{i}}} ? -1.0 : 0.0)"
.
format
(
i
=
i
)
for
i
in
reversed
(
range
(
result
[
'width'
]))]
params
=
[
f
"(({{
{
i
}
}} ? -1.0 : 0.0)"
for
i
in
reversed
(
range
(
result
[
'width'
]))]
arg_string
=
"({
})"
.
format
(
","
.
join
(
params
)
)
arg_string
=
f
"(
{
','
.
join
(
params
)
}
)"
elif
intrinsic_id
==
'makeVecConstBool'
:
elif
intrinsic_id
==
'makeVecConstBool'
:
params
=
[
"(({0}) ? -1.0 : 0.0)"
for
_
in
range
(
result
[
'width'
])]
params
=
[
"(({0}) ? -1.0 : 0.0)"
for
_
in
range
(
result
[
'width'
])]
arg_string
=
"({
})"
.
format
(
","
.
join
(
params
)
)
arg_string
=
f
"(
{
','
.
join
(
params
)
}
)"
else
:
else
:
args
=
function_shortcut
[
function_shortcut
.
index
(
'['
)
+
1
:
-
1
]
args
=
function_shortcut
[
function_shortcut
.
index
(
'['
)
+
1
:
-
1
]
arg_string
=
"("
arg_string
=
"("
...
@@ -141,9 +141,9 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
...
@@ -141,9 +141,9 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
result
[
'bool'
]
=
"__mmask%d"
%
(
size
,)
result
[
'bool'
]
=
"__mmask%d"
%
(
size
,)
params
=
" | "
.
join
([
"({{{i}}} ? {power} : 0)"
.
format
(
i
=
i
,
power
=
2
**
i
)
for
i
in
range
(
8
)])
params
=
" | "
.
join
([
"({{{i}}} ? {power} : 0)"
.
format
(
i
=
i
,
power
=
2
**
i
)
for
i
in
range
(
8
)])
result
[
'makeVecBool'
]
=
"__mmask8(({
}) )"
.
format
(
params
)
result
[
'makeVecBool'
]
=
f
"__mmask8((
{
params
}
) )"
params
=
" | "
.
join
([
"({{0}} ? {power} : 0)"
.
format
(
power
=
2
**
i
)
for
i
in
range
(
8
)])
params
=
" | "
.
join
([
"({{0}} ? {power} : 0)"
.
format
(
power
=
2
**
i
)
for
i
in
range
(
8
)])
result
[
'makeVecConstBool'
]
=
"__mmask8(({
}) )"
.
format
(
params
)
result
[
'makeVecConstBool'
]
=
f
"__mmask8((
{
params
}
) )"
if
instruction_set
==
'avx'
and
data_type
==
'float'
:
if
instruction_set
==
'avx'
and
data_type
==
'float'
:
result
[
'rsqrt'
]
=
"_mm256_rsqrt_ps({0})"
result
[
'rsqrt'
]
=
"_mm256_rsqrt_ps({0})"
...
...
pystencils/boundaries/boundaryhandling.py
View file @
bced838a
...
@@ -66,13 +66,13 @@ class FlagInterface:
...
@@ -66,13 +66,13 @@ class FlagInterface:
self
.
_used_flags
.
add
(
flag
)
self
.
_used_flags
.
add
(
flag
)
assert
self
.
_is_power_of_2
(
flag
)
assert
self
.
_is_power_of_2
(
flag
)
return
flag
return
flag
raise
ValueError
(
"All available {} flags are reserved"
.
format
(
self
.
max_bits
)
)
raise
ValueError
(
f
"All available
{
self
.
max_bits
}
flags are reserved"
)
def
reserve_flag
(
self
,
flag
):
def
reserve_flag
(
self
,
flag
):
assert
self
.
_is_power_of_2
(
flag
)
assert
self
.
_is_power_of_2
(
flag
)
flag
=
self
.
dtype
(
flag
)
flag
=
self
.
dtype
(
flag
)
if
flag
in
self
.
_used_flags
:
if
flag
in
self
.
_used_flags
:
raise
ValueError
(
"The flag {flag} is already reserved"
.
format
(
flag
=
flag
)
)
raise
ValueError
(
f
"The flag
{
flag
}
is already reserved"
)
self
.
_used_flags
.
add
(
flag
)
self
.
_used_flags
.
add
(
flag
)
return
flag
return
flag
...
@@ -392,12 +392,12 @@ class BoundaryDataSetter:
...
@@ -392,12 +392,12 @@ class BoundaryDataSetter:
def
__setitem__
(
self
,
key
,
value
):
def
__setitem__
(
self
,
key
,
value
):
if
key
not
in
self
.
boundary_data_names
:
if
key
not
in
self
.
boundary_data_names
:
raise
KeyError
(
"Invalid boundary data name
%s
. Allowed are
%s"
%
(
key
,
self
.
boundary_data_names
)
)
raise
KeyError
(
f
"Invalid boundary data name
{
key
}
. Allowed are
{
self
.
boundary_data_names
}
"
)
self
.
index_array
[
key
]
=
value
self
.
index_array
[
key
]
=
value
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
if
item
not
in
self
.
boundary_data_names
:
if
item
not
in
self
.
boundary_data_names
:
raise
KeyError
(
"Invalid boundary data name
%s
. Allowed are
%s"
%
(
item
,
self
.
boundary_data_names
)
)
raise
KeyError
(
f
"Invalid boundary data name
{
item
}
. Allowed are
{
self
.
boundary_data_names
}
"
)
return
self
.
index_array
[
item
]
return
self
.
index_array
[
item
]
...
@@ -437,7 +437,7 @@ class BoundaryOffsetInfo(CustomCodeNode):
...
@@ -437,7 +437,7 @@ class BoundaryOffsetInfo(CustomCodeNode):
@
staticmethod
@
staticmethod
def
_offset_symbols
(
dim
):
def
_offset_symbols
(
dim
):
return
[
TypedSymbol
(
"c
%s"
%
(
d
,)
,
create_type
(
np
.
int64
))
for
d
in
[
'x'
,
'y'
,
'z'
][:
dim
]]
return
[
TypedSymbol
(
f
"c
{
d
}
"
,
create_type
(
np
.
int64
))
for
d
in
[
'x'
,
'y'
,
'z'
][:
dim
]]
INV_DIR_SYMBOL
=
TypedSymbol
(
"invdir"
,
"int"
)
INV_DIR_SYMBOL
=
TypedSymbol
(
"invdir"
,
"int"
)
...
...
pystencils/cpu/cpujit.py
View file @
bced838a
...
@@ -362,7 +362,7 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
...
@@ -362,7 +362,7 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
field
=
param
.
fields
[
0
]
field
=
param
.
fields
[
0
]
pre_call_code
+=
template_extract_array
.
format
(
name
=
field
.
name
)
pre_call_code
+=
template_extract_array
.
format
(
name
=
field
.
name
)
post_call_code
+=
template_release_buffer
.
format
(
name
=
field
.
name
)
post_call_code
+=
template_release_buffer
.
format
(
name
=
field
.
name
)
parameters
.
append
(
"({
dtype} *)buffer_{name}.buf"
.
format
(
dtype
=
str
(
field
.
dtype
)
,
name
=
field
.
name
)
)
parameters
.
append
(
f
"(
{
str
(
field
.
dtype
)
}
*)buffer_
{
field
.
name
}
.buf"
)