cpujit.py 21.1 KB
Newer Older
1
r"""
2

Martin Bauer's avatar
Martin Bauer committed
3
4
5
6
7
8
*pystencils* automatically searches for a compiler, so in most cases no explicit configuration is required.
On Linux make sure that 'gcc' and 'g++' are installed and in your path.
On Windows a recent Visual Studio installation is required.
In case anything does not work as expected or a special compiler should be used, changes can be specified
in a configuration file.

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
*pystencils* looks for a configuration file in JSON format at the following locations in the listed order.

1. at the path specified in the environment variable ``PYSTENCILS_CONFIG``
2. in the current working direction for a file named ``pystencils.json``
3. or in your home directory at ``~/.config/pystencils/config.json`` (Linux) or
   ``%HOMEPATH%\.pystencils\config.json`` (Windows)

If no configuration file is found, a default configuration is created at the above mentioned location in your home.
So run *pystencils* once, then edit the created configuration file.


Compiler Config (Linux)
-----------------------

- **'os'**: should be detected automatically as 'linux'
- **'command'**: path to C++ compiler (defaults to 'g++')
- **'flags'**: space separated list of compiler flags. Make sure to activate OpenMP in your compiler
Martin Bauer's avatar
Martin Bauer committed
26
- **'restrict_qualifier'**: the restrict qualifier is not standardized accross compilers.
27
28
29
30
31
32
33
34
35
36
  For most Linux compilers the qualifier is ``__restrict__``


Compiler Config (Windows)
-------------------------

*pystencils* uses the mechanism of *setuptools.msvc* to search for a compilation environment.
Then 'cl.exe' is used to compile.

- **'os'**: should be detected automatically as 'windows'
Martin Bauer's avatar
Martin Bauer committed
37
- **'msvc_version'**:  either a version number, year number, 'auto' or 'latest' for automatic detection of latest
38
39
  installed version or 'setuptools' for setuptools-based detection. Alternatively path to folder
  where Visual Studio is installed. This path has to contain a file called 'vcvarsall.bat'
40
41
- **'arch'**: 'x86' or 'x64'
- **'flags'**: flags passed to 'cl.exe', make sure OpenMP is activated
Martin Bauer's avatar
Martin Bauer committed
42
- **'restrict_qualifier'**: the restrict qualifier is not standardized across compilers.
43
44
45
  For Windows compilers the qualifier should be ``__restrict``

"""
46
import hashlib
47
import json
Martin Bauer's avatar
Martin Bauer committed
48
import os
Martin Bauer's avatar
Martin Bauer committed
49
import platform
50
import shutil
Martin Bauer's avatar
Martin Bauer committed
51
import subprocess
52
import textwrap
Martin Bauer's avatar
Martin Bauer committed
53
54
from collections import OrderedDict
from sysconfig import get_paths
55
56
from tempfile import TemporaryDirectory

57
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
58
from appdirs import user_cache_dir, user_config_dir
59

60
from pystencils.backends.cbackend import generate_c, get_headers
Martin Bauer's avatar
Martin Bauer committed
61
from pystencils.field import FieldType
62
from pystencils.include import get_pystencils_include_path
Martin Bauer's avatar
Martin Bauer committed
63
from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
64

65

Martin Bauer's avatar
Martin Bauer committed
66
def make_python_function(kernel_function_node):
67
68
69
70
71
72
73
    """
    Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function

    The parameters of the kernel are:
        - numpy arrays for each field used in the kernel. The keyword argument name is the name of the field
        - all symbols which are not defined in the kernel itself are expected as parameters

Martin Bauer's avatar
Martin Bauer committed
74
    :param kernel_function_node: the abstract syntax tree
75
76
    :return: kernel functor
    """
77
78
    result = compile_and_load(kernel_function_node)
    return result
79
80


