liveness_opts.py 4.79 KB
Newer Older
Dominik Ernst's avatar
Dominik Ernst committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from sympy import Symbol, Dummy

from pystencils import Field, Assignment

import random
import copy


def get_usage(atoms):
    reg_usage = {}
    for atom in atoms:
        reg_usage[atom.lhs] = 0
    for atom in atoms:
        for arg in atom.rhs.atoms():
            if isinstance(arg, Symbol) and not isinstance(arg, Field.Access):
                if arg in reg_usage:
                    reg_usage[arg] += 1
                else:
                    print(str(arg) + " is unsatisfied")
    return reg_usage


def get_definitions(eqs):
    definitions = {}
    for eq in eqs:
        definitions[eq.lhs] = eq
    return definitions


def get_roots(eqs):
    roots = []
    for eq in eqs:
        if isinstance(eq.lhs, Field.Access):
            roots.append(eq.lhs)
    if not roots:
        roots.append(eqs[-1].lhs)
    return roots


def merge_field_accesses(eqs):
    field_accesses = {}

    for eq in eqs:
        for arg in eq.rhs.atoms():
            if isinstance(arg, Field.Access) and arg not in field_accesses:
                field_accesses[arg] = Dummy()

    for i in range(0, len(eqs)):
        for f, s in field_accesses.items():
            if f in eqs[i].atoms():
                eqs[i] = eqs[i].subs(f, s)

    for f, s in field_accesses.items():
        eqs.insert(0, Assignment(s, f))

    return eqs


def refuse_eqs(input_eqs, max_depth=0, max_usage=1):
    eqs = copy.copy(input_eqs)
    usages = get_usage(eqs)
    definitions = get_definitions(eqs)

    def inline_trivially_schedulable(sym, depth):

        if sym not in usages or usages[sym] > max_usage or depth > max_depth:
            return sym

        rhs = definitions[sym].rhs
        if len(rhs.args) == 0:
            return rhs

        return rhs.func(*[inline_trivially_schedulable(arg, depth + 1) for arg in rhs.args])

    for idx, eq in enumerate(eqs):
        if usages[eq.lhs] > 1 or isinstance(eq.lhs, Field.Access):
            if not isinstance(eq.rhs, Symbol):

                eqs[idx] = Assignment(eq.lhs,
                                      eq.rhs.func(*[inline_trivially_schedulable(arg, 0) for arg in eq.rhs.args]))

    count = 0
    while (len(eqs) != count):
        count = len(eqs)
        usages = get_usage(eqs)
        eqs = [eq for eq in eqs if usages[eq.lhs] > 0 or isinstance(eq.lhs, Field.Access)]

    return eqs


def schedule_eqs(eqs, candidate_count=20):
    if candidate_count == 0:
        return eqs

    definitions = get_definitions(eqs)
    definition_atoms = {}
    for sym, definition in definitions.items():
        definition_atoms[sym] = list(definition.rhs.atoms(Symbol))
    roots = get_roots(eqs)
    initial_usages = get_usage(eqs)

    level = 0
    current_level_set = set([frozenset(roots)])
    current_usages = {frozenset(roots): {u: 0 for u in roots}}
    current_schedules = {frozenset(roots): (0, [])}
    max_regs = 0
    while len(current_level_set) > 0:
        new_usages = dict()
        new_schedules = dict()
        new_level_set = set()

        min_regs = min([len(current_usages[dec_set]) for dec_set in current_level_set])
        max_regs = max(max_regs, min_regs)
        candidates = [(dec_set, len(current_usages[dec_set])) for dec_set in current_level_set]

        random.shuffle(candidates)
        candidates.sort(key=lambda d: d[1])

        for dec_set, regs in candidates[:candidate_count]:
            for dec in dec_set:
                new_dec_set = set(dec_set)
                new_dec_set.remove(dec)
                usage = dict(current_usages[dec_set])
                usage.pop(dec)
                atoms = definition_atoms[dec]
                for arg in atoms:
                    if not isinstance(arg, Field.Access):
                        argu = usage.get(arg, initial_usages[arg]) - 1
                        if argu == 0:
                            new_dec_set.add(arg)
                        usage[arg] = argu
                frozen_new_dec_set = frozenset(new_dec_set)
                schedule = current_schedules[dec_set]
                max_reg_count = max(len(usage), schedule[0])

                if frozen_new_dec_set not in new_schedules or max_reg_count < new_schedules[frozen_new_dec_set][0]:

                    new_schedule = list(schedule[1])
                    new_schedule.append(definitions[dec])
                    new_schedules[frozen_new_dec_set] = (max_reg_count, new_schedule)

                if len(frozen_new_dec_set) > 0:
                    new_level_set.add(frozen_new_dec_set)
                new_usages[frozen_new_dec_set] = usage

        current_schedules = new_schedules
        current_usages = new_usages
        current_level_set = new_level_set

        level += 1

    schedule = current_schedules[frozenset()]
    schedule[1].reverse()
    return (schedule[1])


def liveness_opt_transformation(eqs):
    return refuse_eqs(merge_field_accesses(schedule_eqs(eqs, 3)), 1, 3)