cpujit.py 23.3 KB
Newer Older
1
r"""
2

Martin Bauer's avatar
Martin Bauer committed
3
4
5
6
7
8
*pystencils* automatically searches for a compiler, so in most cases no explicit configuration is required.
On Linux make sure that 'gcc' and 'g++' are installed and in your path.
On Windows a recent Visual Studio installation is required.
In case anything does not work as expected or a special compiler should be used, changes can be specified
in a configuration file.

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
*pystencils* looks for a configuration file in JSON format at the following locations in the listed order.

1. at the path specified in the environment variable ``PYSTENCILS_CONFIG``
2. in the current working direction for a file named ``pystencils.json``
3. or in your home directory at ``~/.config/pystencils/config.json`` (Linux) or
   ``%HOMEPATH%\.pystencils\config.json`` (Windows)

If no configuration file is found, a default configuration is created at the above mentioned location in your home.
So run *pystencils* once, then edit the created configuration file.


Compiler Config (Linux)
-----------------------

- **'os'**: should be detected automatically as 'linux'
- **'command'**: path to C++ compiler (defaults to 'g++')
- **'flags'**: space separated list of compiler flags. Make sure to activate OpenMP in your compiler
Martin Bauer's avatar
Martin Bauer committed
26
- **'restrict_qualifier'**: the restrict qualifier is not standardized accross compilers.
27
28
29
30
31
32
33
34
35
36
  For most Linux compilers the qualifier is ``__restrict__``


Compiler Config (Windows)
-------------------------

*pystencils* uses the mechanism of *setuptools.msvc* to search for a compilation environment.
Then 'cl.exe' is used to compile.

- **'os'**: should be detected automatically as 'windows'
Martin Bauer's avatar
Martin Bauer committed
37
- **'msvc_version'**:  either a version number, year number, 'auto' or 'latest' for automatic detection of latest
38
39
  installed version or 'setuptools' for setuptools-based detection. Alternatively path to folder
  where Visual Studio is installed. This path has to contain a file called 'vcvarsall.bat'
40
41
- **'arch'**: 'x86' or 'x64'
- **'flags'**: flags passed to 'cl.exe', make sure OpenMP is activated
Martin Bauer's avatar
Martin Bauer committed
42
- **'restrict_qualifier'**: the restrict qualifier is not standardized across compilers.
43
44
45
  For Windows compilers the qualifier should be ``__restrict``

"""
46
import hashlib
47
import json
48
import os
Martin Bauer's avatar
Martin Bauer committed
49
import platform
50
import shutil
51
import subprocess
52
import textwrap
53
54
from collections import OrderedDict
from sysconfig import get_paths
55
56
from tempfile import TemporaryDirectory

57
import numpy as np
58
from appdirs import user_cache_dir, user_config_dir
59
60

from pystencils import FieldType
61
from pystencils.backends.cbackend import generate_c, get_headers
62
from pystencils.include import get_pystencils_include_path
63
from pystencils.kernel_wrapper import KernelWrapper
Martin Bauer's avatar
Martin Bauer committed
64
from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
65

66

67
def make_python_function(kernel_function_node, custom_backend=None):
68
69
70
71
72
73
74
    """
    Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function

    The parameters of the kernel are:
        - numpy arrays for each field used in the kernel. The keyword argument name is the name of the field
        - all symbols which are not defined in the kernel itself are expected as parameters

Martin Bauer's avatar
Martin Bauer committed
75
    :param kernel_function_node: the abstract syntax tree
76
77
    :return: kernel functor
    """
78
    result = compile_and_load(kernel_function_node, custom_backend)
79
    return result
80
81


Martin Bauer's avatar
Martin Bauer committed
82
def set_config(config):
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    """
    Override the configuration provided in config file

    Configuration of compiler parameters:
    If this function is not called the configuration is taken from a config file in JSON format which
    is searched in the following locations in the order specified:
        - at location provided in environment variable PYSTENCILS_CONFIG (if this variable exists)
        - a file called ".pystencils.json" in the current working directory
        - ~/.pystencils.json in your home
    If none of these files exist a file ~/.pystencils.json is created with a default configuration using
    the GNU 'g++'

    An example JSON file with all possible keys. If not all keys are specified, default values are used
    ``
    {
Martin Bauer's avatar
Martin Bauer committed
98
99
100
101
102
103
104
        'compiler' :
        {
            "command": "/software/intel/2017/bin/icpc",
            "flags": "-Ofast -DNDEBUG -fPIC -march=native -fopenmp",
            "env": {
                "LM_PROJECT": "iwia",
            }
105
        }
106
    }
107
108
    ``
    """
