simplificationstrategy.py 6.72 KB
Newer Older
1
2
import sympy as sp
from collections import namedtuple
Martin Bauer's avatar
Martin Bauer committed
3
4
from typing import Callable, Any, Optional, Sequence
from pystencils.assignment_collection.assignment_collection import AssignmentCollection
5
6


7
class SimplificationStrategy:
Martin Bauer's avatar
Martin Bauer committed
8
9
    """A simplification strategy is an ordered collection of simplification rules.

10
11
12
13
14
15
16
17
    Each simplification is a function taking an equation collection, and returning a new simplified
    equation collection. The strategy can nicely print intermediate simplification stages and results
    to Jupyter notebooks.
    """

    def __init__(self):
        self._rules = []

Martin Bauer's avatar
Martin Bauer committed
18
19
20
21
22
    def add(self, rule: Callable[[AssignmentCollection], AssignmentCollection]) -> None:
        """Adds the given simplification rule to the end of the collection.

        Args:
            rule: function that rewrites/simplifies an assignment collection
23
24
25
26
27
28
29
        """
        self._rules.append(rule)

    @property
    def rules(self):
        return self._rules

Martin Bauer's avatar
Martin Bauer committed
30
31
    def apply(self, assignment_collection: AssignmentCollection) -> AssignmentCollection:
        """Runs all rules on the given assignment collection."""
32
        for t in self._rules:
Martin Bauer's avatar
Martin Bauer committed
33
34
            assignment_collection = t(assignment_collection)
        return assignment_collection
35

Martin Bauer's avatar
Martin Bauer committed
36
    def __call__(self, assignment_collection: AssignmentCollection) -> AssignmentCollection:
37
        """Same as apply"""
38
        return self.apply(assignment_collection)
39

Martin Bauer's avatar
Martin Bauer committed
40
41
42
43
    def create_simplification_report(self, assignment_collection: AssignmentCollection) -> Any:
        """Creates a report to be displayed as HTML in a Jupyter notebook.

        The simplification report contains the number of operations at each simplification stage together
44
45
46
        with the run-time the simplification took.
        """

Martin Bauer's avatar
Martin Bauer committed
47
        ReportElement = namedtuple('ReportElement', ['simplificationName', 'runtime', 'adds', 'muls', 'divs', 'total'])
48
49
50
51
52
53
54
55
56
57
58

        class Report:
            def __init__(self):
                self.elements = []

            def add(self, element):
                self.elements.append(element)

            def __str__(self):
                try:
                    import tabulate
Martin Bauer's avatar
Martin Bauer committed
59
                    return tabulate(self.elements, headers=['Name', 'Runtime', 'Adds', 'Muls', 'Divs', 'Total'])
60
61
62
                except ImportError:
                    result = "Name, Adds, Muls, Divs, Runtime\n"
                    for e in self.elements:
63
                        result += ",".join([str(tuple_item) for tuple_item in e]) + "\n"
64
65
66
                    return result

            def _repr_html_(self):
Martin Bauer's avatar
Martin Bauer committed
67
68
69
70
71
72
73
                html_table = '<table style="border:none">'
                html_table += "<tr><th>Name</th>" \
                              "<th>Runtime</th>" \
                              "<th>Adds</th>" \
                              "<th>Muls</th>" \
                              "<th>Divs</th>" \
                              "<th>Total</th></tr>"
74
                line = "<tr><td>{simplificationName}</td>" \
Martin Bauer's avatar
Martin Bauer committed
75
                       "<td>{runtime}</td> <td>{adds}</td> <td>{muls}</td> <td>{divs}</td>  <td>{total}</td> </tr>"
76
77

                for e in self.elements:
Martin Bauer's avatar
Martin Bauer committed
78
79
80
81
                    # noinspection PyProtectedMember
                    html_table += line.format(**e._asdict())
                html_table += "</table>"
                return html_table
82

Michael Kuron's avatar
Michael Kuron committed
83
        import timeit
84
        report = Report()
Martin Bauer's avatar
Martin Bauer committed
85
        op = assignment_collection.operation_count
