cpujit.py 23.4 KB
Newer Older
1
r"""
2

Martin Bauer's avatar
Martin Bauer committed
3
4
5
6
7
8
*pystencils* automatically searches for a compiler, so in most cases no explicit configuration is required.
On Linux make sure that 'gcc' and 'g++' are installed and in your path.
On Windows a recent Visual Studio installation is required.
In case anything does not work as expected or a special compiler should be used, changes can be specified
in a configuration file.

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
*pystencils* looks for a configuration file in JSON format at the following locations in the listed order.

1. at the path specified in the environment variable ``PYSTENCILS_CONFIG``
2. in the current working direction for a file named ``pystencils.json``
3. or in your home directory at ``~/.config/pystencils/config.json`` (Linux) or
   ``%HOMEPATH%\.pystencils\config.json`` (Windows)

If no configuration file is found, a default configuration is created at the above mentioned location in your home.
So run *pystencils* once, then edit the created configuration file.


Compiler Config (Linux)
-----------------------

- **'os'**: should be detected automatically as 'linux'
- **'command'**: path to C++ compiler (defaults to 'g++')
- **'flags'**: space separated list of compiler flags. Make sure to activate OpenMP in your compiler
Martin Bauer's avatar
Martin Bauer committed
26
- **'restrict_qualifier'**: the restrict qualifier is not standardized accross compilers.
27
28
29
30
31
32
33
34
35
36
  For most Linux compilers the qualifier is ``__restrict__``


Compiler Config (Windows)
-------------------------

*pystencils* uses the mechanism of *setuptools.msvc* to search for a compilation environment.
Then 'cl.exe' is used to compile.

- **'os'**: should be detected automatically as 'windows'
Martin Bauer's avatar
Martin Bauer committed
37
- **'msvc_version'**:  either a version number, year number, 'auto' or 'latest' for automatic detection of latest
38
39
  installed version or 'setuptools' for setuptools-based detection. Alternatively path to folder
  where Visual Studio is installed. This path has to contain a file called 'vcvarsall.bat'
40
41
- **'arch'**: 'x86' or 'x64'
- **'flags'**: flags passed to 'cl.exe', make sure OpenMP is activated
Martin Bauer's avatar
Martin Bauer committed
42
- **'restrict_qualifier'**: the restrict qualifier is not standardized across compilers.
43
44
45
  For Windows compilers the qualifier should be ``__restrict``

"""
46
import hashlib
47
import json
48
import os
Martin Bauer's avatar
Martin Bauer committed
49
import platform
50
import shutil
51
import subprocess
52
import textwrap
53
54
from collections import OrderedDict
from sysconfig import get_paths
55
56
from tempfile import TemporaryDirectory

57
import numpy as np
58
from appdirs import user_cache_dir, user_config_dir
59
60

from pystencils import FieldType
61
from pystencils.backends.cbackend import generate_c, get_headers
62
from pystencils.include import get_pystencils_include_path
63
from pystencils.kernel_wrapper import KernelWrapper
Martin Bauer's avatar
Martin Bauer committed
64
from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
65

66

67
def make_python_function(kernel_function_node, custom_backend=None):
68
69
70
71
72
73
74
    """
    Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function

    The parameters of the kernel are:
        - numpy arrays for each field used in the kernel. The keyword argument name is the name of the field
        - all symbols which are not defined in the kernel itself are expected as parameters

Martin Bauer's avatar
Martin Bauer committed
75
    :param kernel_function_node: the abstract syntax tree
76
77
    :return: kernel functor
    """
78
    result = compile_and_load(kernel_function_node, custom_backend)
79
    return result
80
81


Martin Bauer's avatar
Martin Bauer committed
82
def set_config(config):
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    """
    Override the configuration provided in config file

    Configuration of compiler parameters:
    If this function is not called the configuration is taken from a config file in JSON format which
    is searched in the following locations in the order specified:
        - at location provided in environment variable PYSTENCILS_CONFIG (if this variable exists)
        - a file called ".pystencils.json" in the current working directory
        - ~/.pystencils.json in your home
    If none of these files exist a file ~/.pystencils.json is created with a default configuration using
    the GNU 'g++'

    An example JSON file with all possible keys. If not all keys are specified, default values are used
    ``
    {
Martin Bauer's avatar
Martin Bauer committed
98
99
100
101
102
103
104
        'compiler' :
        {
            "command": "/software/intel/2017/bin/icpc",
            "flags": "-Ofast -DNDEBUG -fPIC -march=native -fopenmp",
            "env": {
                "LM_PROJECT": "iwia",
            }
105
        }
106
    }
107
108
    ``
    """