Martin Bauer's avatar
Martin Bauer committed
109
110
111
112
    global _config
    _config = config.copy()


Martin Bauer's avatar
Martin Bauer committed
113
114
def get_configuration_file_path():
    config_path_in_home = os.path.join(user_config_dir('pystencils'), 'config.json')
115

116
117
118
    # 1) Read path from environment variable if found
    if 'PYSTENCILS_CONFIG' in os.environ:
        return os.environ['PYSTENCILS_CONFIG'], True
Martin Bauer's avatar
Martin Bauer committed
119
120
121
    # 2) Look in current directory for pystencils.json
    elif os.path.exists("pystencils.json"):
        return "pystencils.json", True
122
    # 3) Try ~/.pystencils.json
Martin Bauer's avatar
Martin Bauer committed
123
124
    elif os.path.exists(config_path_in_home):
        return config_path_in_home, True
125
    else:
Martin Bauer's avatar
Martin Bauer committed
126
        return config_path_in_home, False
127
128


Martin Bauer's avatar
Martin Bauer committed
129
130
def create_folder(path, is_file):
    if is_file:
Martin Bauer's avatar
Martin Bauer committed
131
132
133
134
135
136
137
        path = os.path.split(path)[0]
    try:
        os.makedirs(path)
    except os.error:
        pass


138
139
140
141
142
143
144
145
146
147
def get_llc_command():
    """Try to get executable for llvm's IR compiler llc

    We try if one of the following is in PATH: llc, llc-10, llc-9, llc-8, llc-7, llc-6
    """
    candidates = ['llc', 'llc-10', 'llc-9', 'llc-8', 'llc-7', 'llc-6']
    found_executables = (e for e in candidates if shutil.which(e))
    return next(found_executables, None)


Martin Bauer's avatar
Martin Bauer committed
148
def read_config():
Martin Bauer's avatar
Martin Bauer committed
149
    if platform.system().lower() == 'linux':
Martin Bauer's avatar
Martin Bauer committed
150
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
151
152
            ('os', 'linux'),
            ('command', 'g++'),
153
            ('llc_command', get_llc_command() or 'llc'),
154
            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11'),
Martin Bauer's avatar
Martin Bauer committed
155
            ('restrict_qualifier', '__restrict__')
Martin Bauer's avatar
Martin Bauer committed
156
157
        ])
    elif platform.system().lower() == 'windows':
Martin Bauer's avatar
Martin Bauer committed
158
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
159
            ('os', 'windows'),
Martin Bauer's avatar
Martin Bauer committed
160
            ('msvc_version', 'latest'),
161
            ('llc_command', get_llc_command() or 'llc'),
Martin Bauer's avatar
Martin Bauer committed
162
            ('arch', 'x64'),
163
            ('flags', '/Ox /fp:fast /openmp /arch:avx'),
Martin Bauer's avatar
Martin Bauer committed
164
            ('restrict_qualifier', '__restrict')
Martin Bauer's avatar
Martin Bauer committed
165
        ])
Julian Hammer's avatar
Julian Hammer committed
166
167
168
169
    elif platform.system().lower() == 'darwin':
        default_compiler_config = OrderedDict([
            ('os', 'darwin'),
            ('command', 'clang++'),
170
            ('llc_command', get_llc_command() or 'llc'),
Michael Kuron's avatar
Michael Kuron committed
171
            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -Xclang -fopenmp -std=c++11'),
Julian Hammer's avatar
Julian Hammer committed
172
173
            ('restrict_qualifier', '__restrict__')
        ])
Martin Bauer's avatar
Martin Bauer committed
174
175
176
    default_cache_config = OrderedDict([
        ('object_cache', os.path.join(user_cache_dir('pystencils'), 'objectcache')),
        ('clear_cache_on_start', False),
177
    ])
Martin Bauer's avatar
Martin Bauer committed
178

Martin Bauer's avatar
Martin Bauer committed
179
180
    default_config = OrderedDict([('compiler', default_compiler_config),
                                  ('cache', default_cache_config)])