Martin Bauer's avatar
Martin Bauer committed
81
def set_config(config):
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    """
    Override the configuration provided in config file

    Configuration of compiler parameters:
    If this function is not called the configuration is taken from a config file in JSON format which
    is searched in the following locations in the order specified:
        - at location provided in environment variable PYSTENCILS_CONFIG (if this variable exists)
        - a file called ".pystencils.json" in the current working directory
        - ~/.pystencils.json in your home
    If none of these files exist a file ~/.pystencils.json is created with a default configuration using
    the GNU 'g++'

    An example JSON file with all possible keys. If not all keys are specified, default values are used
    ``
    {
Martin Bauer's avatar
Martin Bauer committed
97
98
99
100
101
102
103
        'compiler' :
        {
            "command": "/software/intel/2017/bin/icpc",
            "flags": "-Ofast -DNDEBUG -fPIC -march=native -fopenmp",
            "env": {
                "LM_PROJECT": "iwia",
            }
104
        }
105
    }
106
107
    ``
    """
Martin Bauer's avatar
Martin Bauer committed
108
109
110
111
    global _config
    _config = config.copy()


Martin Bauer's avatar
Martin Bauer committed
112
113
def get_configuration_file_path():
    config_path_in_home = os.path.join(user_config_dir('pystencils'), 'config.json')
114

115
116
117
    # 1) Read path from environment variable if found
    if 'PYSTENCILS_CONFIG' in os.environ:
        return os.environ['PYSTENCILS_CONFIG'], True
Martin Bauer's avatar
Martin Bauer committed
118
119
120
    # 2) Look in current directory for pystencils.json
    elif os.path.exists("pystencils.json"):
        return "pystencils.json", True
121
    # 3) Try ~/.pystencils.json
Martin Bauer's avatar
Martin Bauer committed
122
123
    elif os.path.exists(config_path_in_home):
        return config_path_in_home, True
124
    else:
Martin Bauer's avatar
Martin Bauer committed
125
        return config_path_in_home, False
126
127


Martin Bauer's avatar
Martin Bauer committed
128
129
def create_folder(path, is_file):
    if is_file:
Martin Bauer's avatar
Martin Bauer committed
130
131
132
133
134
135
136
        path = os.path.split(path)[0]
    try:
        os.makedirs(path)
    except os.error:
        pass


Martin Bauer's avatar
Martin Bauer committed
137
def read_config():
Martin Bauer's avatar
Martin Bauer committed
138
    if platform.system().lower() == 'linux':
Martin Bauer's avatar
Martin Bauer committed
139
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
140
141
            ('os', 'linux'),
            ('command', 'g++'),
142
            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11'),
Martin Bauer's avatar
Martin Bauer committed
143
            ('restrict_qualifier', '__restrict__')
Martin Bauer's avatar
Martin Bauer committed
144
145
        ])
    elif platform.system().lower() == 'windows':
Martin Bauer's avatar
Martin Bauer committed
146
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
147
            ('os', 'windows'),
Martin Bauer's avatar
Martin Bauer committed
148
            ('msvc_version', 'latest'),
Martin Bauer's avatar
Martin Bauer committed
149
            ('arch', 'x64'),
150
            ('flags', '/Ox /fp:fast /openmp /arch:avx'),
Martin Bauer's avatar
Martin Bauer committed
151
            ('restrict_qualifier', '__restrict')
Martin Bauer's avatar
Martin Bauer committed
152
        ])
Julian Hammer's avatar
Julian Hammer committed
153
154
155
156
157
158
159
    elif platform.system().lower() == 'darwin':
        default_compiler_config = OrderedDict([
            ('os', 'darwin'),
            ('command', 'clang++'),
            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11'),
            ('restrict_qualifier', '__restrict__')
        ])
Martin Bauer's avatar
Martin Bauer committed
160
161
162
    default_cache_config = OrderedDict([
        ('object_cache', os.path.join(user_cache_dir('pystencils'), 'objectcache')),
        ('clear_cache_on_start', False),
163
    ])
Martin Bauer's avatar
Martin Bauer committed
164

Martin Bauer's avatar
Martin Bauer committed
165
166
    default_config = OrderedDict([('compiler', default_compiler_config),
                                  ('cache', default_cache_config)])
Martin Bauer's avatar
Martin Bauer committed
167

