Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stephan Seitz
pystencils
Commits
bced838a
Commit
bced838a
authored
Jun 19, 2020
by
Markus Holzer
Browse files
Replaced all format strings with f-strings
parent
c83f4f5d
Changes
35
Show whitespace changes
Inline
Side-by-side
pystencils/assignment.py
View file @
bced838a
...
...
@@ -53,7 +53,7 @@ else:
# Tuple of things that can be on the lhs of an assignment
assignable
=
(
sp
.
Symbol
,
MatrixSymbol
,
MatrixElement
,
sp
.
Indexed
)
if
not
isinstance
(
lhs
,
assignable
):
raise
TypeError
(
"Cannot assign to lhs of type
%s."
%
type
(
lhs
))
raise
TypeError
(
f
"Cannot assign to lhs of type
{
type
(
lhs
)
}
."
)
return
sp
.
Rel
.
__new__
(
cls
,
lhs
,
rhs
,
**
assumptions
)
__str__
=
assignment_str
...
...
pystencils/astnodes.py
View file @
bced838a
...
...
@@ -113,12 +113,12 @@ class Conditional(Node):
return
self
.
__repr__
()
def
__repr__
(
self
):
repr
=
'if:({
!r}) '
.
format
(
self
.
condition_expr
)
repr
=
f
'if:(
{
self
.
condition_expr
!r}
) '
if
self
.
true_block
:
repr
+=
'
\n\t
{
}) '
.
format
(
self
.
true_block
)
repr
+=
f
'
\n\t
{
self
.
true_block
}
) '
if
self
.
false_block
:
repr
=
'else: '
repr
+=
'
\n\t
{
} '
.
format
(
self
.
false_block
)
repr
+=
f
'
\n\t
{
self
.
false_block
}
'
return
repr
...
...
@@ -264,7 +264,7 @@ class KernelFunction(Node):
def
__repr__
(
self
):
params
=
[
p
.
symbol
for
p
in
self
.
get_parameters
()]
return
'{
0} {1}({2})'
.
format
(
type
(
self
).
__name__
,
self
.
function_name
,
params
)
return
f
'
{
type
(
self
).
__name__
}
{
self
.
function_name
}
(
{
params
}
)'
def
compile
(
self
,
*
args
,
**
kwargs
):
if
self
.
_compile_function
is
None
:
...
...
@@ -475,11 +475,11 @@ class LoopOverCoordinate(Node):
@
staticmethod
def
get_loop_counter_name
(
coordinate_to_loop_over
):
return
"%s_%s"
%
(
LoopOverCoordinate
.
LOOP_COUNTER_NAME_PREFIX
,
coordinate_to_loop_over
)
return
f
"
{
LoopOverCoordinate
.
LOOP_COUNTER_NAME_PREFIX
}
_
{
coordinate_to_loop_over
}
"
@
staticmethod
def
get_block_loop_counter_name
(
coordinate_to_loop_over
):
return
"%s_%s"
%
(
LoopOverCoordinate
.
BlOCK_LOOP_COUNTER_NAME_PREFIX
,
coordinate_to_loop_over
)
return
f
"
{
LoopOverCoordinate
.
BlOCK_LOOP_COUNTER_NAME_PREFIX
}
_
{
coordinate_to_loop_over
}
"
@
property
def
loop_counter_name
(
self
):
...
...
@@ -612,7 +612,7 @@ class SympyAssignment(Node):
replacement
.
parent
=
self
self
.
rhs
=
replacement
else
:
raise
ValueError
(
'%s
is not in args of
%s'
%
(
replacement
,
self
.
__class__
)
)
raise
ValueError
(
f
'
{
replacement
}
is not in args of
{
self
.
__class__
}
'
)
def
__repr__
(
self
):
return
repr
(
self
.
lhs
)
+
" ← "
+
repr
(
self
.
rhs
)
...
...
@@ -620,7 +620,7 @@ class SympyAssignment(Node):
def
_repr_html_
(
self
):
printed_lhs
=
sp
.
latex
(
self
.
lhs
)
printed_rhs
=
sp
.
latex
(
self
.
rhs
)
return
"${printed_lhs}
\\
leftarrow {printed_rhs}$"
.
format
(
printed_lhs
=
printed_lhs
,
printed_rhs
=
printed_rhs
)
return
f
"$
{
printed_lhs
}
\\
leftarrow
{
printed_rhs
}
$"
def
__hash__
(
self
):
return
hash
((
self
.
lhs
,
self
.
rhs
))
...
...
@@ -663,7 +663,7 @@ class ResolvedFieldAccess(sp.Indexed):
def
__str__
(
self
):
top
=
super
(
ResolvedFieldAccess
,
self
).
__str__
()
return
"%s (%s)"
%
(
top
,
self
.
typed_symbol
.
dtype
)
return
f
"
{
top
}
(
{
self
.
typed_symbol
.
dtype
}
)"
def
__getnewargs__
(
self
):
return
self
.
base
,
self
.
indices
[
0
],
self
.
field
,
self
.
offsets
,
self
.
idx_coordinate_values
...
...
@@ -740,7 +740,7 @@ def early_out(condition):
def
get_dummy_symbol
(
dtype
=
'bool'
):
return
TypedSymbol
(
'dummy
%s'
%
uuid
.
uuid4
().
hex
,
create_type
(
dtype
))
return
TypedSymbol
(
f
'dummy
{
uuid
.
uuid4
().
hex
}
'
,
create_type
(
dtype
))
class
SourceCodeComment
(
Node
):
...
...
pystencils/backends/cbackend.py
View file @
bced838a
...
...
@@ -158,7 +158,7 @@ class CustomCodeNode(Node):
class
PrintNode
(
CustomCodeNode
):
# noinspection SpellCheckingInspection
def
__init__
(
self
,
symbol_to_print
):
code
=
'
\n
std::cout << "
%s = " << %s << std::endl;
\n
'
%
(
symbol_to_print
.
name
,
symbol_to_print
.
name
)
code
=
f
'
\n
std::cout << "
{
symbol_to_print
.
name
}
= " <<
{
symbol_to_print
.
name
}
<< std::endl;
\n
'
super
(
PrintNode
,
self
).
__init__
(
code
,
symbols_read
=
[
symbol_to_print
],
symbols_defined
=
set
())
self
.
headers
.
append
(
"<iostream>"
)
...
...
@@ -203,12 +203,12 @@ class CBackend:
return
str
(
node
)
def
_print_KernelFunction
(
self
,
node
):
function_arguments
=
[
"%s %s"
%
(
self
.
_print
(
s
.
symbol
.
dtype
)
,
s
.
symbol
.
name
)
for
s
in
node
.
get_parameters
()]
function_arguments
=
[
f
"
{
self
.
_print
(
s
.
symbol
.
dtype
)
}
{
s
.
symbol
.
name
}
"
for
s
in
node
.
get_parameters
()]
launch_bounds
=
""
if
self
.
_dialect
==
'cuda'
:
max_threads
=
node
.
indexing
.
max_threads_per_block
()
if
max_threads
:
launch_bounds
=
"__launch_bounds__({
}) "
.
format
(
max_threads
)
launch_bounds
=
f
"__launch_bounds__(
{
max_threads
}
) "
func_declaration
=
"FUNC_PREFIX %svoid %s(%s)"
%
(
launch_bounds
,
node
.
function_name
,
", "
.
join
(
function_arguments
))
if
self
.
_signatureOnly
:
...
...
@@ -222,19 +222,19 @@ class CBackend:
return
"{
\n
%s
\n
}"
%
(
self
.
_indent
+
self
.
_indent
.
join
(
block_contents
.
splitlines
(
True
)))
def
_print_PragmaBlock
(
self
,
node
):
return
"%s
\n
%s"
%
(
node
.
pragma_line
,
self
.
_print_Block
(
node
)
)
return
f
"
{
node
.
pragma_line
}
\n
{
self
.
_print_Block
(
node
)
}
"
def
_print_LoopOverCoordinate
(
self
,
node
):
counter_symbol
=
node
.
loop_counter_name
start
=
"int
%s = %s"
%
(
counter_symbol
,
self
.
sympy_printer
.
doprint
(
node
.
start
)
)
condition
=
"%s < %s"
%
(
counter_symbol
,
self
.
sympy_printer
.
doprint
(
node
.
stop
)
)
update
=
"%s += %s"
%
(
counter_symbol
,
self
.
sympy_printer
.
doprint
(
node
.
step
)
,)
loop_str
=
"for (
%s; %s; %s)"
%
(
start
,
condition
,
update
)
start
=
f
"int
{
counter_symbol
}
=
{
self
.
sympy_printer
.
doprint
(
node
.
start
)
}
"
condition
=
f
"
{
counter_symbol
}
<
{
self
.
sympy_printer
.
doprint
(
node
.
stop
)
}
"
update
=
f
"
{
counter_symbol
}
+=
{
self
.
sympy_printer
.
doprint
(
node
.
step
)
}
"
loop_str
=
f
"for (
{
start
}
;
{
condition
}
;
{
update
}
)"
prefix
=
"
\n
"
.
join
(
node
.
prefix_lines
)
if
prefix
:
prefix
+=
"
\n
"
return
"%s%s
\n
%s"
%
(
prefix
,
loop_str
,
self
.
_print
(
node
.
body
)
)
return
f
"
{
prefix
}{
loop_str
}
\n
{
self
.
_print
(
node
.
body
)
}
"
def
_print_SympyAssignment
(
self
,
node
):
if
node
.
is_declaration
:
...
...
@@ -262,7 +262,7 @@ class CBackend:
instr
=
'maskStore'
if
aligned
else
'maskStoreU'
printed_mask
=
self
.
sympy_printer
.
doprint
(
mask
)
if
self
.
_vector_instruction_set
[
'dataTypePrefix'
][
'double'
]
==
'__mm256d'
:
printed_mask
=
"_mm256_castpd_si256({
})"
.
format
(
printed_mask
)
printed_mask
=
f
"_mm256_castpd_si256(
{
printed_mask
}
)"
rhs_type
=
get_type_of_expression
(
node
.
rhs
)
if
type
(
rhs_type
)
is
not
VectorType
:
...
...
@@ -274,7 +274,7 @@ class CBackend:
self
.
sympy_printer
.
doprint
(
rhs
),
printed_mask
)
+
';'
else
:
return
"%s = %s;"
%
(
self
.
sympy_printer
.
doprint
(
node
.
lhs
)
,
self
.
sympy_printer
.
doprint
(
node
.
rhs
)
)
return
f
"
{
self
.
sympy_printer
.
doprint
(
node
.
lhs
)
}
=
{
self
.
sympy_printer
.
doprint
(
node
.
rhs
)
}
;"
def
_print_TemporaryMemoryAllocation
(
self
,
node
):
align
=
64
...
...
@@ -314,7 +314,7 @@ class CBackend:
raise
ValueError
(
"Problem with Conditional inside vectorized loop - use vec_any or vec_all"
)
condition_expr
=
self
.
sympy_printer
.
doprint
(
node
.
condition_expr
)
true_block
=
self
.
_print_Block
(
node
.
true_block
)
result
=
"if (
%s)
\n
%s "
%
(
condition_expr
,
true_block
)
result
=
f
"if (
{
condition_expr
}
)
\n
{
true_block
}
"
if
node
.
false_block
:
false_block
=
self
.
_print_Block
(
node
.
false_block
)
result
+=
"else "
+
false_block
...
...
@@ -343,7 +343,7 @@ class CustomSympyPrinter(CCodePrinter):
if
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
0
<
expr
.
exp
<
8
:
return
"("
+
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
expr
.
exp
,
evaluate
=
False
))
+
")"
elif
expr
.
exp
.
is_integer
and
expr
.
exp
.
is_number
and
-
8
<
expr
.
exp
<
0
:
return
"1 / ({
})"
.
format
(
self
.
_print
(
sp
.
Mul
(
*
[
expr
.
base
]
*
(
-
expr
.
exp
),
evaluate
=
False
))
)
return
f
"1 / (
{
self
.
_print
(
sp
.
Mul
(
*
(
[
expr
.
base
]
*
-
expr
.
exp
),
evaluate
=
False
))
}
)"
else
:
return
super
(
CustomSympyPrinter
,
self
).
_print_Pow
(
expr
)
...
...
@@ -363,9 +363,9 @@ class CustomSympyPrinter(CCodePrinter):
def
_print_Abs
(
self
,
expr
):
if
expr
.
args
[
0
].
is_integer
:
return
'abs({
0})'
.
format
(
self
.
_print
(
expr
.
args
[
0
])
)
return
f
'abs(
{
self
.
_print
(
expr
.
args
[
0
])
}
)'
else
:
return
'fabs({
0})'
.
format
(
self
.
_print
(
expr
.
args
[
0
])
)
return
f
'fabs(
{
self
.
_print
(
expr
.
args
[
0
])
}
)'
def
_print_Type
(
self
,
node
):
return
str
(
node
)
...
...
@@ -382,37 +382,37 @@ class CustomSympyPrinter(CCodePrinter):
return
expr
.
to_c
(
self
.
_print
)
if
isinstance
(
expr
,
reinterpret_cast_func
):
arg
,
data_type
=
expr
.
args
return
"*((
%s)(& %s))"
%
(
self
.
_print
(
PointerType
(
data_type
,
restrict
=
False
))
,
self
.
_print
(
arg
)
)
return
f
"*((
{
self
.
_print
(
PointerType
(
data_type
,
restrict
=
False
))
}
)(&
{
self
.
_print
(
arg
)
}
))"
elif
isinstance
(
expr
,
address_of
):
assert
len
(
expr
.
args
)
==
1
,
"address_of must only have one argument"
return
"&(
%s)"
%
self
.
_print
(
expr
.
args
[
0
])
return
f
"&(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
elif
isinstance
(
expr
,
cast_func
):
arg
,
data_type
=
expr
.
args
if
isinstance
(
arg
,
sp
.
Number
)
and
arg
.
is_finite
:
return
self
.
_typed_number
(
arg
,
data_type
)
else
:
return
"((
%s)(%s))"
%
(
data_type
,
self
.
_print
(
arg
)
)
return
f
"((
{
data_type
}
)(
{
self
.
_print
(
arg
)
}
))"
elif
isinstance
(
expr
,
fast_division
):
return
"({
})"
.
format
(
self
.
_print
(
expr
.
args
[
0
]
/
expr
.
args
[
1
])
)
return
f
"(
{
self
.
_print
(
expr
.
args
[
0
]
/
expr
.
args
[
1
])
}
)"
elif
isinstance
(
expr
,
fast_sqrt
):
return
"({
})"
.
format
(
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
]))
)
return
f
"(
{
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
elif
isinstance
(
expr
,
vec_any
)
or
isinstance
(
expr
,
vec_all
):
return
self
.
_print
(
expr
.
args
[
0
])
elif
isinstance
(
expr
,
fast_inv_sqrt
):
return
"({
})"
.
format
(
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
]))
)
return
f
"(
{
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
elif
isinstance
(
expr
,
sp
.
Abs
):
return
"abs({
})"
.
format
(
self
.
_print
(
expr
.
args
[
0
])
)
return
f
"abs(
{
self
.
_print
(
expr
.
args
[
0
])
}
)"
elif
isinstance
(
expr
,
sp
.
Mod
):
if
expr
.
args
[
0
].
is_integer
and
expr
.
args
[
1
].
is_integer
:
return
"({
} % {})"
.
format
(
self
.
_print
(
expr
.
args
[
0
])
,
self
.
_print
(
expr
.
args
[
1
])
)
return
f
"(
{
self
.
_print
(
expr
.
args
[
0
])
}
%
{
self
.
_print
(
expr
.
args
[
1
])
}
)"
else
:
return
"fmod({
}, {})"
.
format
(
self
.
_print
(
expr
.
args
[
0
]),
self
.
_print
(
expr
.
args
[
1
])
)
return
f
"fmod(
{
self
.
_print
(
expr
.
args
[
0
])
}
,
{
self
.
_print
(
expr
.
args
[
1
])
}
)"
elif
expr
.
func
in
infix_functions
:
return
"(
%s %s %s)"
%
(
self
.
_print
(
expr
.
args
[
0
])
,
infix_functions
[
expr
.
func
]
,
self
.
_print
(
expr
.
args
[
1
])
)
return
f
"(
{
self
.
_print
(
expr
.
args
[
0
])
}
{
infix_functions
[
expr
.
func
]
}
{
self
.
_print
(
expr
.
args
[
1
])
}
)"
elif
expr
.
func
==
int_power_of_2
:
return
"(1 << (
%s))"
%
(
self
.
_print
(
expr
.
args
[
0
])
)
return
f
"(1 << (
{
self
.
_print
(
expr
.
args
[
0
])
}
))"
elif
expr
.
func
==
int_div
:
return
"((
%s) / (%s))"
%
(
self
.
_print
(
expr
.
args
[
0
])
,
self
.
_print
(
expr
.
args
[
1
])
)
return
f
"((
{
self
.
_print
(
expr
.
args
[
0
])
}
) / (
{
self
.
_print
(
expr
.
args
[
1
])
}
))"
else
:
name
=
expr
.
name
if
hasattr
(
expr
,
'name'
)
else
expr
.
__class__
.
__name__
arg_str
=
', '
.
join
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
...
...
@@ -540,14 +540,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result
=
self
.
instruction_set
[
'/'
].
format
(
self
.
_print
(
expr
.
args
[
0
]),
self
.
_print
(
expr
.
args
[
1
]))
return
result
elif
expr
.
func
==
fast_sqrt
:
return
"({
})"
.
format
(
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
]))
)
return
f
"(
{
self
.
_print
(
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
elif
expr
.
func
==
fast_inv_sqrt
:
result
=
self
.
_scalarFallback
(
'_print_Function'
,
expr
)
if
not
result
:
if
self
.
instruction_set
[
'rsqrt'
]:
return
self
.
instruction_set
[
'rsqrt'
].
format
(
self
.
_print
(
expr
.
args
[
0
]))
else
:
return
"({
})"
.
format
(
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
]))
)
return
f
"(
{
self
.
_print
(
1
/
sp
.
sqrt
(
expr
.
args
[
0
]))
}
)"
elif
isinstance
(
expr
,
vec_any
):
expr_type
=
get_type_of_expression
(
expr
.
args
[
0
])
if
type
(
expr_type
)
is
not
VectorType
:
...
...
pystencils/backends/cuda_backend.py
View file @
bced838a
...
...
@@ -98,7 +98,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
if
isinstance
(
expr
,
fast_division
):
return
"__fdividef(%s, %s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
elif
isinstance
(
expr
,
fast_sqrt
):
return
"__fsqrt_rn(
%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
f
"__fsqrt_rn(
{
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
}
)"
elif
isinstance
(
expr
,
fast_inv_sqrt
):
return
"__frsqrt_rn(
%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
f
"__frsqrt_rn(
{
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
}
)"
return
super
().
_print_Function
(
expr
)
pystencils/backends/dot.py
View file @
bced838a
...
...
@@ -57,7 +57,7 @@ def __shortened(node):
params
=
node
.
get_parameters
()
param_names
=
[
p
.
field_name
for
p
in
params
if
p
.
is_field_pointer
]
param_names
+=
[
p
.
symbol
.
name
for
p
in
params
if
not
p
.
is_field_parameter
]
return
"Func:
%s (%s)"
%
(
node
.
function_name
,
","
.
join
(
param_names
)
)
return
f
"Func:
{
node
.
function_name
}
(
{
','
.
join
(
param_names
)
}
)"
elif
isinstance
(
node
,
SympyAssignment
):
return
repr
(
node
.
lhs
)
elif
isinstance
(
node
,
Block
):
...
...
@@ -65,7 +65,7 @@ def __shortened(node):
elif
isinstance
(
node
,
Conditional
):
return
repr
(
node
)
else
:
raise
NotImplementedError
(
"Cannot handle node type
%s"
%
(
type
(
node
)
,)
)
raise
NotImplementedError
(
f
"Cannot handle node type
{
type
(
node
)
}
"
)
def
print_dot
(
node
,
view
=
False
,
short
=
False
,
**
kwargs
):
...
...
pystencils/backends/opencl_backend.py
View file @
bced838a
...
...
@@ -98,7 +98,7 @@ class OpenClSympyPrinter(CudaSympyPrinter):
if
isinstance
(
expr
,
fast_division
):
return
"native_divide(%s, %s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
elif
isinstance
(
expr
,
fast_sqrt
):
return
"native_sqrt(
%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
f
"native_sqrt(
{
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
}
)"
elif
isinstance
(
expr
,
fast_inv_sqrt
):
return
"native_rsqrt(
%s)"
%
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
return
f
"native_rsqrt(
{
tuple
(
self
.
_print
(
a
)
for
a
in
expr
.
args
)
}
)"
return
CustomSympyPrinter
.
_print_Function
(
self
,
expr
)
pystencils/backends/simd_instruction_sets.py
View file @
bced838a
...
...
@@ -51,7 +51,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
})
for
comparison_op
,
constant
in
comparisons
.
items
():
base_names
[
comparison_op
]
=
'cmp[0, 1,
%s]'
%
(
constant
,)
base_names
[
comparison_op
]
=
f
'cmp[0, 1,
{
constant
}
]'
headers
=
{
'avx512'
:
[
'<immintrin.h>'
],
...
...
@@ -89,16 +89,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
name
=
function_shortcut
[:
function_shortcut
.
index
(
'['
)]
if
intrinsic_id
==
'makeVecConst'
:
arg_string
=
"({
})"
.
format
(
","
.
join
([
"
{0}
"
]
*
result
[
'width'
])
)
arg_string
=
f
"(
{
','
.
join
([
'
{
0
}
'
] * result['
width
'])
}
)"
elif
intrinsic_id
==
'makeVec'
:
params
=
[
"{"
+
str
(
i
)
+
"}"
for
i
in
reversed
(
range
(
result
[
'width'
]))]
arg_string
=
"({
})"
.
format
(
","
.
join
(
params
)
)
arg_string
=
f
"(
{
','
.
join
(
params
)
}
)"
elif
intrinsic_id
==
'makeVecBool'
:
params
=
[
"(({{{i}}} ? -1.0 : 0.0)"
.
format
(
i
=
i
)
for
i
in
reversed
(
range
(
result
[
'width'
]))]
arg_string
=
"({
})"
.
format
(
","
.
join
(
params
)
)
params
=
[
f
"(({{
{
i
}
}} ? -1.0 : 0.0)"
for
i
in
reversed
(
range
(
result
[
'width'
]))]
arg_string
=
f
"(
{
','
.
join
(
params
)
}
)"
elif
intrinsic_id
==
'makeVecConstBool'
:
params
=
[
"(({0}) ? -1.0 : 0.0)"
for
_
in
range
(
result
[
'width'
])]
arg_string
=
"({
})"
.
format
(
","
.
join
(
params
)
)
arg_string
=
f
"(
{
','
.
join
(
params
)
}
)"
else
:
args
=
function_shortcut
[
function_shortcut
.
index
(
'['
)
+
1
:
-
1
]
arg_string
=
"("
...
...
@@ -141,9 +141,9 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
result
[
'bool'
]
=
"__mmask%d"
%
(
size
,)
params
=
" | "
.
join
([
"({{{i}}} ? {power} : 0)"
.
format
(
i
=
i
,
power
=
2
**
i
)
for
i
in
range
(
8
)])
result
[
'makeVecBool'
]
=
"__mmask8(({
}) )"
.
format
(
params
)
result
[
'makeVecBool'
]
=
f
"__mmask8((
{
params
}
) )"
params
=
" | "
.
join
([
"({{0}} ? {power} : 0)"
.
format
(
power
=
2
**
i
)
for
i
in
range
(
8
)])
result
[
'makeVecConstBool'
]
=
"__mmask8(({
}) )"
.
format
(
params
)
result
[
'makeVecConstBool'
]
=
f
"__mmask8((
{
params
}
) )"
if
instruction_set
==
'avx'
and
data_type
==
'float'
:
result
[
'rsqrt'
]
=
"_mm256_rsqrt_ps({0})"
...
...
pystencils/boundaries/boundaryhandling.py
View file @
bced838a
...
...
@@ -66,13 +66,13 @@ class FlagInterface:
self
.
_used_flags
.
add
(
flag
)
assert
self
.
_is_power_of_2
(
flag
)
return
flag
raise
ValueError
(
"All available {} flags are reserved"
.
format
(
self
.
max_bits
)
)
raise
ValueError
(
f
"All available
{
self
.
max_bits
}
flags are reserved"
)
def
reserve_flag
(
self
,
flag
):
assert
self
.
_is_power_of_2
(
flag
)
flag
=
self
.
dtype
(
flag
)
if
flag
in
self
.
_used_flags
:
raise
ValueError
(
"The flag {flag} is already reserved"
.
format
(
flag
=
flag
)
)
raise
ValueError
(
f
"The flag
{
flag
}
is already reserved"
)
self
.
_used_flags
.
add
(
flag
)
return
flag
...
...
@@ -392,12 +392,12 @@ class BoundaryDataSetter:
def
__setitem__
(
self
,
key
,
value
):
if
key
not
in
self
.
boundary_data_names
:
raise
KeyError
(
"Invalid boundary data name
%s
. Allowed are
%s"
%
(
key
,
self
.
boundary_data_names
)
)
raise
KeyError
(
f
"Invalid boundary data name
{
key
}
. Allowed are
{
self
.
boundary_data_names
}
"
)
self
.
index_array
[
key
]
=
value
def
__getitem__
(
self
,
item
):
if
item
not
in
self
.
boundary_data_names
:
raise
KeyError
(
"Invalid boundary data name
%s
. Allowed are
%s"
%
(
item
,
self
.
boundary_data_names
)
)
raise
KeyError
(
f
"Invalid boundary data name
{
item
}
. Allowed are
{
self
.
boundary_data_names
}
"
)
return
self
.
index_array
[
item
]
...
...
@@ -437,7 +437,7 @@ class BoundaryOffsetInfo(CustomCodeNode):
@
staticmethod
def
_offset_symbols
(
dim
):
return
[
TypedSymbol
(
"c
%s"
%
(
d
,)
,
create_type
(
np
.
int64
))
for
d
in
[
'x'
,
'y'
,
'z'
][:
dim
]]
return
[
TypedSymbol
(
f
"c
{
d
}
"
,
create_type
(
np
.
int64
))
for
d
in
[
'x'
,
'y'
,
'z'
][:
dim
]]
INV_DIR_SYMBOL
=
TypedSymbol
(
"invdir"
,
"int"
)
...
...
pystencils/cpu/cpujit.py
View file @
bced838a
...
...
@@ -362,7 +362,7 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
field
=
param
.
fields
[
0
]
pre_call_code
+=
template_extract_array
.
format
(
name
=
field
.
name
)
post_call_code
+=
template_release_buffer
.
format
(
name
=
field
.
name
)
parameters
.
append
(
"({
dtype} *)buffer_{name}.buf"
.
format
(
dtype
=
str
(
field
.
dtype
)
,
name
=
field
.
name
)
)
parameters
.
append
(
f
"(
{
str
(
field
.
dtype
)
}
*)buffer_
{
field
.
name
}
.buf"
)
if
insert_checks
:
np_dtype
=
field
.
dtype
.
numpy_dtype
...
...
@@ -375,12 +375,12 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
pre_call_code
+=
template_check_array
.
format
(
cond
=
dtype_cond
,
what
=
"data type"
,
name
=
field
.
name
,
expected
=
str
(
field
.
dtype
.
numpy_dtype
))
item_size_cond
=
"buffer_{name}.itemsize == {
size}"
.
format
(
name
=
field
.
name
,
size
=
item_size
)
item_size_cond
=
f
"buffer_
{
field
.
name
}
.itemsize ==
{
item_size
}
"
pre_call_code
+=
template_check_array
.
format
(
cond
=
item_size_cond
,
what
=
"itemsize"
,
name
=
field
.
name
,
expected
=
item_size
)
if
field
.
has_fixed_shape
:
shape_cond
=
[
"buffer_{name}.shape[{i}] == {s}"
.
format
(
s
=
s
,
name
=
field
.
name
,
i
=
i
)
shape_cond
=
[
f
"buffer_
{
field
.
name
}
.shape[
{
i
}
] ==
{
s
}
"
for
i
,
s
in
enumerate
(
field
.
spatial_shape
)]
shape_cond
=
" && "
.
join
(
shape_cond
)
pre_call_code
+=
template_check_array
.
format
(
cond
=
shape_cond
,
what
=
"shape"
,
name
=
field
.
name
,
...
...
@@ -403,7 +403,7 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
parameters
.
append
(
"buffer_{name}.strides[{i}] / {bytes}"
.
format
(
bytes
=
item_size
,
i
=
param
.
symbol
.
coordinate
,
name
=
field
.
name
))
elif
param
.
is_field_shape
:
parameters
.
append
(
"buffer_{name}.shape[{
i}]"
.
format
(
i
=
param
.
symbol
.
coordinate
,
name
=
param
.
field_name
)
)
parameters
.
append
(
f
"buffer_
{
param
.
field_
name
}
.shape[
{
param
.
symbol
.
coordinate
}
]"
)
else
:
extract_function
,
target_type
=
type_mapping
[
param
.
symbol
.
dtype
.
numpy_dtype
.
type
]
if
np
.
issubdtype
(
param
.
symbol
.
dtype
.
numpy_dtype
,
np
.
complexfloating
):
...
...
@@ -490,8 +490,8 @@ class ExtensionModuleCode:
includes
=
"
\n
"
.
join
([
"#include %s"
%
(
include_file
,)
for
include_file
in
header_list
])
print
(
includes
,
file
=
file
)
print
(
"
\n
"
,
file
=
file
)
print
(
"#define RESTRICT
%s"
%
(
restrict_qualifier
,)
,
file
=
file
)
print
(
"#define FUNC_PREFIX
%s"
%
(
function_prefix
,)
,
file
=
file
)
print
(
f
"#define RESTRICT
{
restrict_qualifier
}
"
,
file
=
file
)
print
(
f
"#define FUNC_PREFIX
{
function_prefix
}
"
,
file
=
file
)
print
(
"
\n
"
,
file
=
file
)
for
ast
,
name
in
zip
(
self
.
_ast_nodes
,
self
.
_function_names
):
...
...
@@ -541,7 +541,7 @@ def compile_module(code, code_hash, base_dir):
import
sysconfig
config_vars
=
sysconfig
.
get_config_vars
()
py_lib
=
os
.
path
.
join
(
config_vars
[
"installed_base"
],
"libs"
,
"python{
}.lib"
.
format
(
config_vars
[
"
py_version_nodot
"
])
)
f
"python
{
config_vars
[
'
py_version_nodot
'
]
}
.lib"
)
run_compile_step
([
'link.exe'
,
py_lib
,
'/DLL'
,
'/out:'
+
lib_file
,
object_file
])
elif
platform
.
system
().
lower
()
==
'darwin'
:
with
atomic_file_write
(
lib_file
)
as
file_name
:
...
...
pystencils/cpu/kernelcreation.py
View file @
bced838a
...
...
@@ -129,7 +129,7 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu
rhs
=
idx_field
[
0
](
name
)
lhs
=
TypedSymbol
(
name
,
BasicType
(
data_type
.
get_element_type
(
name
)))
return
SympyAssignment
(
lhs
,
rhs
)
raise
ValueError
(
"Index
%s
not found in any of the passed index fields"
%
(
name
,)
)
raise
ValueError
(
f
"Index
{
name
}
not found in any of the passed index fields"
)
coordinate_symbol_assignments
=
[
get_coordinate_symbol_assignment
(
n
)
for
n
in
coordinate_names
[:
spatial_coordinates
]]
...
...
@@ -173,7 +173,7 @@ def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, ass
assert
type
(
ast_node
)
is
ast
.
KernelFunction
body
=
ast_node
.
body
threads_clause
=
""
if
num_threads
and
isinstance
(
num_threads
,
bool
)
else
" num_threads(
%s)"
%
(
num_threads
,)
threads_clause
=
""
if
num_threads
and
isinstance
(
num_threads
,
bool
)
else
f
" num_threads(
{
num_threads
}
)"
wrapper_block
=
ast
.
PragmaBlock
(
'#pragma omp parallel'
+
threads_clause
,
body
.
take_child_nodes
())
body
.
append
(
wrapper_block
)
...
...
@@ -204,7 +204,7 @@ def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, ass
except
TypeError
:
pass
prefix
=
"#pragma omp for schedule(
%s)"
%
(
schedule
,)
prefix
=
f
"#pragma omp for schedule(
{
schedule
}
)"
if
collapse
:
prefix
+=
" collapse(%d)"
%
(
collapse
,
)
loop_to_parallelize
.
prefix_lines
.
append
(
prefix
)
pystencils/cpu/msvc_detection.py
View file @
bced838a
...
...
@@ -71,7 +71,7 @@ def normalize_msvc_version(version):
def
get_environment_from_vc_vars_file
(
vc_vars_file
,
arch
):
out
=
subprocess
.
check_output
(
'cmd /u /c "{
}" {} && set'
.
format
(
vc_vars_file
,
arch
)
,
f
'cmd /u /c "
{
vc_vars_file
}
"
{
arch
}
&& set'
,
stderr
=
subprocess
.
STDOUT
,
).
decode
(
'utf-16le'
,
errors
=
'replace'
)
...
...
pystencils/cpu/vectorization.py
View file @
bced838a
...
...
@@ -115,7 +115,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
break
typed_symbol
=
base
.
label
assert
type
(
typed_symbol
.
dtype
)
is
PointerType
,
\
"Type of access is {
}, {}"
.
format
(
typed_symbol
.
dtype
,
indexed
)
f
"Type of access is
{
typed_symbol
.
dtype
}
,
{
indexed
}
"
vec_type
=
VectorType
(
typed_symbol
.
dtype
.
base_type
,
vector_width
)
use_aligned_access
=
aligned_access
and
assume_aligned
...
...
pystencils/data_types.py
View file @
bced838a
...
...
@@ -396,7 +396,7 @@ def ctypes_from_llvm(data_type):
elif
isinstance
(
data_type
,
ir
.
VoidType
):
return
None
# Void type is not supported by ctypes
else
:
raise
NotImplementedError
(
'Data type
%s of %s is not supported yet'
%
(
type
(
data_type
),
data_type
)
)
raise
NotImplementedError
(
f
'Data type
{
type
(
data_type
)
}
of
{
data_type
}
is not supported yet'
)
def
to_llvm_type
(
data_type
,
nvvm_target
=
False
):
...
...
@@ -603,7 +603,7 @@ class BasicType(Type):
elif
name
==
'bool'
:
return
'bool'
else
:
raise
NotImplementedError
(
"Can map numpy to C name for
%s"
%
(
name
,)
)
raise
NotImplementedError
(
f
"Can map numpy to C name for
{
name
}
"
)
def
__init__
(
self
,
dtype
,
const
=
False
):
self
.
const
=
const
...
...
pystencils/datahandling/parallel_datahandling.py
View file @
bced838a
...
...
@@ -383,7 +383,7 @@ class ParallelDataHandling(DataHandling):
if
not
os
.
path
.
exists
(
directory
):
os
.
mkdir
(
directory
)
if
os
.
path
.
isfile
(
directory
):
raise
RuntimeError
(
"Trying to save to {}, but file exists already"
.
format
(
directory
)
)
raise
RuntimeError
(
f
"Trying to save to
{
directory
}
, but file exists already"
)
for
field_name
,
data_name
in
self
.
_field_name_to_cpu_data_name
.
items
():
self
.
blocks
.
writeBlockData
(
data_name
,
os
.
path
.
join
(
directory
,
field_name
+
".dat"
))
...
...
pystencils/datahandling/serial_datahandling.py
View file @
bced838a
...
...
@@ -407,7 +407,7 @@ class SerialDataHandling(DataHandling):
time_running
=
time
.
perf_counter
()
-
self
.
_start_time
spacing
=
7
-
len
(
str
(
int
(
time_running
)))
message
=
"[{: <8}]{
}({:.3f} sec) {} "
.
format
(
level
,
spacing
*
'-'
,
time_running
,
message
)
message
=
f
"[
{
level
:
<
8
}
]
{
spacing
*
'-'
}
(
{
time_running
:
.
3
f
}
sec)
{
message
}
"
print
(
message
,
flush
=
True
)
def
log_on_root
(
self
,
*
args
,
level
=
'INFO'
):
...
...
@@ -428,7 +428,7 @@ class SerialDataHandling(DataHandling):
file_contents
=
np
.
load
(
file
)
for
arr_name
,
arr_contents
in
self
.
cpu_arrays
.
items
():
if
arr_name
not
in
file_contents
:
print
(
"Skipping read data {} because there is no data with this name in data handling"
.
format
(
arr_name
)
)
print
(
f
"Skipping read data
{
arr_name
}
because there is no data with this name in data handling"
)
continue
if
file_contents
[
arr_name
].
shape
!=
arr_contents
.
shape
:
print
(
"Skipping read data {} because shapes don't match. "
...
...
pystencils/display_utils.py
View file @
bced838a
...
...
@@ -30,7 +30,7 @@ def highlight_cpp(code: str):
from
pygments.lexers
import
CppLexer
css
=
HtmlFormatter
().
get_style_defs
(
'.highlight'
)
css_tag
=
"<style>{css}</style>"
.
format
(
css
=
css
)
css_tag
=
f
"<style>
{
css
}
</style>"
display
(
HTML
(
css_tag
))
return
HTML
(
highlight
(
code
,
CppLexer
(),
HtmlFormatter
()))
...
...
pystencils/fd/derivation.py
View file @
bced838a
...
...
@@ -107,7 +107,7 @@ class FiniteDifferenceStencilDerivation:
@
staticmethod
def
symbolic_weight
(
*
args
):
str_args
=
[
str
(
e
)
for
e
in
args
]
return
sp
.
Symbol
(
"w_({
})"
.
format
(
","
.
join
(
str_args
)
)
)
return
sp
.
Symbol
(
f
"w_(
{
','
.
join
(
str_args
)
}
)"
)
def
error_term_dict
(
self
,
order
):
error_terms
=
defaultdict
(
lambda
:
0
)
...
...
pystencils/fd/derivative.py
View file @
bced838a
...
...
@@ -109,7 +109,7 @@ class Diff(sp.Expr):
return
result
def
__str__
(
self
):
return
"D(
%s)"
%
self
.
arg
return
f
"D(
{
self
.
arg
}
)"
def
interpolated_access
(
self
,
offset
,
**
kwargs
):
"""Represents an interpolated access on a spatially differentiated field
...
...
pystencils/fd/finitedifferences.py
View file @
bced838a
...
...
@@ -193,7 +193,7 @@ class Advection(sp.Function):
return
self
.
scalar
.
spatial_dimensions
def
_latex
(
self
,
printer
):
name_suffix
=
"_
%s"
%
self
.
scalar_index
if
self
.
scalar_index
is
not
None
else
""