Martin Bauer's avatar
Martin Bauer committed
109
110
111
112
    global _config
    _config = config.copy()


Martin Bauer's avatar
Martin Bauer committed
113
114
def get_configuration_file_path():
    config_path_in_home = os.path.join(user_config_dir('pystencils'), 'config.json')
115

116
117
118
    # 1) Read path from environment variable if found
    if 'PYSTENCILS_CONFIG' in os.environ:
        return os.environ['PYSTENCILS_CONFIG'], True
Martin Bauer's avatar
Martin Bauer committed
119
120
121
    # 2) Look in current directory for pystencils.json
    elif os.path.exists("pystencils.json"):
        return "pystencils.json", True
122
    # 3) Try ~/.pystencils.json
Martin Bauer's avatar
Martin Bauer committed
123
124
    elif os.path.exists(config_path_in_home):
        return config_path_in_home, True
125
    else:
Martin Bauer's avatar
Martin Bauer committed
126
        return config_path_in_home, False
127
128


Martin Bauer's avatar
Martin Bauer committed
129
130
def create_folder(path, is_file):
    if is_file:
Martin Bauer's avatar
Martin Bauer committed
131
132
133
134
135
136
137
        path = os.path.split(path)[0]
    try:
        os.makedirs(path)
    except os.error:
        pass


138
139
140
141
142
143
144
145
146
147
def get_llc_command():
    """Try to get executable for llvm's IR compiler llc

    We try if one of the following is in PATH: llc, llc-10, llc-9, llc-8, llc-7, llc-6
    """
    candidates = ['llc', 'llc-10', 'llc-9', 'llc-8', 'llc-7', 'llc-6']
    found_executables = (e for e in candidates if shutil.which(e))
    return next(found_executables, None)


Martin Bauer's avatar
Martin Bauer committed
148
def read_config():
Martin Bauer's avatar
Martin Bauer committed
149
    if platform.system().lower() == 'linux':
Martin Bauer's avatar
Martin Bauer committed
150
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
151
152
            ('os', 'linux'),
            ('command', 'g++'),
153
            ('llc_command', get_llc_command() or 'llc'),
154
            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11'),
Martin Bauer's avatar
Martin Bauer committed
155
            ('restrict_qualifier', '__restrict__')
Martin Bauer's avatar
Martin Bauer committed
156
157
        ])
    elif platform.system().lower() == 'windows':
Martin Bauer's avatar
Martin Bauer committed
158
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
159
            ('os', 'windows'),
Martin Bauer's avatar
Martin Bauer committed
160
            ('msvc_version', 'latest'),
161
            ('llc_command', get_llc_command() or 'llc'),
Martin Bauer's avatar
Martin Bauer committed
162
            ('arch', 'x64'),
163
            ('flags', '/Ox /fp:fast /openmp /arch:avx'),
Martin Bauer's avatar
Martin Bauer committed
164
            ('restrict_qualifier', '__restrict')
Martin Bauer's avatar
Martin Bauer committed
165
        ])
Julian Hammer's avatar
Julian Hammer committed
166
167
168
169
    elif platform.system().lower() == 'darwin':
        default_compiler_config = OrderedDict([
            ('os', 'darwin'),
            ('command', 'clang++'),
170
            ('llc_command', get_llc_command() or 'llc'),
Michael Kuron's avatar
Michael Kuron committed
171
            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -Xclang -fopenmp -std=c++11'),
Julian Hammer's avatar
Julian Hammer committed
172
173
            ('restrict_qualifier', '__restrict__')
        ])
Martin Bauer's avatar
Martin Bauer committed
174
175
176
    default_cache_config = OrderedDict([
        ('object_cache', os.path.join(user_cache_dir('pystencils'), 'objectcache')),
        ('clear_cache_on_start', False),
177
    ])
Martin Bauer's avatar
Martin Bauer committed
178

Martin Bauer's avatar
Martin Bauer committed
179
180
    default_config = OrderedDict([('compiler', default_compiler_config),
                                  ('cache', default_cache_config)])