Martin Bauer's avatar
Martin Bauer committed
168
169
170
    config_path, config_exists = get_configuration_file_path()
    config = default_config.copy()
    if config_exists:
Martin Bauer's avatar
Martin Bauer committed
171
172
        with open(config_path, 'r') as json_config_file:
            loaded_config = json.load(json_config_file)
Martin Bauer's avatar
Martin Bauer committed
173
        config = recursive_dict_update(config, loaded_config)
174
    else:
Martin Bauer's avatar
Martin Bauer committed
175
176
        create_folder(config_path, True)
        json.dump(config, open(config_path, 'w'), indent=4)
Martin Bauer's avatar
Martin Bauer committed
177

178
179
    if config['cache']['object_cache'] is not False:
        config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid())
Martin Bauer's avatar
Martin Bauer committed
180

181
182
        if config['cache']['clear_cache_on_start']:
            clear_cache()
Martin Bauer's avatar
Martin Bauer committed
183

184
        create_folder(config['cache']['object_cache'], False)
Martin Bauer's avatar
Martin Bauer committed
185
186

    if config['compiler']['os'] == 'windows':
Martin Bauer's avatar
Martin Bauer committed
187
188
        from pystencils.cpu.msvc_detection import get_environment
        msvc_env = get_environment(config['compiler']['msvc_version'], config['compiler']['arch'])
189
190
        if 'env' not in config['compiler']:
            config['compiler']['env'] = {}
Martin Bauer's avatar
Martin Bauer committed
191
        config['compiler']['env'].update(msvc_env)
Martin Bauer's avatar
Martin Bauer committed
192

193
194
195
    return config


Martin Bauer's avatar
Martin Bauer committed
196
_config = read_config()
197
198


Martin Bauer's avatar
Martin Bauer committed
199
def get_compiler_config():
Martin Bauer's avatar
Martin Bauer committed
200
201
202
    return _config['compiler']


Martin Bauer's avatar
Martin Bauer committed
203
def get_cache_config():
Martin Bauer's avatar
Martin Bauer committed
204
    return _config['cache']
Michael Kuron's avatar
Michael Kuron committed
205
206


207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def add_or_change_compiler_flags(flags):
    if not isinstance(flags, list) and not isinstance(flags, tuple):
        flags = [flags]

    compiler_config = get_compiler_config()
    cache_config = get_cache_config()
    cache_config['object_cache'] = False  # disable cache

    for flag in flags:
        flag = flag.strip()
        if '=' in flag:
            base = flag.split('=')[0].strip()
        else:
            base = flag

        new_flags = [c for c in compiler_config['flags'].split() if not c.startswith(base)]
        new_flags.append(flag)
        compiler_config['flags'] = ' '.join(new_flags)


Martin Bauer's avatar
Martin Bauer committed
227
228
def clear_cache():
    cache_config = get_cache_config()
229
230
231
    if cache_config['object_cache'] is not False:
        shutil.rmtree(cache_config['object_cache'], ignore_errors=True)
        create_folder(cache_config['object_cache'], False)
Martin Bauer's avatar
Martin Bauer committed
232
233


234
235
236
237
238
239
240
241
242
243
type_mapping = {
    np.float32: ('PyFloat_AsDouble', 'float'),
    np.float64: ('PyFloat_AsDouble', 'double'),
    np.int16: ('PyLong_AsLong', 'int16_t'),
    np.int32: ('PyLong_AsLong', 'int32_t'),
    np.int64: ('PyLong_AsLong', 'int64_t'),
    np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
    np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
    np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
}
Martin Bauer's avatar
Martin Bauer committed
244
245


246
247
248
249
250
251
template_extract_scalar = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
{target_type} {name} = ({target_type}) {extract_function}( obj_{name} );
if( PyErr_Occurred() ) {{ return NULL; }}
"""
Martin Bauer's avatar
Martin Bauer committed
252

253
254
255
256
template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
Py_buffer buffer_{name};
257
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT);
258
259
260
261
262
263
264
265
266
267
if (buffer_{name}_res == -1) {{ return NULL; }}
"""

