Commit b8b92cdf authored by Martin Bauer's avatar Martin Bauer
Browse files

GPU liveness optimization to reduce registers

parent 7511f364
...@@ -164,6 +164,13 @@ class CBackend: ...@@ -164,6 +164,13 @@ class CBackend:
return "%s%s\n%s" % (prefix, loop_str, self._print(node.body)) return "%s%s\n%s" % (prefix, loop_str, self._print(node.body))
def _print_SympyAssignment(self, node): def _print_SympyAssignment(self, node):
if self._dialect == 'cuda' and isinstance(node.lhs, sp.Symbol) and node.lhs.name.startswith("shmemslot"):
result = "__shared__ volatile double %s[512]; %s[threadIdx.z * " \
"blockDim.x*blockDim.y + threadIdx.y * " \
"blockDim.x + threadIdx.x] = %s;" % \
(node.lhs.name, node.lhs.name, self.sympy_printer.doprint(node.rhs))
return result
if node.is_declaration: if node.is_declaration:
data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " " data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
...@@ -254,6 +261,12 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -254,6 +261,12 @@ class CustomSympyPrinter(CCodePrinter):
res = str(expr.evalf().num) res = str(expr.evalf().num)
return res return res
def _print_Symbol(self, expr):
if self._dialect == 'cuda' and expr.name.startswith("shmemslot"):
return expr.name + "[threadIdx.z * blockDim.x*blockDim.y + threadIdx.y * blockDim.x + threadIdx.x]"
else:
return super(CustomSympyPrinter, self)._print_Symbol(expr)
def _print_Equality(self, expr): def _print_Equality(self, expr):
"""Equality operator is not printable in default printer""" """Equality operator is not printable in default printer"""
return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))' return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'
......
from sympy import Symbol, Dummy import sympy as sp
from pystencils import Field, Assignment
import random import random
import copy import copy
from typing import List
from pystencils import Field, Assignment
fa_symbol_iter = sp.numbered_symbols("fa_")
def get_usage(atoms):
reg_usage = {}
for atom in atoms:
reg_usage[atom.lhs] = 0
for atom in atoms:
for arg in atom.rhs.atoms():
if isinstance(arg, Symbol) and not isinstance(arg, Field.Access):
if arg in reg_usage:
reg_usage[arg] += 1
else:
print(str(arg) + " is unsatisfied")
return reg_usage
def get_definitions(eqs):
definitions = {}
for eq in eqs:
definitions[eq.lhs] = eq
return definitions
def get_roots(eqs): def merge_field_accesses(assignments):
roots = [] """Transformation that introduces symbols for all read field accesses
for eq in eqs: for multiple read accesses only one symbol is introduced"""
if isinstance(eq.lhs, Field.Access):
roots.append(eq.lhs)
if not roots:
roots.append(eqs[-1].lhs)
return roots
def merge_field_accesses(eqs):
field_accesses = {} field_accesses = {}
for eq in eqs: new_eqs = copy.copy(assignments)
for arg in eq.rhs.atoms(): for assignment in new_eqs:
for arg in assignment.rhs.atoms():
if isinstance(arg, Field.Access) and arg not in field_accesses: if isinstance(arg, Field.Access) and arg not in field_accesses:
field_accesses[arg] = Dummy() field_accesses[arg] = next(fa_symbol_iter)
for i in range(0, len(eqs)): for i in range(0, len(new_eqs)):
for f, s in field_accesses.items(): for f, s in field_accesses.items():
if f in eqs[i].atoms(): if f in new_eqs[i].atoms():
eqs[i] = eqs[i].subs(f, s) new_eqs[i] = new_eqs[i].subs(f, s)
for f, s in field_accesses.items(): for f, s in field_accesses.items():
eqs.insert(0, Assignment(s, f)) new_eqs.insert(0, Assignment(s, f))
return new_eqs
return eqs
def fuse_eqs(input_eqs, max_depth=1, max_usage=1):
"""Inserts subexpressions that are used not more than `max_usage`
def refuse_eqs(input_eqs, max_depth=0, max_usage=1): Args:
max_depth: complexity metric for the subexpression to insert
if max_depth is larger than the expression tree of the subexpression
the subexpressions is not inserted
Somewhat the inverse of common subexpression elimination.
"""
eqs = copy.copy(input_eqs) eqs = copy.copy(input_eqs)
usages = get_usage(eqs) usages = get_usage(eqs)
definitions = get_definitions(eqs) definitions = get_definitions(eqs)
def inline_trivially_schedulable(sym, depth): def inline_trivially_schedulable(sym, depth):
if sym not in usages or usages[sym] > max_usage or depth > max_depth: if sym not in definitions or sym not in usages or usages[sym] > max_usage or depth > max_depth:
return sym return sym
rhs = definitions[sym].rhs rhs = definitions[sym].rhs
...@@ -74,13 +56,13 @@ def refuse_eqs(input_eqs, max_depth=0, max_usage=1): ...@@ -74,13 +56,13 @@ def refuse_eqs(input_eqs, max_depth=0, max_usage=1):
for idx, eq in enumerate(eqs): for idx, eq in enumerate(eqs):
if usages[eq.lhs] > 1 or isinstance(eq.lhs, Field.Access): if usages[eq.lhs] > 1 or isinstance(eq.lhs, Field.Access):
if not isinstance(eq.rhs, Symbol): if not isinstance(eq.rhs, sp.Symbol):
eqs[idx] = Assignment(
eqs[idx] = Assignment(eq.lhs, eq.lhs,
eq.rhs.func(*[inline_trivially_schedulable(arg, 0) for arg in eq.rhs.args])) eq.rhs.func(*[inline_trivially_schedulable(arg, 0) for arg in eq.rhs.args]))
count = 0 count = 0
while (len(eqs) != count): while len(eqs) != count:
count = len(eqs) count = len(eqs)
usages = get_usage(eqs) usages = get_usage(eqs)
eqs = [eq for eq in eqs if usages[eq.lhs] > 0 or isinstance(eq.lhs, Field.Access)] eqs = [eq for eq in eqs if usages[eq.lhs] > 0 or isinstance(eq.lhs, Field.Access)]
...@@ -88,16 +70,26 @@ def refuse_eqs(input_eqs, max_depth=0, max_usage=1): ...@@ -88,16 +70,26 @@ def refuse_eqs(input_eqs, max_depth=0, max_usage=1):
return eqs return eqs
def schedule_eqs(eqs, candidate_count=20): def schedule_eqs(assignments: List[Assignment], candidate_count=20):
"""Changes order of assignments to save registers.
Args:
assignments:
candidate_count: tuning parameter, small means fast, but bad scheduling quality
1 corresponds to full greedy search
Returns:
list of re-ordered assignments
"""
if candidate_count == 0: if candidate_count == 0:
return eqs return assignments
definitions = get_definitions(eqs) definitions = get_definitions(assignments)
definition_atoms = {} definition_atoms = {}
for sym, definition in definitions.items(): for sym, definition in definitions.items():
definition_atoms[sym] = list(definition.rhs.atoms(Symbol)) definition_atoms[sym] = list(definition.rhs.atoms(sp.Symbol))
roots = get_roots(eqs) roots = get_roots(assignments)
initial_usages = get_usage(eqs) initial_usages = get_usage(assignments)
level = 0 level = 0
current_level_set = set([frozenset(roots)]) current_level_set = set([frozenset(roots)])
...@@ -111,12 +103,18 @@ def schedule_eqs(eqs, candidate_count=20): ...@@ -111,12 +103,18 @@ def schedule_eqs(eqs, candidate_count=20):
min_regs = min([len(current_usages[dec_set]) for dec_set in current_level_set]) min_regs = min([len(current_usages[dec_set]) for dec_set in current_level_set])
max_regs = max(max_regs, min_regs) max_regs = max(max_regs, min_regs)
candidates = [(dec_set, len(current_usages[dec_set])) for dec_set in current_level_set]
def score_dec_set(dec_set):
score = len(current_usages[dec_set]) # current_schedules[dec_set][0]
return dec_set, score
candidates = [score_dec_set(dec_set) for dec_set in current_level_set]
random.shuffle(candidates) random.shuffle(candidates)
candidates.sort(key=lambda d: d[1]) candidates.sort(key=lambda d: d[1])
for dec_set, regs in candidates[:candidate_count]: for dec_set, regs in candidates[:candidate_count]:
for dec in dec_set: for dec in dec_set:
new_dec_set = set(dec_set) new_dec_set = set(dec_set)
new_dec_set.remove(dec) new_dec_set.remove(dec)
...@@ -126,7 +124,7 @@ def schedule_eqs(eqs, candidate_count=20): ...@@ -126,7 +124,7 @@ def schedule_eqs(eqs, candidate_count=20):
for arg in atoms: for arg in atoms:
if not isinstance(arg, Field.Access): if not isinstance(arg, Field.Access):
argu = usage.get(arg, initial_usages[arg]) - 1 argu = usage.get(arg, initial_usages[arg]) - 1
if argu == 0: if argu == 0 and arg in definitions:
new_dec_set.add(arg) new_dec_set.add(arg)
usage[arg] = argu usage[arg] = argu
frozen_new_dec_set = frozenset(new_dec_set) frozen_new_dec_set = frozenset(new_dec_set)
...@@ -134,7 +132,6 @@ def schedule_eqs(eqs, candidate_count=20): ...@@ -134,7 +132,6 @@ def schedule_eqs(eqs, candidate_count=20):
max_reg_count = max(len(usage), schedule[0]) max_reg_count = max(len(usage), schedule[0])
if frozen_new_dec_set not in new_schedules or max_reg_count < new_schedules[frozen_new_dec_set][0]: if frozen_new_dec_set not in new_schedules or max_reg_count < new_schedules[frozen_new_dec_set][0]:
new_schedule = list(schedule[1]) new_schedule = list(schedule[1])
new_schedule.append(definitions[dec]) new_schedule.append(definitions[dec])
new_schedules[frozen_new_dec_set] = (max_reg_count, new_schedule) new_schedules[frozen_new_dec_set] = (max_reg_count, new_schedule)
...@@ -150,9 +147,77 @@ def schedule_eqs(eqs, candidate_count=20): ...@@ -150,9 +147,77 @@ def schedule_eqs(eqs, candidate_count=20):
level += 1 level += 1
schedule = current_schedules[frozenset()] schedule = current_schedules[frozenset()]
schedule[1].reverse() schedule[1].reverse()
return (schedule[1]) return schedule[1]
def liveness_opt_transformation(eqs): def liveness_opt_transformation(eqs):
return refuse_eqs(merge_field_accesses(schedule_eqs(eqs, 3)), 1, 3) return fuse_eqs(merge_field_accesses(schedule_eqs(eqs, 30)), 1, 3)
# ---------- Utilities -----------------------------------------------------------------------------------------
def get_usage(assignments: List[Assignment]):
"""Count number of reads for all symbols in list of assignments
Returns:
dictionary mapping symbol to number of its reads
"""
reg_usage = {}
for assignment in assignments:
for arg in assignment.rhs.atoms():
if isinstance(arg, sp.Symbol) and not isinstance(arg, Field.Access):
if arg in reg_usage:
reg_usage[arg] += 1
else:
reg_usage[arg] = 1
return reg_usage
def get_definitions(assignments: List[Assignment]):
"""Returns dictionary mapping symbol to its defining assignment"""
definitions = {}
for assignment in assignments:
definitions[assignment.lhs] = assignment
return definitions
def get_roots(eqs):
"""Returns all field accesses that are used as lhs in assignment (stores)
In case there are no independent assignments, the last one is returned (TODO try if necessary)
"""
roots = []
for eq in eqs:
if isinstance(eq.lhs, Field.Access):
roots.append(eq.lhs)
if not roots:
roots.append(eqs[-1].lhs)
return roots
# ---------- Staggered kernels -----------------------------------------------------------------------------------------
def unpack_staggered_eqs(field, expressions, subexpressions):
eqs = copy.deepcopy(subexpressions)
for dim in range(0, len(expressions)):
for vec in range(0, len(expressions[dim])):
eqs.append(Assignment(Field.Access(field, (0, 0, 0, dim, vec)), expressions[dim][vec]))
return eqs
def pack_staggered_eqs(eqs, field, expressions, subexpressions):
new_matrix_list = [0] * (field.shape[-1] * field.shape[-2])
for eq in eqs:
if isinstance(eq.lhs, Field.Access):
new_matrix_list[eq.lhs.offsets[-2] * field.shape[-1] + eq.lhs.offsets[-1]] = eq.rhs
subexpressions = [eq for eq in eqs if not isinstance(eq.lhs, Field.Access)]
return (field, [
sp.Matrix(field.shape[-1], 1,
new_matrix_list[dim * field.shape[-1]:(dim + 1) * field.shape[-1]])
for dim in range(field.shape[-2])
], subexpressions)
This diff is collapsed.
from pygrandchem.grandchem import StaggeredKernelParams
from pystencils.simp.liveness_opts import *
from pystencils.simp.liveness_opts_exp import *
import random
import pycuda.driver as drv
import pystencils as ps
from pystencils import show_code
from timeit import default_timer as timer
import copy
optSequenceCache = {}
all_opts = [[atomize_eqs, []], [schedule_eqs, [2]], [duplicate_trivial_ops, [3, 1]],
[merge_field_accesses, []], [refuse_eqs, [1, 1]], [var_to_shmem, [4]],
[var_to_shmem_lt, [4]]]
def mutateOptSequence(seq):
changed = False
new_seq = copy.deepcopy(seq)
while not changed:
choice = random.randint(0, 4)
if choice == 0:
new_seq.opts.append(random.choice(all_opts))
changed = True
elif choice == 1:
if len(new_seq.opts) > 1:
a = random.randint(0, len(new_seq.opts) - 1)
b = random.randint(0, len(new_seq.opts) - 1)
new_seq.opts[a], new_seq.opts[b] = new_seq.opts[b], new_seq.opts[a]
changed = True
elif choice == 2:
if len(new_seq.opts) > 0:
new_seq.opts.remove(random.choice(new_seq.opts))
changed = True
elif choice == 3:
if len(new_seq.opts) > 0:
opt = random.choice(new_seq.opts)
change = random.choice([-1, 1])
factor = 1
if change < 0:
factor = random.uniform(0.3, 1.0)
if change > 0:
factor = random.uniform(1.0, 3.0)
if len(opt[1]) > 0:
arg = random.randint(0, len(opt[1]) - 1)
opt[1][arg] = int(max(0, opt[1][arg] * factor + change))
changed = True
else:
dim = random.randint(0, 2)
change = random.randint(0, 1)
newBlockSize = list(seq.blockSize)
if change == 0:
newBlockSize[dim] = min(512, newBlockSize[dim] * 2)
else:
newBlockSize[dim] = max(1, newBlockSize[dim] // 2)
if newBlockSize[0] * newBlockSize[1] * newBlockSize[2] <= 512 and (
newBlockSize[0] >= 32 or newBlockSize[0] >= seq.blockSize[0]):
seq.blockSize = tuple(newBlockSize)
changed = True
return new_seq
def evolvePopulation(pop, eqs_set, dhs, staggered_params=None):
pop.append(livenessOptSequence())
once_mutated = [mutateOptSequence(seq) for seq in pop[0:6]]
twice_mutated = [mutateOptSequence(mutateOptSequence(seq)) for seq in pop[0:4]]
thrice_mutated = [
mutateOptSequence(mutateOptSequence(mutateOptSequence(seq))) for seq in pop[0:3]
]
new_pop = list(set(pop + once_mutated + twice_mutated + thrice_mutated))
scores = []
for seq in new_pop:
scores.append((seq, *rateSequence(seq, eqs_set, dhs, staggered_params)))
old_scores = []
for s in optSequenceCache:
if s not in new_pop:
if s not in optSequenceCache:
print("Not in optSequenceCache: ")
print(s)
print(hash(s))
old_scores.append((s, optSequenceCache[s][0], [0, 0]))
old_scores.sort(key=lambda s: sum(s[1]))
if len(old_scores) > 0: scores.append(old_scores[0])
print()
scores.sort(key=lambda s: sum(s[1]))
new_pop = []
count_old_seqs = 0
for score in scores:
if score[0] not in optSequenceCache:
print("Everything in scores: ")
for s in scores:
print(s[0])
print("Not in optSequenceCache: ")
print(score[0])
print(hash(score[0]))
survive = False
if (len(new_pop) < 4 or count_old_seqs < 3) and len(new_pop) < 10:
if optSequenceCache[score[0]][1] > 3:
count_old_seqs += 1
new_pop.append(score[0])
survive = True
print("".join(["{:6.2f} ".format(sc) for sc in score[1]]) + "(" +
"".join(["{:3d} ".format(sc) for sc in score[2]]) + "): " +
"{:2d}".format(optSequenceCache[score[0]][1]) + (" * " if survive else " ") +
str(score[0]))
print()
return new_pop
def rateSequence(seq, eqs_set, dh, staggered_params=None):
if seq not in optSequenceCache:
optSequenceCache[seq] = [[], 0]
cache_entry = optSequenceCache[seq]
if cache_entry[1] > 10:
return (cache_entry[0], [0, 0])
print(cache_entry[1], end=" ")
print(seq)
start = timer()
transformed_eqs_set = [seq.applyOpts(eqs) for eqs in eqs_set]
end = timer()
kernel_results = [
bench_kernel(eqs, dh, seq.blockSize, staggered_params) for eqs in transformed_eqs_set
]
kernel_registers = [k[1] for k in kernel_results]
result = [k[0] for k in kernel_results
] + [k[0] * max(0.0, (len(seq.opts) - 3) * 0.1) for k in kernel_results]
if cache_entry[1] == 0:
cache_entry[0] = result
else:
for i in range(0, len(result)):
cache_entry[0][i] = (cache_entry[0][i] * cache_entry[1] + result[i]) / (
cache_entry[1] + 1)
cache_entry[1] += 1
return cache_entry[0], kernel_registers
def bench_kernel(eqs, dh, blockSize=(64, 2, 1), staggered_params=None):
if staggered_params is None:
kernel = ps.create_kernel(
eqs, target="gpu", gpu_indexing_params={
"block_size": blockSize
}).compile()
else:
kernel = ps.create_staggered_kernel(
*pack_staggered_eqs(eqs, *staggered_params),
target="gpu",
gpu_indexing_params={
"block_size": blockSize
}).compile()
start = drv.Event()
end = drv.Event()
start.record()
dh.run_kernel(kernel, timestep=1)
dh.run_kernel(kernel, timestep=1)
end.record()
end.synchronize()
msec = start.time_till(end) / 2
return msec, kernel.num_regs
# coding: utf-8
# In[32]:
import pickle
import warnings
import pystencils as ps
from pygrandchem.grandchem import GrandChemGenerator
from pygrandchem.scenarios import system_4_2, system_3_1
from pygrandchem.initialization import init_boxes, smooth_fields
from pygrandchem.scenarios import benchmark_configs
from sympy import Number, Symbol, Expr, preorder_traversal, postorder_traversal, Function, Piecewise, relational
from pystencils.simp import sympy_cse_on_assignment_list
from pystencils.simp.liveness_opts import *
from pystencils.simp.liveness_opts_exp import *
from pystencils.simp.liveness_permutations import *
import pycuda
import sys
from subprocess import run, PIPE
from pystencils import show_code
import pycuda.driver as drv
import importlib
configs = benchmark_configs()
def get_config(name):
return configs[name]
domain_size = (512, 512, 128)
periodicity = (True, True, False)
optimization = {'gpu_indexing_params': {"block_size": (32, 4, 2)}}
#bestSeqs = pickle.load(open('best_seq.pickle', 'rb'))
scenarios = ["42_varT_freeEnergy", "31_varT_aniso_rot"]
kernel_types = ["phi_full", "phi_partial1", "phi_partial2", "mu_full", "mu_partial1", "mu_partial2"]