Martin Bauer's avatar
Martin Bauer committed
181

Martin Bauer's avatar
Martin Bauer committed
182
183
184
    config_path, config_exists = get_configuration_file_path()
    config = default_config.copy()
    if config_exists:
Martin Bauer's avatar
Martin Bauer committed
185
186
        with open(config_path, 'r') as json_config_file:
            loaded_config = json.load(json_config_file)
Martin Bauer's avatar
Martin Bauer committed
187
        config = recursive_dict_update(config, loaded_config)
188
    else:
Martin Bauer's avatar
Martin Bauer committed
189
        create_folder(config_path, True)
190
191
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=4)
Martin Bauer's avatar
Martin Bauer committed
192

193
194
    if config['cache']['object_cache'] is not False:
        config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid())
Martin Bauer's avatar
Martin Bauer committed
195

196
197
        if config['cache']['clear_cache_on_start']:
            clear_cache()
Martin Bauer's avatar
Martin Bauer committed
198

199
        create_folder(config['cache']['object_cache'], False)
Martin Bauer's avatar
Martin Bauer committed
200
201

    if config['compiler']['os'] == 'windows':
Martin Bauer's avatar
Martin Bauer committed
202
203
        from pystencils.cpu.msvc_detection import get_environment
        msvc_env = get_environment(config['compiler']['msvc_version'], config['compiler']['arch'])
204
205
        if 'env' not in config['compiler']:
            config['compiler']['env'] = {}
Martin Bauer's avatar
Martin Bauer committed
206
        config['compiler']['env'].update(msvc_env)
Martin Bauer's avatar
Martin Bauer committed
207

208
209
210
    return config


Martin Bauer's avatar
Martin Bauer committed
211
_config = read_config()
212
213


Martin Bauer's avatar
Martin Bauer committed
214
def get_compiler_config():
Martin Bauer's avatar
Martin Bauer committed
215
216
217
    return _config['compiler']


Martin Bauer's avatar
Martin Bauer committed
218
def get_cache_config():
Martin Bauer's avatar
Martin Bauer committed
219
    return _config['cache']
Michael Kuron's avatar
Michael Kuron committed
220
221


222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def add_or_change_compiler_flags(flags):
    if not isinstance(flags, list) and not isinstance(flags, tuple):
        flags = [flags]

    compiler_config = get_compiler_config()
    cache_config = get_cache_config()
    cache_config['object_cache'] = False  # disable cache

    for flag in flags:
        flag = flag.strip()
        if '=' in flag:
            base = flag.split('=')[0].strip()
        else:
            base = flag

        new_flags = [c for c in compiler_config['flags'].split() if not c.startswith(base)]
        new_flags.append(flag)
        compiler_config['flags'] = ' '.join(new_flags)


Martin Bauer's avatar
Martin Bauer committed
242
243
def clear_cache():
    cache_config = get_cache_config()
244
245
246
    if cache_config['object_cache'] is not False:
        shutil.rmtree(cache_config['object_cache'], ignore_errors=True)
        create_folder(cache_config['object_cache'], False)
Martin Bauer's avatar
Martin Bauer committed
247
248


249
250
251
252
253
254
255
256
257
type_mapping = {
    np.float32: ('PyFloat_AsDouble', 'float'),
    np.float64: ('PyFloat_AsDouble', 'double'),
    np.int16: ('PyLong_AsLong', 'int16_t'),
    np.int32: ('PyLong_AsLong', 'int32_t'),
    np.int64: ('PyLong_AsLong', 'int64_t'),
    np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
    np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
    np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
258
259
    np.complex64: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexFloat'),
    np.complex128: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexDouble'),
260
}
Martin Bauer's avatar
Martin Bauer committed
261
262


263
264
265
266
267
268
template_extract_scalar = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
{target_type} {name} = ({target_type}) {extract_function}( obj_{name} );
if( PyErr_Occurred() ) {{ return NULL; }}
"""
Martin Bauer's avatar
Martin Bauer committed
269

270
271
272
273
274
275
276
template_extract_complex = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
{target_type} {name}{{ {extract_function_real}( obj_{name} ), {extract_function_imag}( obj_{name} ) }};
if( PyErr_Occurred() ) {{ return NULL; }}
"""