Martin Bauer's avatar
Martin Bauer committed
181

Martin Bauer's avatar
Martin Bauer committed
182
183
184
    config_path, config_exists = get_configuration_file_path()
    config = default_config.copy()
    if config_exists:
Martin Bauer's avatar
Martin Bauer committed
185
186
        with open(config_path, 'r') as json_config_file:
            loaded_config = json.load(json_config_file)
Martin Bauer's avatar
Martin Bauer committed
187
        config = recursive_dict_update(config, loaded_config)
188
    else:
Martin Bauer's avatar
Martin Bauer committed
189
        create_folder(config_path, True)
190
191
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=4)
Martin Bauer's avatar
Martin Bauer committed
192

193
194
    if config['cache']['object_cache'] is not False:
        config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid())
Martin Bauer's avatar
Martin Bauer committed
195

196
197
        if config['cache']['clear_cache_on_start']:
            clear_cache()
Martin Bauer's avatar
Martin Bauer committed
198

199
        create_folder(config['cache']['object_cache'], False)
Martin Bauer's avatar
Martin Bauer committed
200
201

    if config['compiler']['os'] == 'windows':
Martin Bauer's avatar
Martin Bauer committed
202
203
        from pystencils.cpu.msvc_detection import get_environment
        msvc_env = get_environment(config['compiler']['msvc_version'], config['compiler']['arch'])
204
205
        if 'env' not in config['compiler']:
            config['compiler']['env'] = {}
Martin Bauer's avatar
Martin Bauer committed
206
        config['compiler']['env'].update(msvc_env)
Martin Bauer's avatar
Martin Bauer committed
207

208
209
210
    return config


Martin Bauer's avatar
Martin Bauer committed
211
_config = read_config()
212
213


Martin Bauer's avatar
Martin Bauer committed
214
def get_compiler_config():
Martin Bauer's avatar
Martin Bauer committed
215
216
217
    return _config['compiler']


Martin Bauer's avatar
Martin Bauer committed
218
def get_cache_config():
Martin Bauer's avatar
Martin Bauer committed
219
    return _config['cache']
Michael Kuron's avatar
Michael Kuron committed
220
221


222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def add_or_change_compiler_flags(flags):
    if not isinstance(flags, list) and not isinstance(flags, tuple):
        flags = [flags]

    compiler_config = get_compiler_config()
    cache_config = get_cache_config()
    cache_config['object_cache'] = False  # disable cache

    for flag in flags:
        flag = flag.strip()
        if '=' in flag:
            base = flag.split('=')[0].strip()
        else:
            base = flag

        new_flags = [c for c in compiler_config['flags'].split() if not c.startswith(base)]
        new_flags.append(flag)
        compiler_config['flags'] = ' '.join(new_flags)


Martin Bauer's avatar
Martin Bauer committed
242
243
def clear_cache():
    cache_config = get_cache_config()
244
245
246
    if cache_config['object_cache'] is not False:
        shutil.rmtree(cache_config['object_cache'], ignore_errors=True)
        create_folder(cache_config['object_cache'], False)
Martin Bauer's avatar
Martin Bauer committed
247
248


249
250
251
252
253
254
255
256
257
type_mapping = {
    np.float32: ('PyFloat_AsDouble', 'float'),
    np.float64: ('PyFloat_AsDouble', 'double'),
    np.int16: ('PyLong_AsLong', 'int16_t'),
    np.int32: ('PyLong_AsLong', 'int32_t'),
    np.int64: ('PyLong_AsLong', 'int64_t'),
    np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
    np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
    np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
258
259
    np.complex64: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexFloat'),
    np.complex128: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexDouble'),
260
}
Martin Bauer's avatar
Martin Bauer committed
261
262


263
264
265
266
267
268
template_extract_scalar = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
{target_type} {name} = ({target_type}) {extract_function}( obj_{name} );
if( PyErr_Occurred() ) {{ return NULL; }}
"""
Martin Bauer's avatar
Martin Bauer committed
269

270
271
272
template_extract_complex = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
Michael Kuron's avatar
Michael Kuron committed
273
274
{target_type} {name}{{ ({real_type}) {extract_function_real}( obj_{name} ),
                       ({real_type}) {extract_function_imag}( obj_{name} ) }};
275
276
277
if( PyErr_Occurred() ) {{ return NULL; }}
"""

