Commit b8b92cdf authored by Martin Bauer's avatar Martin Bauer
Browse files

GPU liveness optimization to reduce registers

parent 7511f364
......@@ -164,6 +164,13 @@ class CBackend:
return "%s%s\n%s" % (prefix, loop_str, self._print(node.body))
def _print_SympyAssignment(self, node):
if self._dialect == 'cuda' and isinstance(node.lhs, sp.Symbol) and node.lhs.name.startswith("shmemslot"):
result = "__shared__ volatile double %s[512]; %s[threadIdx.z * " \
"blockDim.x*blockDim.y + threadIdx.y * " \
"blockDim.x + threadIdx.x] = %s;" % \
(node.lhs.name, node.lhs.name, self.sympy_printer.doprint(node.rhs))
return result
if node.is_declaration:
data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
......@@ -254,6 +261,12 @@ class CustomSympyPrinter(CCodePrinter):
res = str(expr.evalf().num)
return res
def _print_Symbol(self, expr):
if self._dialect == 'cuda' and expr.name.startswith("shmemslot"):
return expr.name + "[threadIdx.z * blockDim.x*blockDim.y + threadIdx.y * blockDim.x + threadIdx.x]"
else:
return super(CustomSympyPrinter, self)._print_Symbol(expr)
def _print_Equality(self, expr):
"""Equality operator is not printable in default printer"""
return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'
......
from sympy import Symbol, Dummy
from pystencils import Field, Assignment
import sympy as sp
import random
import copy
from typing import List
from pystencils import Field, Assignment
def get_usage(atoms):
reg_usage = {}
for atom in atoms:
reg_usage[atom.lhs] = 0
for atom in atoms:
for arg in atom.rhs.atoms():
if isinstance(arg, Symbol) and not isinstance(arg, Field.Access):
if arg in reg_usage:
reg_usage[arg] += 1
else:
print(str(arg) + " is unsatisfied")
return reg_usage
def get_definitions(eqs):
definitions = {}
for eq in eqs:
definitions[eq.lhs] = eq
return definitions
fa_symbol_iter = sp.numbered_symbols("fa_")
def get_roots(eqs):
roots = []
for eq in eqs:
if isinstance(eq.lhs, Field.Access):
roots.append(eq.lhs)
if not roots:
roots.append(eqs[-1].lhs)
return roots
def merge_field_accesses(eqs):
def merge_field_accesses(assignments):
"""Transformation that introduces symbols for all read field accesses
for multiple read accesses only one symbol is introduced"""
field_accesses = {}
for eq in eqs:
for arg in eq.rhs.atoms():
new_eqs = copy.copy(assignments)
for assignment in new_eqs:
for arg in assignment.rhs.atoms():
if isinstance(arg, Field.Access) and arg not in field_accesses:
field_accesses[arg] = Dummy()
field_accesses[arg] = next(fa_symbol_iter)
for i in range(0, len(eqs)):
for i in range(0, len(new_eqs)):
for f, s in field_accesses.items():
if f in eqs[i].atoms():
eqs[i] = eqs[i].subs(f, s)
if f in new_eqs[i].atoms():
new_eqs[i] = new_eqs[i].subs(f, s)
for f, s in field_accesses.items():
eqs.insert(0, Assignment(s, f))
new_eqs.insert(0, Assignment(s, f))
return new_eqs
return eqs
def fuse_eqs(input_eqs, max_depth=1, max_usage=1):
"""Inserts subexpressions that are used not more than `max_usage`
def refuse_eqs(input_eqs, max_depth=0, max_usage=1):
Args:
max_depth: complexity metric for the subexpression to insert
if max_depth is larger than the expression tree of the subexpression
the subexpressions is not inserted
Somewhat the inverse of common subexpression elimination.
"""
eqs = copy.copy(input_eqs)
usages = get_usage(eqs)
definitions = get_definitions(eqs)
def inline_trivially_schedulable(sym, depth):
if sym not in usages or usages[sym] > max_usage or depth > max_depth:
if sym not in definitions or sym not in usages or usages[sym] > max_usage or depth > max_depth:
return sym
rhs = definitions[sym].rhs
......@@ -74,13 +56,13 @@ def refuse_eqs(input_eqs, max_depth=0, max_usage=1):
for idx, eq in enumerate(eqs):
if usages[eq.lhs] > 1 or isinstance(eq.lhs, Field.Access):
if not isinstance(eq.rhs, Symbol):
eqs[idx] = Assignment(eq.lhs,
eq.rhs.func(*[inline_trivially_schedulable(arg, 0) for arg in eq.rhs.args]))
if not isinstance(eq.rhs, sp.Symbol):
eqs[idx] = Assignment(
eq.lhs,
eq.rhs.func(*[inline_trivially_schedulable(arg, 0) for arg in eq.rhs.args]))
count = 0
while (len(eqs) != count):
while len(eqs) != count:
count = len(eqs)
usages = get_usage(eqs)
eqs = [eq for eq in eqs if usages[eq.lhs] > 0 or isinstance(eq.lhs, Field.Access)]
......@@ -88,16 +70,26 @@ def refuse_eqs(input_eqs, max_depth=0, max_usage=1):
return eqs
def schedule_eqs(eqs, candidate_count=20):
def schedule_eqs(assignments: List[Assignment], candidate_count=20):
"""Changes order of assignments to save registers.
Args:
assignments:
candidate_count: tuning parameter, small means fast, but bad scheduling quality
1 corresponds to full greedy search
Returns:
list of re-ordered assignments
"""
if candidate_count == 0:
return eqs
return assignments
definitions = get_definitions(eqs)
definitions = get_definitions(assignments)
definition_atoms = {}
for sym, definition in definitions.items():
definition_atoms[sym] = list(definition.rhs.atoms(Symbol))
roots = get_roots(eqs)
initial_usages = get_usage(eqs)
definition_atoms[sym] = list(definition.rhs.atoms(sp.Symbol))
roots = get_roots(assignments)
initial_usages = get_usage(assignments)
level = 0
current_level_set = set([frozenset(roots)])
......@@ -111,12 +103,18 @@ def schedule_eqs(eqs, candidate_count=20):
min_regs = min([len(current_usages[dec_set]) for dec_set in current_level_set])
max_regs = max(max_regs, min_regs)
candidates = [(dec_set, len(current_usages[dec_set])) for dec_set in current_level_set]
def score_dec_set(dec_set):
score = len(current_usages[dec_set]) # current_schedules[dec_set][0]
return dec_set, score
candidates = [score_dec_set(dec_set) for dec_set in current_level_set]
random.shuffle(candidates)
candidates.sort(key=lambda d: d[1])
for dec_set, regs in candidates[:candidate_count]:
for dec in dec_set:
new_dec_set = set(dec_set)
new_dec_set.remove(dec)
......@@ -126,7 +124,7 @@ def schedule_eqs(eqs, candidate_count=20):
for arg in atoms:
if not isinstance(arg, Field.Access):
argu = usage.get(arg, initial_usages[arg]) - 1
if argu == 0:
if argu == 0 and arg in definitions:
new_dec_set.add(arg)
usage[arg] = argu
frozen_new_dec_set = frozenset(new_dec_set)
......@@ -134,7 +132,6 @@ def schedule_eqs(eqs, candidate_count=20):
max_reg_count = max(len(usage), schedule[0])
if frozen_new_dec_set not in new_schedules or max_reg_count < new_schedules[frozen_new_dec_set][0]:
new_schedule = list(schedule[1])
new_schedule.append(definitions[dec])
new_schedules[frozen_new_dec_set] = (max_reg_count, new_schedule)
......@@ -150,9 +147,77 @@ def schedule_eqs(eqs, candidate_count=20):
level += 1
schedule = current_schedules[frozenset()]
schedule[1].reverse()
return (schedule[1])
return schedule[1]
def liveness_opt_transformation(eqs):
return refuse_eqs(merge_field_accesses(schedule_eqs(eqs, 3)), 1, 3)
return fuse_eqs(merge_field_accesses(schedule_eqs(eqs, 30)), 1, 3)
# ---------- Utilities -----------------------------------------------------------------------------------------
def get_usage(assignments: List[Assignment]):
"""Count number of reads for all symbols in list of assignments
Returns:
dictionary mapping symbol to number of its reads
"""
reg_usage = {}
for assignment in assignments:
for arg in assignment.rhs.atoms():
if isinstance(arg, sp.Symbol) and not isinstance(arg, Field.Access):
if arg in reg_usage:
reg_usage[arg] += 1
else:
reg_usage[arg] = 1
return reg_usage
def get_definitions(assignments: List[Assignment]):
"""Returns dictionary mapping symbol to its defining assignment"""
definitions = {}
for assignment in assignments:
definitions[assignment.lhs] = assignment
return definitions
def get_roots(eqs):
"""Returns all field accesses that are used as lhs in assignment (stores)
In case there are no independent assignments, the last one is returned (TODO try if necessary)
"""
roots = []
for eq in eqs:
if isinstance(eq.lhs, Field.Access):
roots.append(eq.lhs)
if not roots:
roots.append(eqs[-1].lhs)
return roots
# ---------- Staggered kernels -----------------------------------------------------------------------------------------
def unpack_staggered_eqs(field, expressions, subexpressions):
eqs = copy.deepcopy(subexpressions)
for dim in range(0, len(expressions)):
for vec in range(0, len(expressions[dim])):
eqs.append(Assignment(Field.Access(field, (0, 0, 0, dim, vec)), expressions[dim][vec]))
return eqs
def pack_staggered_eqs(eqs, field, expressions, subexpressions):
new_matrix_list = [0] * (field.shape[-1] * field.shape[-2])
for eq in eqs:
if isinstance(eq.lhs, Field.Access):
new_matrix_list[eq.lhs.offsets[-2] * field.shape[-1] + eq.lhs.offsets[-1]] = eq.rhs
subexpressions = [eq for eq in eqs if not isinstance(eq.lhs, Field.Access)]
return (field, [
sp.Matrix(field.shape[-1], 1,
new_matrix_list[dim * field.shape[-1]:(dim + 1) * field.shape[-1]])
for dim in range(field.shape[-2])
], subexpressions)
This diff is collapsed.
from pygrandchem.grandchem import StaggeredKernelParams
from pystencils.simp.liveness_opts import *
from pystencils.simp.liveness_opts_exp import *
import random
import pycuda.driver as drv
import pystencils as ps
from pystencils import show_code
from timeit import default_timer as timer
import copy
optSequenceCache = {}
all_opts = [[atomize_eqs, []], [schedule_eqs, [2]], [duplicate_trivial_ops, [3, 1]],
[merge_field_accesses, []], [refuse_eqs, [1, 1]], [var_to_shmem, [4]],
[var_to_shmem_lt, [4]]]
def mutateOptSequence(seq):
changed = False
new_seq = copy.deepcopy(seq)
while not changed:
choice = random.randint(0, 4)
if choice == 0:
new_seq.opts.append(random.choice(all_opts))
changed = True
elif choice == 1:
if len(new_seq.opts) > 1:
a = random.randint(0, len(new_seq.opts) - 1)
b = random.randint(0, len(new_seq.opts) - 1)
new_seq.opts[a], new_seq.opts[b] = new_seq.opts[b], new_seq.opts[a]
changed = True
elif choice == 2:
if len(new_seq.opts) > 0:
new_seq.opts.remove(random.choice(new_seq.opts))
changed = True
elif choice == 3:
if len(new_seq.opts) > 0:
opt = random.choice(new_seq.opts)
change = random.choice([-1, 1])
factor = 1
if change < 0:
factor = random.uniform(0.3, 1.0)
if change > 0:
factor = random.uniform(1.0, 3.0)
if len(opt[1]) > 0:
arg = random.randint(0, len(opt[1]) - 1)
opt[1][arg] = int(max(0, opt[1][arg] * factor + change))
changed = True
else:
dim = random.randint(0, 2)
change = random.randint(0, 1)
newBlockSize = list(seq.blockSize)
if change == 0:
newBlockSize[dim] = min(512, newBlockSize[dim] * 2)
else:
newBlockSize[dim] = max(1, newBlockSize[dim] // 2)
if newBlockSize[0] * newBlockSize[1] * newBlockSize[2] <= 512 and (
newBlockSize[0] >= 32 or newBlockSize[0] >= seq.blockSize[0]):
seq.blockSize = tuple(newBlockSize)
changed = True
return new_seq
def evolvePopulation(pop, eqs_set, dhs, staggered_params=None):
pop.append(livenessOptSequence())
once_mutated = [mutateOptSequence(seq) for seq in pop[0:6]]
twice_mutated = [mutateOptSequence(mutateOptSequence(seq)) for seq in pop[0:4]]
thrice_mutated = [
mutateOptSequence(mutateOptSequence(mutateOptSequence(seq))) for seq in pop[0:3]
]
new_pop = list(set(pop + once_mutated + twice_mutated + thrice_mutated))
scores = []
for seq in new_pop:
scores.append((seq, *rateSequence(seq, eqs_set, dhs, staggered_params)))
old_scores = []
for s in optSequenceCache:
if s not in new_pop:
if s not in optSequenceCache:
print("Not in optSequenceCache: ")
print(s)
print(hash(s))
old_scores.append((s, optSequenceCache[s][0], [0, 0]))
old_scores.sort(key=lambda s: sum(s[1]))
if len(old_scores) > 0: scores.append(old_scores[0])
print()
scores.sort(key=lambda s: sum(s[1]))
new_pop = []
count_old_seqs = 0
for score in scores:
if score[0] not in optSequenceCache:
print("Everything in scores: ")
for s in scores:
print(s[0])
print("Not in optSequenceCache: ")
print(score[0])
print(hash(score[0]))
survive = False
if (len(new_pop) < 4 or count_old_seqs < 3) and len(new_pop) < 10:
if optSequenceCache[score[0]][1] > 3:
count_old_seqs += 1
new_pop.append(score[0])
survive = True
print("".join(["{:6.2f} ".format(sc) for sc in score[1]]) + "(" +
"".join(["{:3d} ".format(sc) for sc in score[2]]) + "): " +
"{:2d}".format(optSequenceCache[score[0]][1]) + (" * " if survive else " ") +
str(score[0]))
print()
return new_pop
def rateSequence(seq, eqs_set, dh, staggered_params=None):
if seq not in optSequenceCache:
optSequenceCache[seq] = [[], 0]
cache_entry = optSequenceCache[seq]
if cache_entry[1] > 10:
return (cache_entry[0], [0, 0])
print(cache_entry[1], end=" ")
print(seq)
start = timer()
transformed_eqs_set = [seq.applyOpts(eqs) for eqs in eqs_set]
end = timer()
kernel_results = [
bench_kernel(eqs, dh, seq.blockSize, staggered_params) for eqs in transformed_eqs_set
]
kernel_registers = [k[1] for k in kernel_results]
result = [k[0] for k in kernel_results
] + [k[0] * max(0.0, (len(seq.opts) - 3) * 0.1) for k in kernel_results]
if cache_entry[1] == 0:
cache_entry[0] = result
else:
for i in range(0, len(result)):
cache_entry[0][i] = (cache_entry[0][i] * cache_entry[1] + result[i]) / (
cache_entry[1] + 1)
cache_entry[1] += 1
return cache_entry[0], kernel_registers
def bench_kernel(eqs, dh, blockSize=(64, 2, 1), staggered_params=None):
if staggered_params is None:
kernel = ps.create_kernel(
eqs, target="gpu", gpu_indexing_params={
"block_size": blockSize
}).compile()
else:
kernel = ps.create_staggered_kernel(
*pack_staggered_eqs(eqs, *staggered_params),
target="gpu",
gpu_indexing_params={
"block_size": blockSize
}).compile()
start = drv.Event()
end = drv.Event()
start.record()
dh.run_kernel(kernel, timestep=1)
dh.run_kernel(kernel, timestep=1)
end.record()
end.synchronize()
msec = start.time_till(end) / 2
return msec, kernel.num_regs
# coding: utf-8
# In[32]:
import pickle
import warnings
import pystencils as ps
from pygrandchem.grandchem import GrandChemGenerator
from pygrandchem.scenarios import system_4_2, system_3_1
from pygrandchem.initialization import init_boxes, smooth_fields
from pygrandchem.scenarios import benchmark_configs
from sympy import Number, Symbol, Expr, preorder_traversal, postorder_traversal, Function, Piecewise, relational
from pystencils.simp import sympy_cse_on_assignment_list
from pystencils.simp.liveness_opts import *
from pystencils.simp.liveness_opts_exp import *
from pystencils.simp.liveness_permutations import *
import pycuda
import sys
from subprocess import run, PIPE
from pystencils import show_code
import pycuda.driver as drv
import importlib
configs = benchmark_configs()
def get_config(name):
return configs[name]
domain_size = (512, 512, 128)
periodicity = (True, True, False)
optimization = {'gpu_indexing_params': {"block_size": (32, 4, 2)}}
#bestSeqs = pickle.load(open('best_seq.pickle', 'rb'))
scenarios = ["42_varT_freeEnergy", "31_varT_aniso_rot"]
kernel_types = ["phi_full", "phi_partial1", "phi_partial2", "mu_full", "mu_partial1", "mu_partial2"]
liveness_trans_seqs = importlib.import_module(
"gpu_liveness_trans_sequences").gpu_liveness_trans_sequences
for scenario in scenarios:
config = get_config(scenario)
phases, components = config['Parameters']['phases'], config['Parameters']['components']
format_args = {'p': phases, 'c': components, 's': ','.join(str(e) for e in domain_size)}
# Adding fields
dh = ps.create_data_handling(domain_size, periodicity=periodicity, default_target='gpu')
f = dh.fields
phi_src = dh.add_array(
'phi_src',
values_per_cell=config['Parameters']['phases'],
layout='fzyx',
latex_name='phi_s')
mu_src = dh.add_array(
'mu_src',
values_per_cell=config['Parameters']['components'],
layout='fzyx',
latex_name="mu_s")
mu_stag = dh.add_array(
'mu_stag', values_per_cell=(dh.dim, config['Parameters']['components']), layout='f')
phi_stag = dh.add_array('phi_stag', values_per_cell=(dh.dim, phases), layout='f')
phi_dst = dh.add_array_like('phi_dst', 'phi_src')
mu_dst = dh.add_array_like('mu_dst', 'mu_src')
gc = GrandChemGenerator(
phi_src,
phi_dst,
mu_src,
mu_dst,
config['FreeEnergy'],
config['Parameters'],
#conc=c,
mu_staggered=mu_stag,
phi_staggered=phi_stag,
use_block_offsets=False,
compile_kernel=False)
mu_full_eqs = gc.mu_full()
phi_full_eqs = gc.phi_full()
phi_kernel = ps.create_kernel(phi_full_eqs, target='gpu', **optimization).compile()
mu_kernel = ps.create_kernel(mu_full_eqs, target='gpu', **optimization).compile()
c = dh.add_array(
'c', values_per_cell=config['Parameters']['components'], layout='fzyx', gpu=False)
init_boxes(dh)
#initialize_concentration_field(dh, free_energy, config['Parameters']['initial_concentration'])
smooth_fields(dh, sigma=0.4, iterations=5, dim=dh.dim)
dh.synchronization_function(['phi_src', 'phi_dst', 'mu_src', 'mu_dst'])()
staggered_params = None