277
278
279
280
template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
Py_buffer buffer_{name};
281
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT);
282
283
284
285
286
287
288
289
290
291
if (buffer_{name}_res == -1) {{ return NULL; }}
"""

template_release_buffer = """
PyBuffer_Release(&buffer_{name});
"""

template_function_boilerplate = """
static PyObject * {func_name}(PyObject * self, PyObject * args, PyObject * kwargs)
{{
Martin Bauer's avatar
Martin Bauer committed
292
293
294
295
    if( !kwargs || !PyDict_Check(kwargs) ) {{ 
        PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); 
        return NULL; 
    }}
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    {pre_call_code}
    kernel_{func_name}({parameters});
    {post_call_code}
    Py_RETURN_NONE;
}}
"""

template_check_array = """
if(!({cond})) {{ 
    PyErr_SetString(PyExc_ValueError, "Wrong {what} of array {name}. Expected {expected}"); 
    return NULL; 
}}
"""

template_size_check = """
if(!({cond})) {{ 
    PyErr_SetString(PyExc_TypeError, "Arrays must have same shape"); return NULL; 
}}"""

template_module_boilerplate = """
static PyMethodDef method_definitions[] = {{
    {method_definitions}
    {{NULL, NULL, 0, NULL}}
}};

static struct PyModuleDef module_definition = {{
    PyModuleDef_HEAD_INIT,
    "{module_name}",   /* name of module */
    NULL,     /* module documentation, may be NULL */
    -1,       /* size of per-interpreter state of the module,
                 or -1 if the module keeps state in global variables. */
    method_definitions
}};