278
279
280
281
template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
Py_buffer buffer_{name};
282
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT);
283
284
285
286
287
288
289
290
291
292
if (buffer_{name}_res == -1) {{ return NULL; }}
"""

template_release_buffer = """
PyBuffer_Release(&buffer_{name});
"""

template_function_boilerplate = """
static PyObject * {func_name}(PyObject * self, PyObject * args, PyObject * kwargs)
{{
Martin Bauer's avatar
Martin Bauer committed
293
294
295
296
    if( !kwargs || !PyDict_Check(kwargs) ) {{ 
        PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); 
        return NULL; 
    }}
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    {pre_call_code}
    kernel_{func_name}({parameters});
    {post_call_code}
    Py_RETURN_NONE;
}}
"""

template_check_array = """
if(!({cond})) {{ 
    PyErr_SetString(PyExc_ValueError, "Wrong {what} of array {name}. Expected {expected}"); 
    return NULL; 
}}
"""

template_size_check = """
if(!({cond})) {{ 
    PyErr_SetString(PyExc_TypeError, "Arrays must have same shape"); return NULL; 
}}"""

template_module_boilerplate = """
static PyMethodDef method_definitions[] = {{
    {method_definitions}
    {{NULL, NULL, 0, NULL}}
}};

static struct PyModuleDef module_definition = {{
    PyModuleDef_HEAD_INIT,
    "{module_name}",   /* name of module */
    NULL,     /* module documentation, may be NULL */
    -1,       /* size of per-interpreter state of the module,
                 or -1 if the module keeps state in global variables. */
    method_definitions
}};

