kerncraft_interface.py 6.83 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
from tempfile import TemporaryDirectory
Martin Bauer's avatar
Martin Bauer committed
2

Martin Bauer's avatar
Martin Bauer committed
3
4
import sympy as sp
from collections import defaultdict
Martin Bauer's avatar
Martin Bauer committed
5
import kerncraft
Martin Bauer's avatar
Martin Bauer committed
6
import kerncraft.kernel
Julian Hammer's avatar
Julian Hammer committed
7
8
9
from typing import Optional
from kerncraft.machinemodel import MachineModel

Martin Bauer's avatar
Martin Bauer committed
10
from pystencils.kerncraft_coupling.generate_benchmark import generate_benchmark
Julian Hammer's avatar
Julian Hammer committed
11
from pystencils.astnodes import LoopOverCoordinate, SympyAssignment, ResolvedFieldAccess, KernelFunction
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.field import get_layout_from_strides
Martin Bauer's avatar
Martin Bauer committed
13
from pystencils.sympyextensions import count_operations_in_ast
14
from pystencils.transformations import filtered_tree_iteration
Martin Bauer's avatar
Martin Bauer committed
15
from pystencils.utils import DotDict
16
import warnings
Martin Bauer's avatar
Martin Bauer committed
17
18


Julian Hammer's avatar
Julian Hammer committed
19
class PyStencilsKerncraftKernel(kerncraft.kernel.KernelCode):
Martin Bauer's avatar
Martin Bauer committed
20
21
22
23
    """
    Implementation of kerncraft's kernel interface for pystencils CPU kernels.
    Analyses a list of equations assuming they will be executed on a CPU
    """
Martin Bauer's avatar
Martin Bauer committed
24
25
    LIKWID_BASE = '/usr/local/likwid'

26
27
    def __init__(self, ast: KernelFunction, machine: Optional[MachineModel] = None,
                 assumed_layout='SoA', debug_print=False, filename=None):
Julian Hammer's avatar
Julian Hammer committed
28
29
30
31
32
        """Create a kerncraft kernel using a pystencils AST

        Args:
            ast: pystencils ast
            machine: kerncraft machine model - specify this if kernel needs to be compiled
33
34
35
            assumed_layout: either 'SoA' or 'AoS' - if fields have symbolic sizes the layout of the index
                    coordinates is not known. In this case either a structures of array (SoA) or
                    array of structures (AoS) layout is assumed
Julian Hammer's avatar
Julian Hammer committed
36
        """
Julian Hammer's avatar
Julian Hammer committed
37
        kerncraft.kernel.Kernel.__init__(self, machine)
Martin Bauer's avatar
Martin Bauer committed
38

Julian Hammer's avatar
Julian Hammer committed
39
40
        # Initialize state
        self.asm_block = None
41
        self._filename = filename
Julian Hammer's avatar
Julian Hammer committed
42
43

        self.kernel_ast = ast
Martin Bauer's avatar
Martin Bauer committed
44
        self.temporary_dir = TemporaryDirectory()
Martin Bauer's avatar
Martin Bauer committed
45
46

        # Loops
47
48
        inner_loops = [l for l in filtered_tree_iteration(ast, LoopOverCoordinate, stop_type=SympyAssignment)
                       if l.is_innermost_loop]
Martin Bauer's avatar
Martin Bauer committed
49
        if len(inner_loops) == 0:
Martin Bauer's avatar
Martin Bauer committed
50
51
            raise ValueError("No loop found in pystencils AST")
        else:
52
53
54
            if len(inner_loops) > 1:
                warnings.warn("pystencils AST contains multiple inner loops. "
                              "Only one can be analyzed - choosing first one")
Martin Bauer's avatar
Martin Bauer committed
55
            inner_loop = inner_loops[0]
Martin Bauer's avatar
Martin Bauer committed
56
57

        self._loop_stack = []
Martin Bauer's avatar
Martin Bauer committed
58
59
60
        cur_node = inner_loop
        while cur_node is not None:
            if isinstance(cur_node, LoopOverCoordinate):
Martin Bauer's avatar
Martin Bauer committed
61
                loop_counter_sym = cur_node.loop_counter_symbol
62
                loop_info = (loop_counter_sym.name, cur_node.start, cur_node.stop, 1)
Julian Hammer's avatar
Julian Hammer committed
63
64
                # If the correct step were to be provided, all access within that step length will
                # also need to be passed to kerncraft: cur_node.step)
Martin Bauer's avatar
Martin Bauer committed
65
                self._loop_stack.append(loop_info)
Martin Bauer's avatar
Martin Bauer committed
66
            cur_node = cur_node.parent
Martin Bauer's avatar
Martin Bauer committed
67
68
69
        self._loop_stack = list(reversed(self._loop_stack))

        # Data sources & destinations
70
71
        self.sources = defaultdict(list)
        self.destinations = defaultdict(list)
Martin Bauer's avatar
Martin Bauer committed
72

Julian Hammer's avatar
Julian Hammer committed
73
74
75
76
77
78
79
80
81
        def get_layout_tuple(f):
            if f.has_fixed_shape:
                return get_layout_from_strides(f.strides)
            else:
                layout_list = list(f.layout)
                for _ in range(f.index_dimensions):
                    layout_list.insert(0 if assumed_layout == 'SoA' else -1, max(layout_list) + 1)
                return layout_list