PyMODINIT_FUNC
PyInit_{module_name}(void)
{{
    return PyModule_Create(&module_definition);
}}
"""


def equal_size_check(fields):
    fields = list(fields)
    if len(fields) <= 1:
        return ""

    ref_field = fields[0]
344
345
    cond = ["(buffer_{field.name}.shape[{i}] == buffer_{ref_field.name}.shape[{i}])".format(ref_field=ref_field,
                                                                                            field=field_to_test, i=i)
346
347
348
349
350
351
352
353
354
355
356
357
358
            for field_to_test in fields[1:]
            for i in range(fields[0].spatial_dimensions)]
    cond = " && ".join(cond)
    return template_size_check.format(cond=cond)


def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
    pre_call_code = ""
    parameters = []
    post_call_code = ""
    variable_sized_normal_fields = set()
    variable_sized_index_fields = set()

359
360
361
362
363
364
365
    for param in parameter_info:
        if param.is_field_pointer:
            field = param.fields[0]
            pre_call_code += template_extract_array.format(name=field.name)
            post_call_code += template_release_buffer.format(name=field.name)
            parameters.append("({dtype} *)buffer_{name}.buf".format(dtype=str(field.dtype), name=field.name))

366
367
368
369
            if insert_checks:
                np_dtype = field.dtype.numpy_dtype
                item_size = np_dtype.itemsize

370
371
                if (np_dtype.isbuiltin and FieldType.is_generic(field)
                        and not np.issubdtype(field.dtype.numpy_dtype, np.complexfloating)):
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
                    dtype_cond = "buffer_{name}.format[0] == '{format}'".format(name=field.name,
                                                                                format=field.dtype.numpy_dtype.char)
                    pre_call_code += template_check_array.format(cond=dtype_cond, what="data type", name=field.name,
                                                                 expected=str(field.dtype.numpy_dtype))

                item_size_cond = "buffer_{name}.itemsize == {size}".format(name=field.name, size=item_size)
                pre_call_code += template_check_array.format(cond=item_size_cond, what="itemsize", name=field.name,
                                                             expected=item_size)

                if field.has_fixed_shape:
                    shape_cond = ["buffer_{name}.shape[{i}] == {s}".format(s=s, name=field.name, i=i)
                                  for i, s in enumerate(field.spatial_shape)]
                    shape_cond = " && ".join(shape_cond)
                    pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=field.name,
                                                                 expected=str(field.shape))

                    expected_strides = [e * item_size for e in field.spatial_strides]
                    stride_check_code = "(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)"
                    strides_cond = " && ".join([stride_check_code.format(s=s, i=i, name=field.name)
                                                for i, s in enumerate(expected_strides)])
                    pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=field.name,
                                                                 expected=str(expected_strides))
                else:
                    if FieldType.is_generic(field):
                        variable_sized_normal_fields.add(field)
                    elif FieldType.is_indexed(field):
                        variable_sized_index_fields.add(field)
399
400
401
402
403
404
405
        elif param.is_field_stride:
            field = param.fields[0]
            item_size = field.dtype.numpy_dtype.itemsize
            parameters.append("buffer_{name}.strides[{i}] / {bytes}".format(bytes=item_size, i=param.symbol.coordinate,
                                                                            name=field.name))
        elif param.is_field_shape:
            parameters.append("buffer_{name}.shape[{i}]".format(i=param.symbol.coordinate, name=param.field_name))
Martin Bauer's avatar
Martin Bauer committed
406
        else:
407
            extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
408
409
410
411
412
413
414
415
416
417
            if np.issubdtype(param.symbol.dtype.numpy_dtype, np.complexfloating):
                pre_call_code += template_extract_complex.format(extract_function_real=extract_function[0],
                                                                 extract_function_imag=extract_function[1],
                                                                 target_type=target_type,
                                                                 name=param.symbol.name)
            else:
                pre_call_code += template_extract_scalar.format(extract_function=extract_function,
                                                                target_type=target_type,
                                                                name=param.symbol.name)

418
            parameters.append(param.symbol.name)
Martin Bauer's avatar
Martin Bauer committed
419

420
421
    pre_call_code += equal_size_check(variable_sized_normal_fields)
    pre_call_code += equal_size_check(variable_sized_index_fields)
Martin Bauer's avatar
Martin Bauer committed
422

423
424
425
426
    pre_call_code = textwrap.indent(pre_call_code, '    ')
    post_call_code = textwrap.indent(post_call_code, '    ')
    return template_function_boilerplate.format(func_name=name, pre_call_code=pre_call_code,
                                                post_call_code=post_call_code, parameters=", ".join(parameters))
Martin Bauer's avatar
Martin Bauer committed
427
428


429
430
431
432
def create_module_boilerplate_code(module_name, names):
    method_definition = '{{"{name}", (PyCFunction){name}, METH_VARARGS | METH_KEYWORDS, ""}},'
    method_definitions = "\n".join([method_definition.format(name=name) for name in names])
    return template_module_boilerplate.format(module_name=module_name, method_definitions=method_definitions)
Martin Bauer's avatar
Martin Bauer committed
433

434

435
436
def load_kernel_from_file(module_name, function_name, path):
    from importlib.util import spec_from_file_location, module_from_spec
437
438
439
440
441
442
443
444
445
446
447
448
449
    try:
        spec = spec_from_file_location(name=module_name, location=path)
        mod = module_from_spec(spec)
        spec.loader.exec_module(mod)
    except ImportError:
        import time
        import warnings
        warnings.warn("Could not load " + path + ", trying on more time...")
        time.sleep(1)
        spec = spec_from_file_location(name=module_name, location=path)
        mod = module_from_spec(spec)
        spec.loader.exec_module(mod)

450
    return getattr(mod, function_name)
451

Martin Bauer's avatar
Martin Bauer committed
452

Martin Bauer's avatar
Martin Bauer committed
453
454
455
456
457
def run_compile_step(command):
    compiler_config = get_compiler_config()
    config_env = compiler_config['env'] if 'env' in compiler_config else {}
    compile_environment = os.environ.copy()
    compile_environment.update(config_env)
Martin Bauer's avatar
Martin Bauer committed
458

Michael Kuron's avatar
Michael Kuron committed
459
    try:
Martin Bauer's avatar
Martin Bauer committed
460
461
        shell = True if compiler_config['os'].lower() == 'windows' else False
        subprocess.check_output(command, env=compile_environment, stderr=subprocess.STDOUT, shell=shell)
Michael Kuron's avatar
Michael Kuron committed
462
    except subprocess.CalledProcessError as e:
Martin Bauer's avatar
Martin Bauer committed
463
        print(" ".join(command))
464
        print(e.output.decode('utf8'))
Michael Kuron's avatar
Michael Kuron committed
465
        raise e
466
467


468
class ExtensionModuleCode:
469
    def __init__(self, module_name='generated', custom_backend=None):
470
        self.module_name = module_name
Martin Bauer's avatar
Martin Bauer committed
471

472
473
        self._ast_nodes = []
        self._function_names = []
474
        self._custom_backend = custom_backend
Martin Bauer's avatar
Martin Bauer committed
475

476
477
478
    def add_function(self, ast, name=None):
        self._ast_nodes.append(ast)
        self._function_names.append(name if name is not None else ast.function_name)
Martin Bauer's avatar
Martin Bauer committed
479

480
    def write_to_file(self, restrict_qualifier, function_prefix, file):
481
        headers = {'<math.h>', '<stdint.h>'}
482
483
        for ast in self._ast_nodes:
            headers.update(get_headers(ast))
484
485
486
        header_list = list(headers)
        header_list.sort()
        header_list.insert(0, '"Python.h"')
487

488
        includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
489
490
491
492
493
        print(includes, file=file)
        print("\n", file=file)
        print("#define RESTRICT %s" % (restrict_qualifier,), file=file)
        print("#define FUNC_PREFIX %s" % (function_prefix,), file=file)
        print("\n", file=file)
494

495
496
497
        for ast, name in zip(self._ast_nodes, self._function_names):
            old_name = ast.function_name
            ast.function_name = "kernel_" + name
498
            print(generate_c(ast, custom_backend=self._custom_backend), file=file)
499
            print(create_function_boilerplate_code(ast.get_parameters(), name), file=file)
500
501
            ast.function_name = old_name
        print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
Martin Bauer's avatar
Martin Bauer committed
502
503


504
def compile_module(code, code_hash, base_dir):
505
    compiler_config = get_compiler_config()
506
    extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()]
507

508
509
    if compiler_config['os'].lower() == 'windows':
        function_prefix = '__declspec(dllexport)'
510
        lib_suffix = '.pyd'
511
512
513
514
515
516
517
518
        object_suffix = '.obj'
        windows = True
    else:
        function_prefix = ''
        lib_suffix = '.so'
        object_suffix = '.o'
        windows = False

519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
    src_file = os.path.join(base_dir, code_hash + ".cpp")
    lib_file = os.path.join(base_dir, code_hash + lib_suffix)
    object_file = os.path.join(base_dir, code_hash + object_suffix)

    if not os.path.exists(object_file):
        with file_handle_for_atomic_write(src_file) as f:
            code.write_to_file(compiler_config['restrict_qualifier'], function_prefix, f)

        if windows:
            compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split()
            compile_cmd += [*extra_flags, src_file, '/Fo' + object_file]
            run_compile_step(compile_cmd)
        else:
            with atomic_file_write(object_file) as file_name:
                compile_cmd = [compiler_config['command'], '-c'] + compiler_config['flags'].split()
                compile_cmd += [*extra_flags, '-o', file_name, src_file]
                run_compile_step(compile_cmd)

        # Linking
        if windows:
            import sysconfig
            config_vars = sysconfig.get_config_vars()
            py_lib = os.path.join(config_vars["installed_base"], "libs",
                                  "python{}.lib".format(config_vars["py_version_nodot"]))
            run_compile_step(['link.exe', py_lib, '/DLL', '/out:' + lib_file, object_file])
Michael Kuron's avatar
Michael Kuron committed
544
545
546
547
548
        elif platform.system().lower() == 'darwin':
            with atomic_file_write(lib_file) as file_name:
                run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name, '-undefined',
                                  'dynamic_lookup']
                                 + compiler_config['flags'].split())
549
550
        else:
            with atomic_file_write(lib_file) as file_name:
Martin Bauer's avatar
Martin Bauer committed
551
552
                run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name]
                                 + compiler_config['flags'].split())
553
554
555
    return lib_file


556
def compile_and_load(ast, custom_backend=None):
557
    cache_config = get_cache_config()
558
559
560
    code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c',
                                                       custom_backend=custom_backend).encode()).hexdigest()
    code = ExtensionModuleCode(module_name=code_hash_str, custom_backend=custom_backend)
561
    code.add_function(ast, ast.function_name)
562
563
564
565
566
567
568
569
570

    if cache_config['object_cache'] is False:
        with TemporaryDirectory() as base_dir:
            lib_file = compile_module(code, code_hash_str, base_dir)
            result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)
    else:
        lib_file = compile_module(code, code_hash_str, base_dir=cache_config['object_cache'])
        result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)

571
    return KernelWrapper(result, ast.get_parameters(), ast)