kerncraft_interface.py 6.68 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
from tempfile import TemporaryDirectory
Martin Bauer's avatar
Martin Bauer committed
2

Martin Bauer's avatar
Martin Bauer committed
3
4
import sympy as sp
from collections import defaultdict
Martin Bauer's avatar
Martin Bauer committed
5
import kerncraft
Martin Bauer's avatar
Martin Bauer committed
6
import kerncraft.kernel
Julian Hammer's avatar
Julian Hammer committed
7
8
9
from typing import Optional
from kerncraft.machinemodel import MachineModel

Martin Bauer's avatar
Martin Bauer committed
10
from pystencils.kerncraft_coupling.generate_benchmark import generate_benchmark
Julian Hammer's avatar
Julian Hammer committed
11
from pystencils.astnodes import LoopOverCoordinate, SympyAssignment, ResolvedFieldAccess, KernelFunction
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.field import get_layout_from_strides
Martin Bauer's avatar
Martin Bauer committed
13
from pystencils.sympyextensions import count_operations_in_ast
14
from pystencils.transformations import filtered_tree_iteration
Martin Bauer's avatar
Martin Bauer committed
15
from pystencils.utils import DotDict
16
import warnings
Martin Bauer's avatar
Martin Bauer committed
17
18


Julian Hammer's avatar
Julian Hammer committed
19
class PyStencilsKerncraftKernel(kerncraft.kernel.KernelCode):
Martin Bauer's avatar
Martin Bauer committed
20
21
22
23
    """
    Implementation of kerncraft's kernel interface for pystencils CPU kernels.
    Analyses a list of equations assuming they will be executed on a CPU
    """
Martin Bauer's avatar
Martin Bauer committed
24
25
    LIKWID_BASE = '/usr/local/likwid'

26
27
    def __init__(self, ast: KernelFunction, machine: Optional[MachineModel] = None,
                 assumed_layout='SoA', debug_print=False, filename=None):
Julian Hammer's avatar
Julian Hammer committed
28
29
30
31
32
        """Create a kerncraft kernel using a pystencils AST

        Args:
            ast: pystencils ast
            machine: kerncraft machine model - specify this if kernel needs to be compiled
33
34
35
            assumed_layout: either 'SoA' or 'AoS' - if fields have symbolic sizes the layout of the index
                    coordinates is not known. In this case either a structures of array (SoA) or
                    array of structures (AoS) layout is assumed
Julian Hammer's avatar
Julian Hammer committed
36
        """
Julian Hammer's avatar
Julian Hammer committed
37
        kerncraft.kernel.Kernel.__init__(self, machine)
Martin Bauer's avatar
Martin Bauer committed
38

Julian Hammer's avatar
Julian Hammer committed
39
40
        # Initialize state
        self.asm_block = None
41
        self._filename = filename
Julian Hammer's avatar
Julian Hammer committed
42
43

        self.kernel_ast = ast
Martin Bauer's avatar
Martin Bauer committed
44
        self.temporary_dir = TemporaryDirectory()
Martin Bauer's avatar
Martin Bauer committed
45
46

        # Loops
47
48
        inner_loops = [l for l in filtered_tree_iteration(ast, LoopOverCoordinate, stop_type=SympyAssignment)
                       if l.is_innermost_loop]
Martin Bauer's avatar
Martin Bauer committed
49
        if len(inner_loops) == 0:
Martin Bauer's avatar
Martin Bauer committed
50
51
            raise ValueError("No loop found in pystencils AST")
        else:
52
53
54
            if len(inner_loops) > 1:
                warnings.warn("pystencils AST contains multiple inner loops. "
                              "Only one can be analyzed - choosing first one")
Martin Bauer's avatar
Martin Bauer committed
55
            inner_loop = inner_loops[0]
Martin Bauer's avatar
Martin Bauer committed
56
57

        self._loop_stack = []
Martin Bauer's avatar
Martin Bauer committed
58
59
60
        cur_node = inner_loop
        while cur_node is not None:
            if isinstance(cur_node, LoopOverCoordinate):
Martin Bauer's avatar
Martin Bauer committed
61
62
63
                loop_counter_sym = cur_node.loop_counter_symbol
                loop_info = (loop_counter_sym.name, cur_node.start, cur_node.stop, cur_node.step)
                self._loop_stack.append(loop_info)
Martin Bauer's avatar
Martin Bauer committed
64
            cur_node = cur_node.parent
Martin Bauer's avatar
Martin Bauer committed
65
66
67
        self._loop_stack = list(reversed(self._loop_stack))

        # Data sources & destinations
68
69
        self.sources = defaultdict(list)
        self.destinations = defaultdict(list)
Martin Bauer's avatar
Martin Bauer committed
70

Julian Hammer's avatar
Julian Hammer committed
71
72
73
74
75
76
77
78
79
        def get_layout_tuple(f):
            if f.has_fixed_shape:
                return get_layout_from_strides(f.strides)
            else:
                layout_list = list(f.layout)
                for _ in range(f.index_dimensions):
                    layout_list.insert(0 if assumed_layout == 'SoA' else -1, max(layout_list) + 1)
                return layout_list

Martin Bauer's avatar
Martin Bauer committed
80
        reads, writes = search_resolved_field_accesses_in_ast(inner_loop)