Martin Bauer's avatar
Martin Bauer committed
82
        reads, writes = search_resolved_field_accesses_in_ast(inner_loop)
Martin Bauer's avatar
Martin Bauer committed
83
        for accesses, target_dict in [(reads, self.sources), (writes, self.destinations)]:
Martin Bauer's avatar
Martin Bauer committed
84
            for fa in accesses:
Martin Bauer's avatar
Martin Bauer committed
85
                coord = [sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i), positive=True, integer=True) + off
Martin Bauer's avatar
Martin Bauer committed
86
                         for i, off in enumerate(fa.offsets)]
Martin Bauer's avatar
Martin Bauer committed
87
                coord += list(fa.idx_coordinate_values)
Julian Hammer's avatar
Julian Hammer committed
88
89
                layout = get_layout_tuple(fa.field)
                permuted_coord = [sp.sympify(coord[i]) for i in layout]
Martin Bauer's avatar
Martin Bauer committed
90
                target_dict[fa.field.name].append(permuted_coord)
Martin Bauer's avatar
Martin Bauer committed
91
92

        # Variables (arrays)
Martin Bauer's avatar
Martin Bauer committed
93
94
        fields_accessed = ast.fields_accessed
        for field in fields_accessed:
Julian Hammer's avatar
Julian Hammer committed
95
            layout = get_layout_tuple(field)
Martin Bauer's avatar
Martin Bauer committed
96
97
            permuted_shape = list(field.shape[i] for i in layout)
            self.set_variable(field.name, str(field.dtype), tuple(permuted_shape))
Martin Bauer's avatar
Martin Bauer committed
98

99
100
101
102
103
        # Scalars may be safely ignored
        # for param in ast.get_parameters():
        #     if not param.is_field_parameter:
        #         # self.set_variable(param.symbol.name, str(param.symbol.dtype), None)
        #         self.sources[param.symbol.name] = [None]
Martin Bauer's avatar
Martin Bauer committed
104
105
106
107
108

        # data type
        self.datatype = list(self.variables.values())[0][0]

        # flops
Martin Bauer's avatar
Martin Bauer committed
109
        operation_count = count_operations_in_ast(inner_loop)
Martin Bauer's avatar
Martin Bauer committed
110
        self._flops = {
Martin Bauer's avatar
Martin Bauer committed
111
112
113
            '+': operation_count['adds'],
            '*': operation_count['muls'],
            '/': operation_count['divs'],
Martin Bauer's avatar
Martin Bauer committed
114
        }
Jan Hönig's avatar
Jan Hönig committed
115
116
        for k in [k for k, v in self._flops.items() if v == 0]:
            del self._flops[k]
Martin Bauer's avatar
Martin Bauer committed
117
118
        self.check()

119
120
121
122
123
124
125
126
127
128
129
        if debug_print:
            from pprint import pprint
            print("-----------------------------  Loop Stack --------------------------")
            pprint(self._loop_stack)
            print("-----------------------------  Sources -----------------------------")
            pprint(self.sources)
            print("-----------------------------  Destinations ------------------------")
            pprint(self.destinations)
            print("-----------------------------  FLOPS -------------------------------")
            pprint(self._flops)

Julian Hammer's avatar
Julian Hammer committed
130
131
132
133
134
135
136
137
    def as_code(self, type_='iaca', openmp=False):
        """
        Generate and return compilable source code.

        :param type: can be iaca or likwid.
        :param openmp: if true, openmp code will be generated
        """
        return generate_benchmark(self.kernel_ast, likwid=type_ == 'likwid')
138

Martin Bauer's avatar
Martin Bauer committed
139
140

class KerncraftParameters(DotDict):
Martin Bauer's avatar
Martin Bauer committed
141
142
    def __init__(self, **kwargs):
        super(KerncraftParameters, self).__init__(**kwargs)
Martin Bauer's avatar
Martin Bauer committed
143
144
145
146
147
        self['asm_block'] = 'auto'
        self['asm_increment'] = 0
        self['cores'] = 1
        self['cache_predictor'] = 'SIM'
        self['verbose'] = 0
Jan Hönig's avatar
Jan Hönig committed
148
        self['pointer_increment'] = 'auto'
Jan Hönig's avatar
Jan Hönig committed
149
        self['iterations'] = 10
Julian Hammer's avatar
Julian Hammer committed
150
151
        self['unit'] = 'cy/CL'
        self['ignore_warnings'] = True
Jan Hönig's avatar
Jan Hönig committed
152

Martin Bauer's avatar
Martin Bauer committed
153
154
155
156

# ------------------------------------------- Helper functions ---------------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
157
def search_resolved_field_accesses_in_ast(ast):
Martin Bauer's avatar
Martin Bauer committed
158
159
160
161
162
163
164
165
166
    def visit(node, reads, writes):
        if not isinstance(node, SympyAssignment):
            for a in node.args:
                visit(a, reads, writes)
            return

        for expr, accesses in [(node.lhs, writes), (node.rhs, reads)]:
            accesses.update(expr.atoms(ResolvedFieldAccess))

Martin Bauer's avatar
Martin Bauer committed
167
168
169
    read_accesses = set()
    write_accesses = set()
    visit(ast, read_accesses, write_accesses)
Martin Bauer's avatar
Martin Bauer committed
170
    return read_accesses, write_accesses