template_release_buffer = """
PyBuffer_Release(&buffer_{name});
"""

template_function_boilerplate = """
static PyObject * {func_name}(PyObject * self, PyObject * args, PyObject * kwargs)
{{
Martin Bauer's avatar
Martin Bauer committed
268
269
270
271
    if( !kwargs || !PyDict_Check(kwargs) ) {{ 
        PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); 
        return NULL; 
    }}
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
    {pre_call_code}
    kernel_{func_name}({parameters});
    {post_call_code}
    Py_RETURN_NONE;
}}
"""

template_check_array = """
if(!({cond})) {{ 
    PyErr_SetString(PyExc_ValueError, "Wrong {what} of array {name}. Expected {expected}"); 
    return NULL; 
}}
"""

template_size_check = """
if(!({cond})) {{ 
    PyErr_SetString(PyExc_TypeError, "Arrays must have same shape"); return NULL; 
}}"""

template_module_boilerplate = """
static PyMethodDef method_definitions[] = {{
    {method_definitions}
    {{NULL, NULL, 0, NULL}}
}};

static struct PyModuleDef module_definition = {{
    PyModuleDef_HEAD_INIT,
    "{module_name}",   /* name of module */
    NULL,     /* module documentation, may be NULL */
    -1,       /* size of per-interpreter state of the module,
                 or -1 if the module keeps state in global variables. */
    method_definitions
}};

PyMODINIT_FUNC
PyInit_{module_name}(void)
{{
    return PyModule_Create(&module_definition);
}}
"""


def equal_size_check(fields):
    fields = list(fields)
    if len(fields) <= 1:
        return ""

    ref_field = fields[0]
320
321
    cond = ["(buffer_{field.name}.shape[{i}] == buffer_{ref_field.name}.shape[{i}])".format(ref_field=ref_field,
                                                                                            field=field_to_test, i=i)
322
323
324
325
326
327
328
329
330
331
332
333
334
            for field_to_test in fields[1:]
            for i in range(fields[0].spatial_dimensions)]
    cond = " && ".join(cond)
    return template_size_check.format(cond=cond)


def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
    pre_call_code = ""
    parameters = []
    post_call_code = ""
    variable_sized_normal_fields = set()
    variable_sized_index_fields = set()

335
336
337
338
339
340
341
    for param in parameter_info:
        if param.is_field_pointer:
            field = param.fields[0]
            pre_call_code += template_extract_array.format(name=field.name)
            post_call_code += template_release_buffer.format(name=field.name)
            parameters.append("({dtype} *)buffer_{name}.buf".format(dtype=str(field.dtype), name=field.name))

342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
            if insert_checks:
                np_dtype = field.dtype.numpy_dtype
                item_size = np_dtype.itemsize

                if np_dtype.isbuiltin and FieldType.is_generic(field):
                    dtype_cond = "buffer_{name}.format[0] == '{format}'".format(name=field.name,
                                                                                format=field.dtype.numpy_dtype.char)
                    pre_call_code += template_check_array.format(cond=dtype_cond, what="data type", name=field.name,
                                                                 expected=str(field.dtype.numpy_dtype))

                item_size_cond = "buffer_{name}.itemsize == {size}".format(name=field.name, size=item_size)
                pre_call_code += template_check_array.format(cond=item_size_cond, what="itemsize", name=field.name,
                                                             expected=item_size)

                if field.has_fixed_shape:
                    shape_cond = ["buffer_{name}.shape[{i}] == {s}".format(s=s, name=field.name, i=i)
                                  for i, s in enumerate(field.spatial_shape)]
                    shape_cond = " && ".join(shape_cond)
                    pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=field.name,
                                                                 expected=str(field.shape))

                    expected_strides = [e * item_size for e in field.spatial_strides]
                    stride_check_code = "(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)"
                    strides_cond = " && ".join([stride_check_code.format(s=s, i=i, name=field.name)
                                                for i, s in enumerate(expected_strides)])
                    pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=field.name,
                                                                 expected=str(expected_strides))
                else:
                    if FieldType.is_generic(field):
                        variable_sized_normal_fields.add(field)
                    elif FieldType.is_indexed(field):
                        variable_sized_index_fields.add(field)