PyMODINIT_FUNC
PyInit_{module_name}(void)
{{
    return PyModule_Create(&module_definition);
}}
"""


def equal_size_check(fields):
    fields = list(fields)
    if len(fields) <= 1:
        return ""

    ref_field = fields[0]
345
346
    cond = ["(buffer_{field.name}.shape[{i}] == buffer_{ref_field.name}.shape[{i}])".format(ref_field=ref_field,
                                                                                            field=field_to_test, i=i)
347
348
349
350
351
352
353
354
355
356
357
358
359
            for field_to_test in fields[1:]
            for i in range(fields[0].spatial_dimensions)]
    cond = " && ".join(cond)
    return template_size_check.format(cond=cond)


def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
    pre_call_code = ""
    parameters = []
    post_call_code = ""
    variable_sized_normal_fields = set()
    variable_sized_index_fields = set()

360
361
362
363
364
    for param in parameter_info:
        if param.is_field_pointer:
            field = param.fields[0]
            pre_call_code += template_extract_array.format(name=field.name)
            post_call_code += template_release_buffer.format(name=field.name)
365
            parameters.append(f"({str(field.dtype)} *)buffer_{field.name}.buf")
366

367
368
369
370
            if insert_checks:
                np_dtype = field.dtype.numpy_dtype
                item_size = np_dtype.itemsize

371
372
                if (np_dtype.isbuiltin and FieldType.is_generic(field)
                        and not np.issubdtype(field.dtype.numpy_dtype, np.complexfloating)):
373
374
375
376
377
                    dtype_cond = "buffer_{name}.format[0] == '{format}'".format(name=field.name,
                                                                                format=field.dtype.numpy_dtype.char)
                    pre_call_code += template_check_array.format(cond=dtype_cond, what="data type", name=field.name,
                                                                 expected=str(field.dtype.numpy_dtype))

378
                item_size_cond = f"buffer_{field.name}.itemsize == {item_size}"
379
380
381
382
                pre_call_code += template_check_array.format(cond=item_size_cond, what="itemsize", name=field.name,
                                                             expected=item_size)

                if field.has_fixed_shape:
383
                    shape_cond = [f"buffer_{field.name}.shape[{i}] == {s}"
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
                                  for i, s in enumerate(field.spatial_shape)]
                    shape_cond = " && ".join(shape_cond)
                    pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=field.name,
                                                                 expected=str(field.shape))

                    expected_strides = [e * item_size for e in field.spatial_strides]
                    stride_check_code = "(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)"
                    strides_cond = " && ".join([stride_check_code.format(s=s, i=i, name=field.name)
                                                for i, s in enumerate(expected_strides)])
                    pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=field.name,
                                                                 expected=str(expected_strides))
                else:
                    if FieldType.is_generic(field):
                        variable_sized_normal_fields.add(field)
                    elif FieldType.is_indexed(field):
                        variable_sized_index_fields.add(field)
400
401
402
403
404
405
        elif param.is_field_stride:
            field = param.fields[0]
            item_size = field.dtype.numpy_dtype.itemsize
            parameters.append("buffer_{name}.strides[{i}] / {bytes}".format(bytes=item_size, i=param.symbol.coordinate,
                                                                            name=field.name))
        elif param.is_field_shape:
406
            parameters.append(f"buffer_{param.field_name}.shape[{param.symbol.coordinate}]")
Martin Bauer's avatar
Martin Bauer committed
407
        else:
408
            extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
409
410
411
412
            if np.issubdtype(param.symbol.dtype.numpy_dtype, np.complexfloating):
                pre_call_code += template_extract_complex.format(extract_function_real=extract_function[0],
                                                                 extract_function_imag=extract_function[1],
                                                                 target_type=target_type,
Michael Kuron's avatar
Michael Kuron committed
413
414
                                                                 real_type="float" if target_type == "ComplexFloat"
                                                                           else "double",
415
416
417
418
419
420
                                                                 name=param.symbol.name)
            else:
                pre_call_code += template_extract_scalar.format(extract_function=extract_function,
                                                                target_type=target_type,
                                                                name=param.symbol.name)

421
            parameters.append(param.symbol.name)
Martin Bauer's avatar
Martin Bauer committed
422

423
424
    pre_call_code += equal_size_check(variable_sized_normal_fields)
    pre_call_code += equal_size_check(variable_sized_index_fields)
Martin Bauer's avatar
Martin Bauer committed
425

426
427
428
429
    pre_call_code = textwrap.indent(pre_call_code, '    ')
    post_call_code = textwrap.indent(post_call_code, '    ')
    return template_function_boilerplate.format(func_name=name, pre_call_code=pre_call_code,
                                                post_call_code=post_call_code, parameters=", ".join(parameters))
Martin Bauer's avatar
Martin Bauer committed
430
431


432
433
434
435
def create_module_boilerplate_code(module_name, names):
    method_definition = '{{"{name}", (PyCFunction){name}, METH_VARARGS | METH_KEYWORDS, ""}},'
    method_definitions = "\n".join([method_definition.format(name=name) for name in names])
    return template_module_boilerplate.format(module_name=module_name, method_definitions=method_definitions)
Martin Bauer's avatar
Martin Bauer committed
436

437

438
439
def load_kernel_from_file(module_name, function_name, path):
    from importlib.util import spec_from_file_location, module_from_spec
440
441
442
443
444
445
446
447
448
449
450
451
452
    try:
        spec = spec_from_file_location(name=module_name, location=path)
        mod = module_from_spec(spec)
        spec.loader.exec_module(mod)
    except ImportError:
        import time
        import warnings
        warnings.warn("Could not load " + path + ", trying on more time...")
        time.sleep(1)
        spec = spec_from_file_location(name=module_name, location=path)
        mod = module_from_spec(spec)
        spec.loader.exec_module(mod)

453
    return getattr(mod, function_name)
454

Martin Bauer's avatar
Martin Bauer committed
455

Martin Bauer's avatar
Martin Bauer committed
456
457
458
459
460
def run_compile_step(command):
    compiler_config = get_compiler_config()
    config_env = compiler_config['env'] if 'env' in compiler_config else {}
    compile_environment = os.environ.copy()
    compile_environment.update(config_env)
Michael Kuron's avatar
Michael Kuron committed
461
    try:
Martin Bauer's avatar
Martin Bauer committed
462
463
        shell = True if compiler_config['os'].lower() == 'windows' else False
        subprocess.check_output(command, env=compile_environment, stderr=subprocess.STDOUT, shell=shell)
Michael Kuron's avatar
Michael Kuron committed
464
    except subprocess.CalledProcessError as e:
Martin Bauer's avatar
Martin Bauer committed
465
        print(" ".join(command))
466
        print(e.output.decode('utf8'))
Michael Kuron's avatar
Michael Kuron committed
467
        raise e
468
469


470
class ExtensionModuleCode:
471
    def __init__(self, module_name='generated', custom_backend=None):
472
        self.module_name = module_name
Martin Bauer's avatar
Martin Bauer committed
473

474
475
        self._ast_nodes = []
        self._function_names = []
476
        self._custom_backend = custom_backend
Martin Bauer's avatar
Martin Bauer committed
477

478
479
480
    def add_function(self, ast, name=None):
        self._ast_nodes.append(ast)
        self._function_names.append(name if name is not None else ast.function_name)
Martin Bauer's avatar
Martin Bauer committed
481

482
    def write_to_file(self, restrict_qualifier, function_prefix, file):
483
        headers = {'<math.h>', '<stdint.h>'}
484
485
        for ast in self._ast_nodes:
            headers.update(get_headers(ast))
486
487
488
        header_list = list(headers)
        header_list.sort()
        header_list.insert(0, '"Python.h"')
489

490
        includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
491
492
        print(includes, file=file)
        print("\n", file=file)
493
494
        print(f"#define RESTRICT {restrict_qualifier}", file=file)
        print(f"#define FUNC_PREFIX {function_prefix}", file=file)
495
        print("\n", file=file)
496

497
498
499
        for ast, name in zip(self._ast_nodes, self._function_names):
            old_name = ast.function_name
            ast.function_name = "kernel_" + name
500
            print(generate_c(ast, custom_backend=self._custom_backend), file=file)
501
            print(create_function_boilerplate_code(ast.get_parameters(), name), file=file)
502
503
            ast.function_name = old_name
        print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
Martin Bauer's avatar
Martin Bauer committed
504
505


506
def compile_module(code, code_hash, base_dir):
507
    compiler_config = get_compiler_config()
508
    extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()]
509

510
511
    if compiler_config['os'].lower() == 'windows':
        function_prefix = '__declspec(dllexport)'
512
        lib_suffix = '.pyd'
513
514
515
516
517
518
519
520
        object_suffix = '.obj'
        windows = True
    else:
        function_prefix = ''
        lib_suffix = '.so'
        object_suffix = '.o'
        windows = False

521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
    src_file = os.path.join(base_dir, code_hash + ".cpp")
    lib_file = os.path.join(base_dir, code_hash + lib_suffix)
    object_file = os.path.join(base_dir, code_hash + object_suffix)

    if not os.path.exists(object_file):
        with file_handle_for_atomic_write(src_file) as f:
            code.write_to_file(compiler_config['restrict_qualifier'], function_prefix, f)

        if windows:
            compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split()
            compile_cmd += [*extra_flags, src_file, '/Fo' + object_file]
            run_compile_step(compile_cmd)
        else:
            with atomic_file_write(object_file) as file_name:
                compile_cmd = [compiler_config['command'], '-c'] + compiler_config['flags'].split()
                compile_cmd += [*extra_flags, '-o', file_name, src_file]
                run_compile_step(compile_cmd)

        # Linking
        if windows:
            import sysconfig
            config_vars = sysconfig.get_config_vars()
            py_lib = os.path.join(config_vars["installed_base"], "libs",
544
                                  f"python{config_vars['py_version_nodot']}.lib")
545
            run_compile_step(['link.exe', py_lib, '/DLL', '/out:' + lib_file, object_file])
Michael Kuron's avatar
Michael Kuron committed
546
547
548
549
550
        elif platform.system().lower() == 'darwin':
            with atomic_file_write(lib_file) as file_name:
                run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name, '-undefined',
                                  'dynamic_lookup']
                                 + compiler_config['flags'].split())
551
552
        else:
            with atomic_file_write(lib_file) as file_name:
Martin Bauer's avatar
Martin Bauer committed
553
554
                run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name]
                                 + compiler_config['flags'].split())
555
556
557
    return lib_file


558
def compile_and_load(ast, custom_backend=None):
559
    cache_config = get_cache_config()
560
561
562
    code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c',
                                                       custom_backend=custom_backend).encode()).hexdigest()
    code = ExtensionModuleCode(module_name=code_hash_str, custom_backend=custom_backend)
563
    code.add_function(ast, ast.function_name)
564
565
566
567
568
569
570
571
572

    if cache_config['object_cache'] is False:
        with TemporaryDirectory() as base_dir:
            lib_file = compile_module(code, code_hash_str, base_dir)
            result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)
    else:
        lib_file = compile_module(code, code_hash_str, base_dir=cache_config['object_cache'])
        result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)

573
    return KernelWrapper(result, ast.get_parameters(), ast)