Martin Bauer's avatar
Martin Bauer committed
86
        total = op['adds'] + op['muls'] + op['divs']
Martin Bauer's avatar
Martin Bauer committed
87
        report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total))
88
        for t in self._rules:
Martin Bauer's avatar
Martin Bauer committed
89
            start_time = timeit.default_timer()
90
            assignment_collection = t(assignment_collection)
Martin Bauer's avatar
Martin Bauer committed
91
92
93
            end_time = timeit.default_timer()
            op = assignment_collection.operation_count
            time_str = "%.2f ms" % ((end_time - start_time) * 1000,)
Martin Bauer's avatar
Martin Bauer committed
94
            total = op['adds'] + op['muls'] + op['divs']
Martin Bauer's avatar
Martin Bauer committed
95
            report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total))
96
97
        return report

Martin Bauer's avatar
Martin Bauer committed
98
99
100
    def show_intermediate_results(self, assignment_collection: AssignmentCollection,
                                  symbols: Optional[Sequence[sp.Symbol]] = None) -> Any:
        """Shows the assignment collection after the application of each rule as HTML report for Jupyter notebook.
101

Martin Bauer's avatar
Martin Bauer committed
102
103
104
105
        Args:
            assignment_collection: the collection to apply the rules to
            symbols: if not None, only the assignments are shown that have one of these symbols as left hand side
        """
106
        class IntermediateResults:
Martin Bauer's avatar
Martin Bauer committed
107
            def __init__(self, strategy, collection, restrict_symbols):
108
                self.strategy = strategy
Martin Bauer's avatar
Martin Bauer committed
109
110
                self.assignment_collection = collection
                self.restrict_symbols = restrict_symbols
111
112

            def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
113
                def print_assignment_collection(title, c):
114
                    text = title
Martin Bauer's avatar
Martin Bauer committed
115
                    if self.restrict_symbols:
Martin Bauer's avatar
Martin Bauer committed
116
                        text += "\n".join([str(e) for e in c.new_filtered(self.restrict_symbols).main_assignments])
117
                    else:
Martin Bauer's avatar
Martin Bauer committed
118
                        text += (" " * 3 + (" " * 3).join(str(c).splitlines(True)))
119
120
                    return text

Martin Bauer's avatar
Martin Bauer committed
121
122
                result = print_assignment_collection("Initial Version", self.assignment_collection)
                collection = self.assignment_collection
123
                for rule in self.strategy.rules:
Martin Bauer's avatar
Martin Bauer committed
124
125
                    collection = rule(collection)
                    result += print_assignment_collection(rule.__name__, collection)
126
127
128
                return result

            def _repr_html_(self):
Martin Bauer's avatar
Martin Bauer committed
129
                def print_assignment_collection(title, c):
130
                    text = '<h5 style="padding-bottom:10px">%s</h5> <div style="padding-left:20px;">' % (title, )
Martin Bauer's avatar
Martin Bauer committed
131
                    if self.restrict_symbols:
Martin Bauer's avatar
Martin Bauer committed
132
133
                        text += "\n".join(["$$" + sp.latex(e) + '$$'
                                           for e in c.new_filtered(self.restrict_symbols).main_assignments])
134
                    else:
Martin Bauer's avatar
Martin Bauer committed
135
136
                        # noinspection PyProtectedMember
                        text += c._repr_html_()
137
138
139
                    text += "</div>"
                    return text

Martin Bauer's avatar
Martin Bauer committed
140
141
                result = print_assignment_collection("Initial Version", self.assignment_collection)
                collection = self.assignment_collection
142
                for rule in self.strategy.rules:
Martin Bauer's avatar
Martin Bauer committed
143
144
                    collection = rule(collection)
                    result += print_assignment_collection(rule.__name__, collection)
145
146
                return result

147
        return IntermediateResults(self, assignment_collection, symbols)
148
149
150
151
152
153

    def __repr__(self):
        result = "Simplification Strategy:\n"
        for t in self._rules:
            result += " - %s\n" % (t.__name__,)
        return result