374
375
376
377
378
379
380
        elif param.is_field_stride:
            field = param.fields[0]
            item_size = field.dtype.numpy_dtype.itemsize
            parameters.append("buffer_{name}.strides[{i}] / {bytes}".format(bytes=item_size, i=param.symbol.coordinate,
                                                                            name=field.name))
        elif param.is_field_shape:
            parameters.append("buffer_{name}.shape[{i}]".format(i=param.symbol.coordinate, name=param.field_name))
Martin Bauer's avatar
Martin Bauer committed
381
        else:
382
            extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
383
            pre_call_code += template_extract_scalar.format(extract_function=extract_function, target_type=target_type,
384
385
                                                            name=param.symbol.name)
            parameters.append(param.symbol.name)
Martin Bauer's avatar
Martin Bauer committed
386

387
388
    pre_call_code += equal_size_check(variable_sized_normal_fields)
    pre_call_code += equal_size_check(variable_sized_index_fields)
Martin Bauer's avatar
Martin Bauer committed
389

390
391
392
393
    pre_call_code = textwrap.indent(pre_call_code, '    ')
    post_call_code = textwrap.indent(post_call_code, '    ')
    return template_function_boilerplate.format(func_name=name, pre_call_code=pre_call_code,
                                                post_call_code=post_call_code, parameters=", ".join(parameters))
Martin Bauer's avatar
Martin Bauer committed
394
395


396
397
398
399
def create_module_boilerplate_code(module_name, names):
    method_definition = '{{"{name}", (PyCFunction){name}, METH_VARARGS | METH_KEYWORDS, ""}},'
    method_definitions = "\n".join([method_definition.format(name=name) for name in names])
    return template_module_boilerplate.format(module_name=module_name, method_definitions=method_definitions)
Martin Bauer's avatar
Martin Bauer committed
400

401

402
403
def load_kernel_from_file(module_name, function_name, path):
    from importlib.util import spec_from_file_location, module_from_spec
404
405
406
407
408
409
410
411
412
413
414
415
416
    try:
        spec = spec_from_file_location(name=module_name, location=path)
        mod = module_from_spec(spec)
        spec.loader.exec_module(mod)
    except ImportError:
        import time
        import warnings
        warnings.warn("Could not load " + path + ", trying on more time...")
        time.sleep(1)
        spec = spec_from_file_location(name=module_name, location=path)
        mod = module_from_spec(spec)
        spec.loader.exec_module(mod)

417
    return getattr(mod, function_name)
418

Martin Bauer's avatar
Martin Bauer committed
419

Martin Bauer's avatar
Martin Bauer committed
420
421
422
423
424
def run_compile_step(command):
    compiler_config = get_compiler_config()
    config_env = compiler_config['env'] if 'env' in compiler_config else {}
    compile_environment = os.environ.copy()
    compile_environment.update(config_env)
Martin Bauer's avatar
Martin Bauer committed
425

Michael Kuron's avatar
Michael Kuron committed
426
    try:
Martin Bauer's avatar
Martin Bauer committed
427
428
        shell = True if compiler_config['os'].lower() == 'windows' else False
        subprocess.check_output(command, env=compile_environment, stderr=subprocess.STDOUT, shell=shell)
Michael Kuron's avatar
Michael Kuron committed
429
    except subprocess.CalledProcessError as e:
Martin Bauer's avatar
Martin Bauer committed
430
        print(" ".join(command))
431
        print(e.output.decode('utf8'))
Michael Kuron's avatar
Michael Kuron committed
432
        raise e
433
434


435
436
437
class ExtensionModuleCode:
    def __init__(self, module_name='generated'):
        self.module_name = module_name
Martin Bauer's avatar
Martin Bauer committed
438

439
440
        self._ast_nodes = []
        self._function_names = []
Martin Bauer's avatar
Martin Bauer committed
441

