Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Jonas Plewinski
pystencils
Commits
b8b92cdf
Commit
b8b92cdf
authored
Apr 26, 2019
by
Martin Bauer
Browse files
GPU liveness optimization to reduce registers
parent
7511f364
Changes
18
Expand all
Hide whitespace changes
Inline
Side-by-side
pystencils/backends/cbackend.py
View file @
b8b92cdf
...
...
@@ -164,6 +164,13 @@ class CBackend:
return
"%s%s
\n
%s"
%
(
prefix
,
loop_str
,
self
.
_print
(
node
.
body
))
def
_print_SympyAssignment
(
self
,
node
):
if
self
.
_dialect
==
'cuda'
and
isinstance
(
node
.
lhs
,
sp
.
Symbol
)
and
node
.
lhs
.
name
.
startswith
(
"shmemslot"
):
result
=
"__shared__ volatile double %s[512]; %s[threadIdx.z * "
\
"blockDim.x*blockDim.y + threadIdx.y * "
\
"blockDim.x + threadIdx.x] = %s;"
%
\
(
node
.
lhs
.
name
,
node
.
lhs
.
name
,
self
.
sympy_printer
.
doprint
(
node
.
rhs
))
return
result
if
node
.
is_declaration
:
data_type
=
"const "
+
str
(
node
.
lhs
.
dtype
)
+
" "
if
node
.
is_const
else
str
(
node
.
lhs
.
dtype
)
+
" "
return
"%s%s = %s;"
%
(
data_type
,
self
.
sympy_printer
.
doprint
(
node
.
lhs
),
...
...
@@ -254,6 +261,12 @@ class CustomSympyPrinter(CCodePrinter):
res
=
str
(
expr
.
evalf
().
num
)
return
res
def
_print_Symbol
(
self
,
expr
):
if
self
.
_dialect
==
'cuda'
and
expr
.
name
.
startswith
(
"shmemslot"
):
return
expr
.
name
+
"[threadIdx.z * blockDim.x*blockDim.y + threadIdx.y * blockDim.x + threadIdx.x]"
else
:
return
super
(
CustomSympyPrinter
,
self
).
_print_Symbol
(
expr
)
def
_print_Equality
(
self
,
expr
):
"""Equality operator is not printable in default printer"""
return
'(('
+
self
.
_print
(
expr
.
lhs
)
+
") == ("
+
self
.
_print
(
expr
.
rhs
)
+
'))'
...
...
pystencils/simp/liveness_opts.py
View file @
b8b92cdf
from
sympy
import
Symbol
,
Dummy
from
pystencils
import
Field
,
Assignment
import
sympy
as
sp
import
random
import
copy
from
typing
import
List
from
pystencils
import
Field
,
Assignment
def
get_usage
(
atoms
):
reg_usage
=
{}
for
atom
in
atoms
:
reg_usage
[
atom
.
lhs
]
=
0
for
atom
in
atoms
:
for
arg
in
atom
.
rhs
.
atoms
():
if
isinstance
(
arg
,
Symbol
)
and
not
isinstance
(
arg
,
Field
.
Access
):
if
arg
in
reg_usage
:
reg_usage
[
arg
]
+=
1
else
:
print
(
str
(
arg
)
+
" is unsatisfied"
)
return
reg_usage
def
get_definitions
(
eqs
):
definitions
=
{}
for
eq
in
eqs
:
definitions
[
eq
.
lhs
]
=
eq
return
definitions
fa_symbol_iter
=
sp
.
numbered_symbols
(
"fa_"
)
def
get_roots
(
eqs
):
roots
=
[]
for
eq
in
eqs
:
if
isinstance
(
eq
.
lhs
,
Field
.
Access
):
roots
.
append
(
eq
.
lhs
)
if
not
roots
:
roots
.
append
(
eqs
[
-
1
].
lhs
)
return
roots
def
merge_field_accesses
(
eqs
):
def
merge_field_accesses
(
assignments
):
"""Transformation that introduces symbols for all read field accesses
for multiple read accesses only one symbol is introduced"""
field_accesses
=
{}
for
eq
in
eqs
:
for
arg
in
eq
.
rhs
.
atoms
():
new_eqs
=
copy
.
copy
(
assignments
)
for
assignment
in
new_eqs
:
for
arg
in
assignment
.
rhs
.
atoms
():
if
isinstance
(
arg
,
Field
.
Access
)
and
arg
not
in
field_accesses
:
field_accesses
[
arg
]
=
Dummy
(
)
field_accesses
[
arg
]
=
next
(
fa_symbol_iter
)
for
i
in
range
(
0
,
len
(
eqs
)):
for
i
in
range
(
0
,
len
(
new_
eqs
)):
for
f
,
s
in
field_accesses
.
items
():
if
f
in
eqs
[
i
].
atoms
():
eqs
[
i
]
=
eqs
[
i
].
subs
(
f
,
s
)
if
f
in
new_
eqs
[
i
].
atoms
():
new_
eqs
[
i
]
=
new_
eqs
[
i
].
subs
(
f
,
s
)
for
f
,
s
in
field_accesses
.
items
():
eqs
.
insert
(
0
,
Assignment
(
s
,
f
))
new_eqs
.
insert
(
0
,
Assignment
(
s
,
f
))
return
new_eqs
return
eqs
def
fuse_eqs
(
input_eqs
,
max_depth
=
1
,
max_usage
=
1
):
"""Inserts subexpressions that are used not more than `max_usage`
def
refuse_eqs
(
input_eqs
,
max_depth
=
0
,
max_usage
=
1
):
Args:
max_depth: complexity metric for the subexpression to insert
if max_depth is larger than the expression tree of the subexpression
the subexpressions is not inserted
Somewhat the inverse of common subexpression elimination.
"""
eqs
=
copy
.
copy
(
input_eqs
)
usages
=
get_usage
(
eqs
)
definitions
=
get_definitions
(
eqs
)
def
inline_trivially_schedulable
(
sym
,
depth
):
if
sym
not
in
usages
or
usages
[
sym
]
>
max_usage
or
depth
>
max_depth
:
if
sym
not
in
definitions
or
sym
not
in
usages
or
usages
[
sym
]
>
max_usage
or
depth
>
max_depth
:
return
sym
rhs
=
definitions
[
sym
].
rhs
...
...
@@ -74,13 +56,13 @@ def refuse_eqs(input_eqs, max_depth=0, max_usage=1):
for
idx
,
eq
in
enumerate
(
eqs
):
if
usages
[
eq
.
lhs
]
>
1
or
isinstance
(
eq
.
lhs
,
Field
.
Access
):
if
not
isinstance
(
eq
.
rhs
,
Symbol
):
eqs
[
idx
]
=
Assignment
(
eq
.
lhs
,
eq
.
rhs
.
func
(
*
[
inline_trivially_schedulable
(
arg
,
0
)
for
arg
in
eq
.
rhs
.
args
]))
if
not
isinstance
(
eq
.
rhs
,
sp
.
Symbol
):
eqs
[
idx
]
=
Assignment
(
eq
.
lhs
,
eq
.
rhs
.
func
(
*
[
inline_trivially_schedulable
(
arg
,
0
)
for
arg
in
eq
.
rhs
.
args
]))
count
=
0
while
(
len
(
eqs
)
!=
count
)
:
while
len
(
eqs
)
!=
count
:
count
=
len
(
eqs
)
usages
=
get_usage
(
eqs
)
eqs
=
[
eq
for
eq
in
eqs
if
usages
[
eq
.
lhs
]
>
0
or
isinstance
(
eq
.
lhs
,
Field
.
Access
)]
...
...
@@ -88,16 +70,26 @@ def refuse_eqs(input_eqs, max_depth=0, max_usage=1):
return
eqs
def
schedule_eqs
(
eqs
,
candidate_count
=
20
):
def
schedule_eqs
(
assignments
:
List
[
Assignment
],
candidate_count
=
20
):
"""Changes order of assignments to save registers.
Args:
assignments:
candidate_count: tuning parameter, small means fast, but bad scheduling quality
1 corresponds to full greedy search
Returns:
list of re-ordered assignments
"""
if
candidate_count
==
0
:
return
eq
s
return
assignment
s
definitions
=
get_definitions
(
eq
s
)
definitions
=
get_definitions
(
assignment
s
)
definition_atoms
=
{}
for
sym
,
definition
in
definitions
.
items
():
definition_atoms
[
sym
]
=
list
(
definition
.
rhs
.
atoms
(
Symbol
))
roots
=
get_roots
(
eq
s
)
initial_usages
=
get_usage
(
eq
s
)
definition_atoms
[
sym
]
=
list
(
definition
.
rhs
.
atoms
(
sp
.
Symbol
))
roots
=
get_roots
(
assignment
s
)
initial_usages
=
get_usage
(
assignment
s
)
level
=
0
current_level_set
=
set
([
frozenset
(
roots
)])
...
...
@@ -111,12 +103,18 @@ def schedule_eqs(eqs, candidate_count=20):
min_regs
=
min
([
len
(
current_usages
[
dec_set
])
for
dec_set
in
current_level_set
])
max_regs
=
max
(
max_regs
,
min_regs
)
candidates
=
[(
dec_set
,
len
(
current_usages
[
dec_set
]))
for
dec_set
in
current_level_set
]
def
score_dec_set
(
dec_set
):
score
=
len
(
current_usages
[
dec_set
])
# current_schedules[dec_set][0]
return
dec_set
,
score
candidates
=
[
score_dec_set
(
dec_set
)
for
dec_set
in
current_level_set
]
random
.
shuffle
(
candidates
)
candidates
.
sort
(
key
=
lambda
d
:
d
[
1
])
for
dec_set
,
regs
in
candidates
[:
candidate_count
]:
for
dec
in
dec_set
:
new_dec_set
=
set
(
dec_set
)
new_dec_set
.
remove
(
dec
)
...
...
@@ -126,7 +124,7 @@ def schedule_eqs(eqs, candidate_count=20):
for
arg
in
atoms
:
if
not
isinstance
(
arg
,
Field
.
Access
):
argu
=
usage
.
get
(
arg
,
initial_usages
[
arg
])
-
1
if
argu
==
0
:
if
argu
==
0
and
arg
in
definitions
:
new_dec_set
.
add
(
arg
)
usage
[
arg
]
=
argu
frozen_new_dec_set
=
frozenset
(
new_dec_set
)
...
...
@@ -134,7 +132,6 @@ def schedule_eqs(eqs, candidate_count=20):
max_reg_count
=
max
(
len
(
usage
),
schedule
[
0
])
if
frozen_new_dec_set
not
in
new_schedules
or
max_reg_count
<
new_schedules
[
frozen_new_dec_set
][
0
]:
new_schedule
=
list
(
schedule
[
1
])
new_schedule
.
append
(
definitions
[
dec
])
new_schedules
[
frozen_new_dec_set
]
=
(
max_reg_count
,
new_schedule
)
...
...
@@ -150,9 +147,77 @@ def schedule_eqs(eqs, candidate_count=20):
level
+=
1
schedule
=
current_schedules
[
frozenset
()]
schedule
[
1
].
reverse
()
return
(
schedule
[
1
]
)
return
schedule
[
1
]
def
liveness_opt_transformation
(
eqs
):
return
refuse_eqs
(
merge_field_accesses
(
schedule_eqs
(
eqs
,
3
)),
1
,
3
)
return
fuse_eqs
(
merge_field_accesses
(
schedule_eqs
(
eqs
,
30
)),
1
,
3
)
# ---------- Utilities -----------------------------------------------------------------------------------------
def
get_usage
(
assignments
:
List
[
Assignment
]):
"""Count number of reads for all symbols in list of assignments
Returns:
dictionary mapping symbol to number of its reads
"""
reg_usage
=
{}
for
assignment
in
assignments
:
for
arg
in
assignment
.
rhs
.
atoms
():
if
isinstance
(
arg
,
sp
.
Symbol
)
and
not
isinstance
(
arg
,
Field
.
Access
):
if
arg
in
reg_usage
:
reg_usage
[
arg
]
+=
1
else
:
reg_usage
[
arg
]
=
1
return
reg_usage
def
get_definitions
(
assignments
:
List
[
Assignment
]):
"""Returns dictionary mapping symbol to its defining assignment"""
definitions
=
{}
for
assignment
in
assignments
:
definitions
[
assignment
.
lhs
]
=
assignment
return
definitions
def
get_roots
(
eqs
):
"""Returns all field accesses that are used as lhs in assignment (stores)
In case there are no independent assignments, the last one is returned (TODO try if necessary)
"""
roots
=
[]
for
eq
in
eqs
:
if
isinstance
(
eq
.
lhs
,
Field
.
Access
):
roots
.
append
(
eq
.
lhs
)
if
not
roots
:
roots
.
append
(
eqs
[
-
1
].
lhs
)
return
roots
# ---------- Staggered kernels -----------------------------------------------------------------------------------------
def
unpack_staggered_eqs
(
field
,
expressions
,
subexpressions
):
eqs
=
copy
.
deepcopy
(
subexpressions
)
for
dim
in
range
(
0
,
len
(
expressions
)):
for
vec
in
range
(
0
,
len
(
expressions
[
dim
])):
eqs
.
append
(
Assignment
(
Field
.
Access
(
field
,
(
0
,
0
,
0
,
dim
,
vec
)),
expressions
[
dim
][
vec
]))
return
eqs
def
pack_staggered_eqs
(
eqs
,
field
,
expressions
,
subexpressions
):
new_matrix_list
=
[
0
]
*
(
field
.
shape
[
-
1
]
*
field
.
shape
[
-
2
])
for
eq
in
eqs
:
if
isinstance
(
eq
.
lhs
,
Field
.
Access
):
new_matrix_list
[
eq
.
lhs
.
offsets
[
-
2
]
*
field
.
shape
[
-
1
]
+
eq
.
lhs
.
offsets
[
-
1
]]
=
eq
.
rhs
subexpressions
=
[
eq
for
eq
in
eqs
if
not
isinstance
(
eq
.
lhs
,
Field
.
Access
)]
return
(
field
,
[
sp
.
Matrix
(
field
.
shape
[
-
1
],
1
,
new_matrix_list
[
dim
*
field
.
shape
[
-
1
]:(
dim
+
1
)
*
field
.
shape
[
-
1
]])
for
dim
in
range
(
field
.
shape
[
-
2
])
],
subexpressions
)
pystencils/simp/liveness_opts_exp.py
0 → 100644
View file @
b8b92cdf
This diff is collapsed.
Click to expand it.
pystencils/simp/liveness_permutations.py
0 → 100644
View file @
b8b92cdf
from
pygrandchem.grandchem
import
StaggeredKernelParams
from
pystencils.simp.liveness_opts
import
*
from
pystencils.simp.liveness_opts_exp
import
*
import
random
import
pycuda.driver
as
drv
import
pystencils
as
ps
from
pystencils
import
show_code
from
timeit
import
default_timer
as
timer
import
copy
optSequenceCache
=
{}
all_opts
=
[[
atomize_eqs
,
[]],
[
schedule_eqs
,
[
2
]],
[
duplicate_trivial_ops
,
[
3
,
1
]],
[
merge_field_accesses
,
[]],
[
refuse_eqs
,
[
1
,
1
]],
[
var_to_shmem
,
[
4
]],
[
var_to_shmem_lt
,
[
4
]]]
def
mutateOptSequence
(
seq
):
changed
=
False
new_seq
=
copy
.
deepcopy
(
seq
)
while
not
changed
:
choice
=
random
.
randint
(
0
,
4
)
if
choice
==
0
:
new_seq
.
opts
.
append
(
random
.
choice
(
all_opts
))
changed
=
True
elif
choice
==
1
:
if
len
(
new_seq
.
opts
)
>
1
:
a
=
random
.
randint
(
0
,
len
(
new_seq
.
opts
)
-
1
)
b
=
random
.
randint
(
0
,
len
(
new_seq
.
opts
)
-
1
)
new_seq
.
opts
[
a
],
new_seq
.
opts
[
b
]
=
new_seq
.
opts
[
b
],
new_seq
.
opts
[
a
]
changed
=
True
elif
choice
==
2
:
if
len
(
new_seq
.
opts
)
>
0
:
new_seq
.
opts
.
remove
(
random
.
choice
(
new_seq
.
opts
))
changed
=
True
elif
choice
==
3
:
if
len
(
new_seq
.
opts
)
>
0
:
opt
=
random
.
choice
(
new_seq
.
opts
)
change
=
random
.
choice
([
-
1
,
1
])
factor
=
1
if
change
<
0
:
factor
=
random
.
uniform
(
0.3
,
1.0
)
if
change
>
0
:
factor
=
random
.
uniform
(
1.0
,
3.0
)
if
len
(
opt
[
1
])
>
0
:
arg
=
random
.
randint
(
0
,
len
(
opt
[
1
])
-
1
)
opt
[
1
][
arg
]
=
int
(
max
(
0
,
opt
[
1
][
arg
]
*
factor
+
change
))
changed
=
True
else
:
dim
=
random
.
randint
(
0
,
2
)
change
=
random
.
randint
(
0
,
1
)
newBlockSize
=
list
(
seq
.
blockSize
)
if
change
==
0
:
newBlockSize
[
dim
]
=
min
(
512
,
newBlockSize
[
dim
]
*
2
)
else
:
newBlockSize
[
dim
]
=
max
(
1
,
newBlockSize
[
dim
]
//
2
)
if
newBlockSize
[
0
]
*
newBlockSize
[
1
]
*
newBlockSize
[
2
]
<=
512
and
(
newBlockSize
[
0
]
>=
32
or
newBlockSize
[
0
]
>=
seq
.
blockSize
[
0
]):
seq
.
blockSize
=
tuple
(
newBlockSize
)
changed
=
True
return
new_seq
def
evolvePopulation
(
pop
,
eqs_set
,
dhs
,
staggered_params
=
None
):
pop
.
append
(
livenessOptSequence
())
once_mutated
=
[
mutateOptSequence
(
seq
)
for
seq
in
pop
[
0
:
6
]]
twice_mutated
=
[
mutateOptSequence
(
mutateOptSequence
(
seq
))
for
seq
in
pop
[
0
:
4
]]
thrice_mutated
=
[
mutateOptSequence
(
mutateOptSequence
(
mutateOptSequence
(
seq
)))
for
seq
in
pop
[
0
:
3
]
]
new_pop
=
list
(
set
(
pop
+
once_mutated
+
twice_mutated
+
thrice_mutated
))
scores
=
[]
for
seq
in
new_pop
:
scores
.
append
((
seq
,
*
rateSequence
(
seq
,
eqs_set
,
dhs
,
staggered_params
)))
old_scores
=
[]
for
s
in
optSequenceCache
:
if
s
not
in
new_pop
:
if
s
not
in
optSequenceCache
:
print
(
"Not in optSequenceCache: "
)
print
(
s
)
print
(
hash
(
s
))
old_scores
.
append
((
s
,
optSequenceCache
[
s
][
0
],
[
0
,
0
]))
old_scores
.
sort
(
key
=
lambda
s
:
sum
(
s
[
1
]))
if
len
(
old_scores
)
>
0
:
scores
.
append
(
old_scores
[
0
])
print
()
scores
.
sort
(
key
=
lambda
s
:
sum
(
s
[
1
]))
new_pop
=
[]
count_old_seqs
=
0
for
score
in
scores
:
if
score
[
0
]
not
in
optSequenceCache
:
print
(
"Everything in scores: "
)
for
s
in
scores
:
print
(
s
[
0
])
print
(
"Not in optSequenceCache: "
)
print
(
score
[
0
])
print
(
hash
(
score
[
0
]))
survive
=
False
if
(
len
(
new_pop
)
<
4
or
count_old_seqs
<
3
)
and
len
(
new_pop
)
<
10
:
if
optSequenceCache
[
score
[
0
]][
1
]
>
3
:
count_old_seqs
+=
1
new_pop
.
append
(
score
[
0
])
survive
=
True
print
(
""
.
join
([
"{:6.2f} "
.
format
(
sc
)
for
sc
in
score
[
1
]])
+
"("
+
""
.
join
([
"{:3d} "
.
format
(
sc
)
for
sc
in
score
[
2
]])
+
"): "
+
"{:2d}"
.
format
(
optSequenceCache
[
score
[
0
]][
1
])
+
(
" * "
if
survive
else
" "
)
+
str
(
score
[
0
]))
print
()
return
new_pop
def
rateSequence
(
seq
,
eqs_set
,
dh
,
staggered_params
=
None
):
if
seq
not
in
optSequenceCache
:
optSequenceCache
[
seq
]
=
[[],
0
]
cache_entry
=
optSequenceCache
[
seq
]
if
cache_entry
[
1
]
>
10
:
return
(
cache_entry
[
0
],
[
0
,
0
])
print
(
cache_entry
[
1
],
end
=
" "
)
print
(
seq
)
start
=
timer
()
transformed_eqs_set
=
[
seq
.
applyOpts
(
eqs
)
for
eqs
in
eqs_set
]
end
=
timer
()
kernel_results
=
[
bench_kernel
(
eqs
,
dh
,
seq
.
blockSize
,
staggered_params
)
for
eqs
in
transformed_eqs_set
]
kernel_registers
=
[
k
[
1
]
for
k
in
kernel_results
]
result
=
[
k
[
0
]
for
k
in
kernel_results
]
+
[
k
[
0
]
*
max
(
0.0
,
(
len
(
seq
.
opts
)
-
3
)
*
0.1
)
for
k
in
kernel_results
]
if
cache_entry
[
1
]
==
0
:
cache_entry
[
0
]
=
result
else
:
for
i
in
range
(
0
,
len
(
result
)):
cache_entry
[
0
][
i
]
=
(
cache_entry
[
0
][
i
]
*
cache_entry
[
1
]
+
result
[
i
])
/
(
cache_entry
[
1
]
+
1
)
cache_entry
[
1
]
+=
1
return
cache_entry
[
0
],
kernel_registers
def
bench_kernel
(
eqs
,
dh
,
blockSize
=
(
64
,
2
,
1
),
staggered_params
=
None
):
if
staggered_params
is
None
:
kernel
=
ps
.
create_kernel
(
eqs
,
target
=
"gpu"
,
gpu_indexing_params
=
{
"block_size"
:
blockSize
}).
compile
()
else
:
kernel
=
ps
.
create_staggered_kernel
(
*
pack_staggered_eqs
(
eqs
,
*
staggered_params
),
target
=
"gpu"
,
gpu_indexing_params
=
{
"block_size"
:
blockSize
}).
compile
()
start
=
drv
.
Event
()
end
=
drv
.
Event
()
start
.
record
()
dh
.
run_kernel
(
kernel
,
timestep
=
1
)
dh
.
run_kernel
(
kernel
,
timestep
=
1
)
end
.
record
()
end
.
synchronize
()
msec
=
start
.
time_till
(
end
)
/
2
return
msec
,
kernel
.
num_regs
pystencils_tests/liveness_opts/compare_seqs.py
0 → 100644
View file @
b8b92cdf
# coding: utf-8
# In[32]:
import
pickle
import
warnings
import
pystencils
as
ps
from
pygrandchem.grandchem
import
GrandChemGenerator
from
pygrandchem.scenarios
import
system_4_2
,
system_3_1
from
pygrandchem.initialization
import
init_boxes
,
smooth_fields
from
pygrandchem.scenarios
import
benchmark_configs
from
sympy
import
Number
,
Symbol
,
Expr
,
preorder_traversal
,
postorder_traversal
,
Function
,
Piecewise
,
relational
from
pystencils.simp
import
sympy_cse_on_assignment_list
from
pystencils.simp.liveness_opts
import
*
from
pystencils.simp.liveness_opts_exp
import
*
from
pystencils.simp.liveness_permutations
import
*
import
pycuda
import
sys
from
subprocess
import
run
,
PIPE
from
pystencils
import
show_code
import
pycuda.driver
as
drv
import
importlib
configs
=
benchmark_configs
()
def
get_config
(
name
):
return
configs
[
name
]
domain_size
=
(
512
,
512
,
128
)
periodicity
=
(
True
,
True
,
False
)
optimization
=
{
'gpu_indexing_params'
:
{
"block_size"
:
(
32
,
4
,
2
)}}
#bestSeqs = pickle.load(open('best_seq.pickle', 'rb'))
scenarios
=
[
"42_varT_freeEnergy"
,
"31_varT_aniso_rot"
]
kernel_types
=
[
"phi_full"
,
"phi_partial1"
,
"phi_partial2"
,
"mu_full"
,
"mu_partial1"
,
"mu_partial2"
]
liveness_trans_seqs
=
importlib
.
import_module
(
"gpu_liveness_trans_sequences"
).
gpu_liveness_trans_sequences
for
scenario
in
scenarios
:
config
=
get_config
(
scenario
)
phases
,
components
=
config
[
'Parameters'
][
'phases'
],
config
[
'Parameters'
][
'components'
]
format_args
=
{
'p'
:
phases
,
'c'
:
components
,
's'
:
','
.
join
(
str
(
e
)
for
e
in
domain_size
)}
# Adding fields
dh
=
ps
.
create_data_handling
(
domain_size
,
periodicity
=
periodicity
,
default_target
=
'gpu'
)
f
=
dh
.
fields
phi_src
=
dh
.
add_array
(
'phi_src'
,
values_per_cell
=
config
[
'Parameters'
][
'phases'
],
layout
=
'fzyx'
,
latex_name
=
'phi_s'
)
mu_src
=
dh
.
add_array
(
'mu_src'
,
values_per_cell
=
config
[
'Parameters'
][
'components'
],
layout
=
'fzyx'
,
latex_name
=
"mu_s"
)
mu_stag
=
dh
.
add_array
(
'mu_stag'
,
values_per_cell
=
(
dh
.
dim
,
config
[
'Parameters'
][
'components'
]),
layout
=
'f'
)
phi_stag
=
dh
.
add_array
(
'phi_stag'
,
values_per_cell
=
(
dh
.
dim
,
phases
),
layout
=
'f'
)
phi_dst
=
dh
.
add_array_like
(
'phi_dst'
,
'phi_src'
)
mu_dst
=
dh
.
add_array_like
(
'mu_dst'
,
'mu_src'
)
gc
=
GrandChemGenerator
(
phi_src
,
phi_dst
,
mu_src
,
mu_dst
,
config
[
'FreeEnergy'
],
config
[
'Parameters'
],
#conc=c,
mu_staggered
=
mu_stag
,
phi_staggered
=
phi_stag
,
use_block_offsets
=
False
,
compile_kernel
=
False
)
mu_full_eqs
=
gc
.
mu_full
()
phi_full_eqs
=
gc
.
phi_full
()
phi_kernel
=
ps
.
create_kernel
(
phi_full_eqs
,
target
=
'gpu'
,
**
optimization
).
compile
()
mu_kernel
=
ps
.
create_kernel
(
mu_full_eqs
,
target
=
'gpu'
,
**
optimization
).
compile
()
c
=
dh
.
add_array
(
'c'
,
values_per_cell
=
config
[
'Parameters'
][
'components'
],
layout
=
'fzyx'
,
gpu
=
False
)
init_boxes
(
dh
)
#initialize_concentration_field(dh, free_energy, config['Parameters']['initial_concentration'])
smooth_fields
(
dh
,
sigma
=
0.4
,
iterations
=
5
,
dim
=
dh
.
dim
)
dh
.
synchronization_function
([
'phi_src'
,
'phi_dst'
,
'mu_src'
,
'mu_dst'
])()
staggered_params
=
None
def
bench_kernels
(
mu_kernel
,
phi_kernel
):