Martin Bauer's avatar
Martin Bauer committed
81
        for accesses, target_dict in [(reads, self.sources), (writes, self.destinations)]:
Martin Bauer's avatar
Martin Bauer committed
82
            for fa in accesses:
Martin Bauer's avatar
Martin Bauer committed
83
                coord = [sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i), positive=True, integer=True) + off
Martin Bauer's avatar
Martin Bauer committed
84
                         for i, off in enumerate(fa.offsets)]
Martin Bauer's avatar
Martin Bauer committed
85
                coord += list(fa.idx_coordinate_values)
Julian Hammer's avatar
Julian Hammer committed
86
87
                layout = get_layout_tuple(fa.field)
                permuted_coord = [sp.sympify(coord[i]) for i in layout]
Martin Bauer's avatar
Martin Bauer committed
88
                target_dict[fa.field.name].append(permuted_coord)
Martin Bauer's avatar
Martin Bauer committed
89
90

        # Variables (arrays)
Martin Bauer's avatar
Martin Bauer committed
91
92
        fields_accessed = ast.fields_accessed
        for field in fields_accessed:
Julian Hammer's avatar
Julian Hammer committed
93
            layout = get_layout_tuple(field)
Martin Bauer's avatar
Martin Bauer committed
94
95
            permuted_shape = list(field.shape[i] for i in layout)
            self.set_variable(field.name, str(field.dtype), tuple(permuted_shape))
Martin Bauer's avatar
Martin Bauer committed
96

97
98
99
100
101
        # Scalars may be safely ignored
        # for param in ast.get_parameters():
        #     if not param.is_field_parameter:
        #         # self.set_variable(param.symbol.name, str(param.symbol.dtype), None)
        #         self.sources[param.symbol.name] = [None]
Martin Bauer's avatar
Martin Bauer committed
102
103
104
105
106

        # data type
        self.datatype = list(self.variables.values())[0][0]

        # flops
Martin Bauer's avatar
Martin Bauer committed
107
        operation_count = count_operations_in_ast(inner_loop)
Martin Bauer's avatar
Martin Bauer committed
108
        self._flops = {
Martin Bauer's avatar
Martin Bauer committed
109
110
111
            '+': operation_count['adds'],
            '*': operation_count['muls'],
            '/': operation_count['divs'],
Martin Bauer's avatar
Martin Bauer committed
112
        }
Jan Hönig's avatar
Jan Hönig committed
113
114
        for k in [k for k, v in self._flops.items() if v == 0]:
            del self._flops[k]
Martin Bauer's avatar
Martin Bauer committed
115
116
        self.check()

117
118
119
120
121
122
123
124
125
126
127
        if debug_print:
            from pprint import pprint
            print("-----------------------------  Loop Stack --------------------------")
            pprint(self._loop_stack)
            print("-----------------------------  Sources -----------------------------")
            pprint(self.sources)
            print("-----------------------------  Destinations ------------------------")
            pprint(self.destinations)
            print("-----------------------------  FLOPS -------------------------------")
            pprint(self._flops)

Julian Hammer's avatar
Julian Hammer committed
128
129
130
131
132
133
134
135
    def as_code(self, type_='iaca', openmp=False):
        """
        Generate and return compilable source code.

        :param type: can be iaca or likwid.
        :param openmp: if true, openmp code will be generated
        """
        return generate_benchmark(self.kernel_ast, likwid=type_ == 'likwid')
136

Martin Bauer's avatar
Martin Bauer committed
137
138

class KerncraftParameters(DotDict):
Martin Bauer's avatar
Martin Bauer committed
139
140
    def __init__(self, **kwargs):
        super(KerncraftParameters, self).__init__(**kwargs)
Martin Bauer's avatar
Martin Bauer committed
141
142
143
144
145
        self['asm_block'] = 'auto'
        self['asm_increment'] = 0
        self['cores'] = 1
        self['cache_predictor'] = 'SIM'
        self['verbose'] = 0
Jan Hönig's avatar
Jan Hönig committed
146
        self['pointer_increment'] = 'auto'
Jan Hönig's avatar
Jan Hönig committed
147
        self['iterations'] = 10
Julian Hammer's avatar
Julian Hammer committed
148
149
        self['unit'] = 'cy/CL'
        self['ignore_warnings'] = True
Jan Hönig's avatar
Jan Hönig committed
150

Martin Bauer's avatar
Martin Bauer committed
151
152
153
154

# ------------------------------------------- Helper functions ---------------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
155
def search_resolved_field_accesses_in_ast(ast):
Martin Bauer's avatar
Martin Bauer committed
156
157
158
159
160
161
162
163
164
    def visit(node, reads, writes):
        if not isinstance(node, SympyAssignment):
            for a in node.args:
                visit(a, reads, writes)
            return

        for expr, accesses in [(node.lhs, writes), (node.rhs, reads)]:
            accesses.update(expr.atoms(ResolvedFieldAccess))

Martin Bauer's avatar
Martin Bauer committed
165
166
167
    read_accesses = set()
    write_accesses = set()
    visit(ast, read_accesses, write_accesses)
Martin Bauer's avatar
Martin Bauer committed
168
    return read_accesses, write_accesses