442
443
444
    def add_function(self, ast, name=None):
        self._ast_nodes.append(ast)
        self._function_names.append(name if name is not None else ast.function_name)
Martin Bauer's avatar
Martin Bauer committed
445

446
    def write_to_file(self, restrict_qualifier, function_prefix, file):
447
        headers = {'<math.h>', '<stdint.h>'}
448
449
        for ast in self._ast_nodes:
            headers.update(get_headers(ast))
450
451
452
        header_list = list(headers)
        header_list.sort()
        header_list.insert(0, '"Python.h"')
453

454
        includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
455
456
457
458
459
        print(includes, file=file)
        print("\n", file=file)
        print("#define RESTRICT %s" % (restrict_qualifier,), file=file)
        print("#define FUNC_PREFIX %s" % (function_prefix,), file=file)
        print("\n", file=file)
460

461
462
463
464
        for ast, name in zip(self._ast_nodes, self._function_names):
            old_name = ast.function_name
            ast.function_name = "kernel_" + name
            print(generate_c(ast), file=file)
465
            print(create_function_boilerplate_code(ast.get_parameters(), name), file=file)
466
467
            ast.function_name = old_name
        print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
Martin Bauer's avatar
Martin Bauer committed
468
469


470
471
472
473
474
class KernelWrapper:
    def __init__(self, kernel, parameters, ast_node):
        self.kernel = kernel
        self.parameters = parameters
        self.ast = ast_node
Martin Bauer's avatar
Martin Bauer committed
475

476
477
    def __call__(self, **kwargs):
        return self.kernel(**kwargs)
478

479

480
def compile_module(code, code_hash, base_dir):
481
    compiler_config = get_compiler_config()
482
    extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()]
483

484
485
    if compiler_config['os'].lower() == 'windows':
        function_prefix = '__declspec(dllexport)'
486
        lib_suffix = '.pyd'
487
488
489
490
491
492
493
494
        object_suffix = '.obj'
        windows = True
    else:
        function_prefix = ''
        lib_suffix = '.so'
        object_suffix = '.o'
        windows = False

495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    src_file = os.path.join(base_dir, code_hash + ".cpp")
    lib_file = os.path.join(base_dir, code_hash + lib_suffix)
    object_file = os.path.join(base_dir, code_hash + object_suffix)

    if not os.path.exists(object_file):
        with file_handle_for_atomic_write(src_file) as f:
            code.write_to_file(compiler_config['restrict_qualifier'], function_prefix, f)

        if windows:
            compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split()
            compile_cmd += [*extra_flags, src_file, '/Fo' + object_file]
            run_compile_step(compile_cmd)
        else:
            with atomic_file_write(object_file) as file_name:
                compile_cmd = [compiler_config['command'], '-c'] + compiler_config['flags'].split()
                compile_cmd += [*extra_flags, '-o', file_name, src_file]
                run_compile_step(compile_cmd)

        # Linking
        if windows:
            import sysconfig
            config_vars = sysconfig.get_config_vars()
            py_lib = os.path.join(config_vars["installed_base"], "libs",
                                  "python{}.lib".format(config_vars["py_version_nodot"]))
            run_compile_step(['link.exe', py_lib, '/DLL', '/out:' + lib_file, object_file])
        else:
            with atomic_file_write(lib_file) as file_name:
Martin Bauer's avatar
Martin Bauer committed
522
523
                run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name]
                                 + compiler_config['flags'].split())
524
525
526
527
528
    return lib_file


def compile_and_load(ast):
    cache_config = get_cache_config()
529
    code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c').encode()).hexdigest()
530
531
    code = ExtensionModuleCode(module_name=code_hash_str)
    code.add_function(ast, ast.function_name)
532
533
534
535
536
537
538
539
540

    if cache_config['object_cache'] is False:
        with TemporaryDirectory() as base_dir:
            lib_file = compile_module(code, code_hash_str, base_dir)
            result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)
    else:
        lib_file = compile_module(code, code_hash_str, base_dir=cache_config['object_cache'])
        result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)

541
    return KernelWrapper(result, ast.get_parameters(), ast)