From 940915cf3a1f3a338224f3fcb11579c22f8349f2 Mon Sep 17 00:00:00 2001
From: Dominik Ernst <dominik.ernst@fau.de>
Date: Wed, 31 Oct 2018 13:38:51 +0100
Subject: [PATCH] GPU Liveness Optimization

Transformations on the Sympy equation level, that aim to reduce the amount of registers the Compiler needs. Primarily aimed at GPUs. The function livenessOptTransformation is a sequence of three sub transformations, which has been determined by a genetic optimization algorithm.
The file test_liveness_opts.py is a simple code that performs the transformation and tests whether the resulting equations are still generatable and compilable. Numerical results are not checked.
---
 simp/liveness_opts.py | 158 ++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 158 insertions(+)
 create mode 100644 simp/liveness_opts.py

diff --git a/simp/liveness_opts.py b/simp/liveness_opts.py
new file mode 100644
index 000000000..3bee292de
--- /dev/null
+++ b/simp/liveness_opts.py
@@ -0,0 +1,158 @@
+from sympy import Symbol, Dummy
+
+from pystencils import Field, Assignment
+
+import random
+import copy
+
+
+def get_usage(atoms):
+    reg_usage = {}
+    for atom in atoms:
+        reg_usage[atom.lhs] = 0
+    for atom in atoms:
+        for arg in atom.rhs.atoms():
+            if isinstance(arg, Symbol) and not isinstance(arg, Field.Access):
+                if arg in reg_usage:
+                    reg_usage[arg] += 1
+                else:
+                    print(str(arg) + " is unsatisfied")
+    return reg_usage
+
+
+def get_definitions(eqs):
+    definitions = {}
+    for eq in eqs:
+        definitions[eq.lhs] = eq
+    return definitions
+
+
+def get_roots(eqs):
+    roots = []
+    for eq in eqs:
+        if isinstance(eq.lhs, Field.Access):
+            roots.append(eq.lhs)
+    if not roots:
+        roots.append(eqs[-1].lhs)
+    return roots
+
+
+def merge_field_accesses(eqs):
+    field_accesses = {}
+
+    for eq in eqs:
+        for arg in eq.rhs.atoms():
+            if isinstance(arg, Field.Access) and arg not in field_accesses:
+                field_accesses[arg] = Dummy()
+
+    for i in range(0, len(eqs)):
+        for f, s in field_accesses.items():
+            if f in eqs[i].atoms():
+                eqs[i] = eqs[i].subs(f, s)
+
+    for f, s in field_accesses.items():
+        eqs.insert(0, Assignment(s, f))
+
+    return eqs
+
+
+def refuse_eqs(input_eqs, max_depth=0, max_usage=1):
+    eqs = copy.copy(input_eqs)
+    usages = get_usage(eqs)
+    definitions = get_definitions(eqs)
+
+    def inline_trivially_schedulable(sym, depth):
+
+        if sym not in usages or usages[sym] > max_usage or depth > max_depth:
+            return sym
+
+        rhs = definitions[sym].rhs
+        if len(rhs.args) == 0:
+            return rhs
+
+        return rhs.func(*[inline_trivially_schedulable(arg, depth + 1) for arg in rhs.args])
+
+    for idx, eq in enumerate(eqs):
+        if usages[eq.lhs] > 1 or isinstance(eq.lhs, Field.Access):
+            if not isinstance(eq.rhs, Symbol):
+
+                eqs[idx] = Assignment(eq.lhs,
+                                      eq.rhs.func(*[inline_trivially_schedulable(arg, 0) for arg in eq.rhs.args]))
+
+    count = 0
+    while (len(eqs) != count):
+        count = len(eqs)
+        usages = get_usage(eqs)
+        eqs = [eq for eq in eqs if usages[eq.lhs] > 0 or isinstance(eq.lhs, Field.Access)]
+
+    return eqs
+
+
+def schedule_eqs(eqs, candidate_count=20):
+    if candidate_count == 0:
+        return eqs
+
+    definitions = get_definitions(eqs)
+    definition_atoms = {}
+    for sym, definition in definitions.items():
+        definition_atoms[sym] = list(definition.rhs.atoms(Symbol))
+    roots = get_roots(eqs)
+    initial_usages = get_usage(eqs)
+
+    level = 0
+    current_level_set = set([frozenset(roots)])
+    current_usages = {frozenset(roots): {u: 0 for u in roots}}
+    current_schedules = {frozenset(roots): (0, [])}
+    max_regs = 0
+    while len(current_level_set) > 0:
+        new_usages = dict()
+        new_schedules = dict()
+        new_level_set = set()
+
+        min_regs = min([len(current_usages[dec_set]) for dec_set in current_level_set])
+        max_regs = max(max_regs, min_regs)
+        candidates = [(dec_set, len(current_usages[dec_set])) for dec_set in current_level_set]
+
+        random.shuffle(candidates)
+        candidates.sort(key=lambda d: d[1])
+
+        for dec_set, regs in candidates[:candidate_count]:
+            for dec in dec_set:
+                new_dec_set = set(dec_set)
+                new_dec_set.remove(dec)
+                usage = dict(current_usages[dec_set])
+                usage.pop(dec)
+                atoms = definition_atoms[dec]
+                for arg in atoms:
+                    if not isinstance(arg, Field.Access):
+                        argu = usage.get(arg, initial_usages[arg]) - 1
+                        if argu == 0:
+                            new_dec_set.add(arg)
+                        usage[arg] = argu
+                frozen_new_dec_set = frozenset(new_dec_set)
+                schedule = current_schedules[dec_set]
+                max_reg_count = max(len(usage), schedule[0])
+
+                if frozen_new_dec_set not in new_schedules or max_reg_count < new_schedules[frozen_new_dec_set][0]:
+
+                    new_schedule = list(schedule[1])
+                    new_schedule.append(definitions[dec])
+                    new_schedules[frozen_new_dec_set] = (max_reg_count, new_schedule)
+
+                if len(frozen_new_dec_set) > 0:
+                    new_level_set.add(frozen_new_dec_set)
+                new_usages[frozen_new_dec_set] = usage
+
+        current_schedules = new_schedules
+        current_usages = new_usages
+        current_level_set = new_level_set
+
+        level += 1
+
+    schedule = current_schedules[frozenset()]
+    schedule[1].reverse()
+    return (schedule[1])
+
+
+def liveness_opt_transformation(eqs):
+    return refuse_eqs(merge_field_accesses(schedule_eqs(eqs, 3)), 1, 3)
-- 
GitLab