Commit b8b92cdf authored by Martin Bauer's avatar Martin Bauer
Browse files

GPU liveness optimization to reduce registers

parent 7511f364
......@@ -164,6 +164,13 @@ class CBackend:
return "%s%s\n%s" % (prefix, loop_str, self._print(node.body))
def _print_SympyAssignment(self, node):
if self._dialect == 'cuda' and isinstance(node.lhs, sp.Symbol) and node.lhs.name.startswith("shmemslot"):
result = "__shared__ volatile double %s[512]; %s[threadIdx.z * " \
"blockDim.x*blockDim.y + threadIdx.y * " \
"blockDim.x + threadIdx.x] = %s;" % \
(node.lhs.name, node.lhs.name, self.sympy_printer.doprint(node.rhs))
return result
if node.is_declaration:
data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
......@@ -254,6 +261,12 @@ class CustomSympyPrinter(CCodePrinter):
res = str(expr.evalf().num)
return res
def _print_Symbol(self, expr):
if self._dialect == 'cuda' and expr.name.startswith("shmemslot"):
return expr.name + "[threadIdx.z * blockDim.x*blockDim.y + threadIdx.y * blockDim.x + threadIdx.x]"
else:
return super(CustomSympyPrinter, self)._print_Symbol(expr)
def _print_Equality(self, expr):
"""Equality operator is not printable in default printer"""
return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'
......
from sympy import Symbol, Dummy
from pystencils import Field, Assignment
import sympy as sp
import random
import copy
from typing import List
from pystencils import Field, Assignment
def get_usage(atoms):
reg_usage = {}
for atom in atoms:
reg_usage[atom.lhs] = 0
for atom in atoms:
for arg in atom.rhs.atoms():
if isinstance(arg, Symbol) and not isinstance(arg, Field.Access):
if arg in reg_usage:
reg_usage[arg] += 1
else:
print(str(arg) + " is unsatisfied")
return reg_usage
def get_definitions(eqs):
definitions = {}
for eq in eqs:
definitions[eq.lhs] = eq
return definitions
fa_symbol_iter = sp.numbered_symbols("fa_")
def get_roots(eqs):
roots = []
for eq in eqs:
if isinstance(eq.lhs, Field.Access):
roots.append(eq.lhs)
if not roots:
roots.append(eqs[-1].lhs)
return roots
def merge_field_accesses(eqs):
def merge_field_accesses(assignments):
"""Transformation that introduces symbols for all read field accesses
for multiple read accesses only one symbol is introduced"""
field_accesses = {}
for eq in eqs:
for arg in eq.rhs.atoms():
new_eqs = copy.copy(assignments)
for assignment in new_eqs:
for arg in assignment.rhs.atoms():
if isinstance(arg, Field.Access) and arg not in field_accesses:
field_accesses[arg] = Dummy()
field_accesses[arg] = next(fa_symbol_iter)
for i in range(0, len(eqs)):
for i in range(0, len(new_eqs)):
for f, s in field_accesses.items():
if f in eqs[i].atoms():
eqs[i] = eqs[i].subs(f, s)
if f in new_eqs[i].atoms():
new_eqs[i] = new_eqs[i].subs(f, s)
for f, s in field_accesses.items():
eqs.insert(0, Assignment(s, f))
new_eqs.insert(0, Assignment(s, f))
return new_eqs
return eqs
def fuse_eqs(input_eqs, max_depth=1, max_usage=1):
"""Inserts subexpressions that are used not more than `max_usage`
def refuse_eqs(input_eqs, max_depth=0, max_usage=1):
Args:
max_depth: complexity metric for the subexpression to insert
if max_depth is larger than the expression tree of the subexpression
the subexpressions is not inserted
Somewhat the inverse of common subexpression elimination.
"""
eqs = copy.copy(input_eqs)
usages = get_usage(eqs)
definitions = get_definitions(eqs)
def inline_trivially_schedulable(sym, depth):
if sym not in usages or usages[sym] > max_usage or depth > max_depth:
if sym not in definitions or sym not in usages or usages[sym] > max_usage or depth > max_depth:
return sym
rhs = definitions[sym].rhs
......@@ -74,13 +56,13 @@ def refuse_eqs(input_eqs, max_depth=0, max_usage=1):
for idx, eq in enumerate(eqs):
if usages[eq.lhs] > 1 or isinstance(eq.lhs, Field.Access):
if not isinstance(eq.rhs, Symbol):
eqs[idx] = Assignment(eq.lhs,
eq.rhs.func(*[inline_trivially_schedulable(arg, 0) for arg in eq.rhs.args]))
if not isinstance(eq.rhs, sp.Symbol):
eqs[idx] = Assignment(
eq.lhs,
eq.rhs.func(*[inline_trivially_schedulable(arg, 0) for arg in eq.rhs.args]))
count = 0
while (len(eqs) != count):
while len(eqs) != count:
count = len(eqs)
usages = get_usage(eqs)
eqs = [eq for eq in eqs if usages[eq.lhs] > 0 or isinstance(eq.lhs, Field.Access)]
......@@ -88,16 +70,26 @@ def refuse_eqs(input_eqs, max_depth=0, max_usage=1):
return eqs
def schedule_eqs(eqs, candidate_count=20):
def schedule_eqs(assignments: List[Assignment], candidate_count=20):
"""Changes order of assignments to save registers.
Args:
assignments:
candidate_count: tuning parameter, small means fast, but bad scheduling quality
1 corresponds to full greedy search
Returns:
list of re-ordered assignments
"""
if candidate_count == 0:
return eqs
return assignments
definitions = get_definitions(eqs)
definitions = get_definitions(assignments)
definition_atoms = {}
for sym, definition in definitions.items():
definition_atoms[sym] = list(definition.rhs.atoms(Symbol))
roots = get_roots(eqs)
initial_usages = get_usage(eqs)
definition_atoms[sym] = list(definition.rhs.atoms(sp.Symbol))
roots = get_roots(assignments)
initial_usages = get_usage(assignments)
level = 0
current_level_set = set([frozenset(roots)])
......@@ -111,12 +103,18 @@ def schedule_eqs(eqs, candidate_count=20):
min_regs = min([len(current_usages[dec_set]) for dec_set in current_level_set])
max_regs = max(max_regs, min_regs)
candidates = [(dec_set, len(current_usages[dec_set])) for dec_set in current_level_set]
def score_dec_set(dec_set):
score = len(current_usages[dec_set]) # current_schedules[dec_set][0]
return dec_set, score
candidates = [score_dec_set(dec_set) for dec_set in current_level_set]
random.shuffle(candidates)
candidates.sort(key=lambda d: d[1])
for dec_set, regs in candidates[:candidate_count]:
for dec in dec_set:
new_dec_set = set(dec_set)
new_dec_set.remove(dec)
......@@ -126,7 +124,7 @@ def schedule_eqs(eqs, candidate_count=20):
for arg in atoms:
if not isinstance(arg, Field.Access):
argu = usage.get(arg, initial_usages[arg]) - 1
if argu == 0:
if argu == 0 and arg in definitions:
new_dec_set.add(arg)
usage[arg] = argu
frozen_new_dec_set = frozenset(new_dec_set)
......@@ -134,7 +132,6 @@ def schedule_eqs(eqs, candidate_count=20):
max_reg_count = max(len(usage), schedule[0])
if frozen_new_dec_set not in new_schedules or max_reg_count < new_schedules[frozen_new_dec_set][0]:
new_schedule = list(schedule[1])
new_schedule.append(definitions[dec])
new_schedules[frozen_new_dec_set] = (max_reg_count, new_schedule)
......@@ -150,9 +147,77 @@ def schedule_eqs(eqs, candidate_count=20):
level += 1
schedule = current_schedules[frozenset()]
schedule[1].reverse()
return (schedule[1])
return schedule[1]
def liveness_opt_transformation(eqs):
return refuse_eqs(merge_field_accesses(schedule_eqs(eqs, 3)), 1, 3)
return fuse_eqs(merge_field_accesses(schedule_eqs(eqs, 30)), 1, 3)
# ---------- Utilities -----------------------------------------------------------------------------------------
def get_usage(assignments: List[Assignment]):
"""Count number of reads for all symbols in list of assignments
Returns:
dictionary mapping symbol to number of its reads
"""
reg_usage = {}
for assignment in assignments:
for arg in assignment.rhs.atoms():
if isinstance(arg, sp.Symbol) and not isinstance(arg, Field.Access):
if arg in reg_usage:
reg_usage[arg] += 1
else:
reg_usage[arg] = 1
return reg_usage
def get_definitions(assignments: List[Assignment]):
"""Returns dictionary mapping symbol to its defining assignment"""
definitions = {}
for assignment in assignments:
definitions[assignment.lhs] = assignment
return definitions
def get_roots(eqs):
"""Returns all field accesses that are used as lhs in assignment (stores)
In case there are no independent assignments, the last one is returned (TODO try if necessary)
"""
roots = []
for eq in eqs:
if isinstance(eq.lhs, Field.Access):
roots.append(eq.lhs)
if not roots:
roots.append(eqs[-1].lhs)
return roots
# ---------- Staggered kernels -----------------------------------------------------------------------------------------
def unpack_staggered_eqs(field, expressions, subexpressions):
eqs = copy.deepcopy(subexpressions)
for dim in range(0, len(expressions)):
for vec in range(0, len(expressions[dim])):
eqs.append(Assignment(Field.Access(field, (0, 0, 0, dim, vec)), expressions[dim][vec]))
return eqs
def pack_staggered_eqs(eqs, field, expressions, subexpressions):
new_matrix_list = [0] * (field.shape[-1] * field.shape[-2])
for eq in eqs:
if isinstance(eq.lhs, Field.Access):
new_matrix_list[eq.lhs.offsets[-2] * field.shape[-1] + eq.lhs.offsets[-1]] = eq.rhs
subexpressions = [eq for eq in eqs if not isinstance(eq.lhs, Field.Access)]
return (field, [
sp.Matrix(field.shape[-1], 1,
new_matrix_list[dim * field.shape[-1]:(dim + 1) * field.shape[-1]])
for dim in range(field.shape[-2])
], subexpressions)
import sympy
import itertools
from sympy import Symbol, Piecewise, Number, postorder_traversal, numbered_symbols
from pystencils.simp.liveness_opts import *
atom_symbol_iter = numbered_symbols("atom_")
def three_operand_form(assignments):
"""Transforms list of assignments in three operand form"""
def atomize(expr, atoms):
if len(expr.args) == 0:
return expr
atom = next(atom_symbol_iter)
if len(expr.args) == 1:
atoms.append(Assignment(atom, expr.func(atomize(expr.args[0], atoms))))
return atom
if isinstance(expr, Piecewise):
atoms.append(
Assignment(
atom,
Piecewise(*[(atomize(expr.expr, atoms), expr.cond) for expr in expr.args])))
return atom
atoms.append(Assignment(atom, expr.func(atomize(expr.args[0], atoms), atomize(expr.args[1], atoms))))
current_atom = atom
for i in range(2, len(expr.args)):
atom = next(atom_symbol_iter)
atoms.append(Assignment(atom, expr.func(atomize(expr.args[i], atoms), current_atom)))
current_atom = atom
return current_atom
atoms = []
for eq in assignments:
new_atoms = []
atomize(eq.rhs, new_atoms)
if len(new_atoms) > 0:
new_atoms[-1] = Assignment(eq.lhs, new_atoms[-1].rhs)
else:
new_atoms.append(eq)
atoms.extend(new_atoms)
return atoms
def var_to_shmem(eqs, var_count=8):
if var_count > 8:
return eqs
if var_count == 0:
return copy.copy(eqs)
for eq in eqs:
if eq.lhs.name.startswith("shmemslot"):
return eqs
usage = get_usage(eqs)
usage_list = [(s, usage[s]) for s in usage]
usage_list.sort(key=lambda s: -s[1])
vars = [Symbol("shmemslot" + str(i)) for i in range(0, var_count)]
shmem_eqs = []
for idx, eq in enumerate(eqs):
shmem_eqs.append(eq.subs([(usage_list[i][0], vars[i]) for i in range(0, var_count)]))
return shmem_eqs
def shift_fa_eqs(eqs, direction=1):
def shift_fa(expr, direction):
if isinstance(expr, Field.Access):
return expr.neighbor(0, direction)
if len(expr.args) == 0:
return expr
else:
return expr.func(*[shift_fa(arg, direction) for arg in expr.args])
new_eqs = []
for eq in eqs:
new_eqs.append(shift_fa(eq, direction))
return new_eqs
def get_steal_list(eqs, shifted_eqs):
def is_equal_arg(left_arg, right_arg, steal_list, left_def, right_def, verbose=False):
if verbose: print("is_equal_arg: IN left_arg " + str(left_arg))
if verbose: print("is_equal_arg: IN right_arg " + str(right_arg))
if verbose: print("is_equal_arg: SUB left_arg " + str(left_arg))
if verbose: print("is_equal_arg: SUB left_arg " + str(right_arg))
if isinstance(left_arg, Number): return left_arg == right_arg
if isinstance(left_arg, Field.Access): return left_arg == right_arg
if left_arg not in steal_list: return False
if verbose: print("is_equal_arg: stolen" + str(steal_list[left_arg]))
return steal_list[left_arg] == right_arg
def is_equal_expr(left_expr, right_expr, steal_list, left_def, right_def, verbose=False):
# print(str(left_expr) + " =?= " + str(right_expr))
if type(left_expr) != type(right_expr): return False
if left_expr.func != right_expr.func or len(left_expr.args) != len(right_expr.args):
return False
if len(left_expr.args) == 0:
return is_equal_arg(left_expr, right_expr, steal_list, left_def, right_def, verbose)
for left_arg_perm in itertools.permutations(left_expr.args):
equal_args = True
for idx, left_arg in enumerate(left_arg_perm):
if not is_equal_arg(left_arg, right_expr.args[idx], steal_list, left_def, right_def,
verbose):
equal_args = False
break
if equal_args: return True
return False
steal_from_e = {}
left_def = get_definitions(eqs)
right_def = get_definitions(shifted_eqs)
for lidx, asgn_left in enumerate(eqs):
verbose = False
for left_subexpr in sympy.postorder_traversal(asgn_left.rhs):
if isinstance(left_subexpr, sympy.Number) or isinstance(
left_subexpr, Field.Access) or isinstance(left_subexpr, Assignment):
continue
for ridx, asgn_right in enumerate(shifted_eqs):
for right_subexpr in sympy.postorder_traversal(asgn_right.rhs):
left_arg = left_subexpr
right_arg = right_subexpr
if isinstance(left_subexpr,
Symbol) and not isinstance(left_subexpr, Field.Access):
left_arg = left_def[left_subexpr].rhs
if isinstance(right_subexpr,
Symbol) and not isinstance(right_subexpr, Field.Access):
right_arg = right_def[right_subexpr].rhs
if is_equal_expr(left_arg, right_arg, steal_from_e, left_def, right_def,
verbose):
steal_from_e[left_subexpr] = right_subexpr
# if verbose:
print(str(left_subexpr) + " == " + str(right_subexpr))
return steal_from_e
def find_symbol(eqs, name):
for eq in eqs:
if eq.lhs.name == name: return eq.lhs
def find_expr(eqs, expr):
for idx, eq in enumerate(eqs):
for sub_expr in postorder_traversal(eq):
if sub_expr == expr:
return (idx, sub_expr, eq)
def left_steal(eqs, steal_count=2):
shifted_eqs = shift_fa_eqs(eqs)
steal_from_e = get_steal_list(eqs, shifted_eqs)
usage = get_usage(eqs)
definitions = get_definitions(eqs)
def count_nodes_up(node):
if isinstance(node, Field.Access):
return 1
if node in definitions:
node = definitions[node].rhs
node_count = 0
for arg in node.args:
if not (arg in usage and usage[arg] > 1):
node_count += count_nodes_up(arg)
return node_count + 1
new_eqs = copy.copy(eqs)
for i in range(0, steal_count):
scores = [(s, count_nodes_up(s)) for s in steal_from_e if isinstance(s, Symbol)]
scores.sort(key=lambda s: s[1], reverse=True)
print(scores[0:10])
sym_xi = scores[0][0]
print(sym_xi)
print(steal_from_e[sym_xi])
steal_src = find_expr(new_eqs, shift_fa_eqs([steal_from_e[sym_xi]], -1)[0])
shmem_var = Symbol("shmemslot" + str(i))
new_eqs.insert(steal_src[0] + 1, Assignment(shmem_var, steal_src[1]))
steal_dst = find_expr(new_eqs, sym_xi)
print(steal_dst)
print(steal_src)
print()
for idx, eq in enumerate(new_eqs):
if steal_dst[1] in eq.atoms():
new_eqs[idx] = Assignment(new_eqs[idx].lhs, new_eqs[idx].rhs.subs(
steal_dst[1], shmem_var))
new_eqs.pop(steal_dst[0])
# Ancestors of donated value cannot be stolen, therefore remove from steal list
def get_ancestor_nodes(node, definitions):
ancestors = [node]
if isinstance(node, sympy.Number):
return []
if node in definitions:
ancestors.extend(get_ancestor_nodes(definitions[node].rhs, definitions))
for arg in node.args:
ancestors.extend(get_ancestor_nodes(arg, definitions))
return ancestors
for a in get_ancestor_nodes(steal_from_e[sym_xi], definitions):
if a in steal_from_e:
steal_from_e.pop(a)
# Remove value just stolen from steal list
steal_from_e.pop(sym_xi)
return new_eqs
# eqs = atomize_eqs(eqs)
def move_forward(atoms):
reg_usage = get_usage(atoms)
i = 0
while i < len(atoms):
atom = atoms[i]
killed_regs = 0
for arg in atom.rhs.atoms():
if isinstance(arg, Field.Access) or not isinstance(arg, Symbol):
continue
reg_usage[arg] -= 1
if reg_usage[arg] == 0:
killed_regs += 1
if killed_regs == 0:
first_usage = i
for n in range(i, len(atoms)) or len(
[x for x in atoms[n].rhs.atoms() if x in atoms[i].rhs.atoms()]) != 0:
usage = atoms[n].rhs
if atom.lhs in usage.atoms():
first_usage = n
break
if first_usage - i > 5:
atoms.insert(first_usage - 1, atoms.pop(i))
for arg in atom.rhs.atoms():
if isinstance(arg, Field.Access) or not isinstance(arg, Symbol):
continue
reg_usage[arg] += 1
# print("_move " + str(i) + " " + str(first_usage) + " " +
# str(atom))
i -= 1
i += 1
return atoms
def move_backward(atoms):
reg_usage = get_usage(atoms)
i = 0
while i < len(atoms):
atom = atoms[i]
killed_regs = 0
for arg in atom.rhs.atoms():
if isinstance(arg, Field.Access) or not isinstance(arg, Symbol):
continue
reg_usage[arg] -= 1
if reg_usage[arg] == 0:
killed_regs += 1
if killed_regs > 1:
last_defined = 0
for n in range(i - 1, 0, -1):
if len([x for x in atoms[n].rhs.atoms() if x in atoms[i].rhs.atoms()
]) != 0 or atoms[n].lhs in atom.rhs.atoms():
last_defined = n
break
if i - last_defined > 5:
atoms.insert(last_defined + 1, atoms.pop(i))
# print("_move " + str(i) + " " + str(last_defined) + " " +
# str(atom) + " " + str(atoms[last_defined]))