cpujit.py 28.5 KB
Newer Older
1
r"""
2

Martin Bauer's avatar
Martin Bauer committed
3
4
5
6
7
8
*pystencils* automatically searches for a compiler, so in most cases no explicit configuration is required.
On Linux make sure that 'gcc' and 'g++' are installed and in your path.
On Windows a recent Visual Studio installation is required.
In case anything does not work as expected or a special compiler should be used, changes can be specified
in a configuration file.

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
*pystencils* looks for a configuration file in JSON format at the following locations in the listed order.

1. at the path specified in the environment variable ``PYSTENCILS_CONFIG``
2. in the current working direction for a file named ``pystencils.json``
3. or in your home directory at ``~/.config/pystencils/config.json`` (Linux) or
   ``%HOMEPATH%\.pystencils\config.json`` (Windows)

If no configuration file is found, a default configuration is created at the above mentioned location in your home.
So run *pystencils* once, then edit the created configuration file.


Compiler Config (Linux)
-----------------------

- **'os'**: should be detected automatically as 'linux'
- **'command'**: path to C++ compiler (defaults to 'g++')
- **'flags'**: space separated list of compiler flags. Make sure to activate OpenMP in your compiler
Martin Bauer's avatar
Martin Bauer committed
26
- **'restrict_qualifier'**: the restrict qualifier is not standardized accross compilers.
27
28
29
30
31
32
33
34
35
36
  For most Linux compilers the qualifier is ``__restrict__``


Compiler Config (Windows)
-------------------------

*pystencils* uses the mechanism of *setuptools.msvc* to search for a compilation environment.
Then 'cl.exe' is used to compile.

- **'os'**: should be detected automatically as 'windows'
Martin Bauer's avatar
Martin Bauer committed
37
- **'msvc_version'**:  either a version number, year number, 'auto' or 'latest' for automatic detection of latest
38
39
  installed version or 'setuptools' for setuptools-based detection. Alternatively path to folder
  where Visual Studio is installed. This path has to contain a file called 'vcvarsall.bat'
40
41
- **'arch'**: 'x86' or 'x64'
- **'flags'**: flags passed to 'cl.exe', make sure OpenMP is activated
Martin Bauer's avatar
Martin Bauer committed
42
- **'restrict_qualifier'**: the restrict qualifier is not standardized across compilers.
43
44
45
  For Windows compilers the qualifier should be ``__restrict``

"""
46
import hashlib
47
import json
48
import os
Martin Bauer's avatar
Martin Bauer committed
49
import platform
50
import shutil
51
import subprocess
52
import textwrap
53
54
from collections import OrderedDict
from sysconfig import get_paths
55
from tempfile import TemporaryDirectory, NamedTemporaryFile
56

57
import numpy as np
58
from appdirs import user_cache_dir, user_config_dir
59
60

from pystencils import FieldType
61
from pystencils.astnodes import LoopOverCoordinate
Michael Kuron's avatar
Michael Kuron committed
62
from pystencils.backends.cbackend import generate_c, get_headers, CFunction
63
from pystencils.data_types import cast_func, VectorType, vector_memory_access
64
from pystencils.include import get_pystencils_include_path
65
from pystencils.kernel_wrapper import KernelWrapper
Markus Holzer's avatar
Testing    
Markus Holzer committed
66
from pystencils.utils import atomic_file_write, recursive_dict_update
67

68

69
def make_python_function(kernel_function_node, custom_backend=None):
70
71
72
73
74
75
76
    """
    Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function

    The parameters of the kernel are:
        - numpy arrays for each field used in the kernel. The keyword argument name is the name of the field
        - all symbols which are not defined in the kernel itself are expected as parameters

Martin Bauer's avatar
Martin Bauer committed
77
    :param kernel_function_node: the abstract syntax tree
Markus Holzer's avatar
Markus Holzer committed
78
    :param custom_backend: use own custom printer for code generation
79
80
    :return: kernel functor
    """
81
    result = compile_and_load(kernel_function_node, custom_backend)
82
    return result
83
84


Martin Bauer's avatar
Martin Bauer committed
85
def set_config(config):
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    """
    Override the configuration provided in config file

    Configuration of compiler parameters:
    If this function is not called the configuration is taken from a config file in JSON format which
    is searched in the following locations in the order specified:
        - at location provided in environment variable PYSTENCILS_CONFIG (if this variable exists)
        - a file called ".pystencils.json" in the current working directory
        - ~/.pystencils.json in your home
    If none of these files exist a file ~/.pystencils.json is created with a default configuration using
    the GNU 'g++'

    An example JSON file with all possible keys. If not all keys are specified, default values are used
    ``
    {
Martin Bauer's avatar
Martin Bauer committed
101
102
103
104
105
106
107
        'compiler' :
        {
            "command": "/software/intel/2017/bin/icpc",
            "flags": "-Ofast -DNDEBUG -fPIC -march=native -fopenmp",
            "env": {
                "LM_PROJECT": "iwia",
            }
108
        }
109
    }
110
111
    ``
    """
Martin Bauer's avatar
Martin Bauer committed
112
113
114
115
    global _config
    _config = config.copy()


Martin Bauer's avatar
Martin Bauer committed
116
117
def get_configuration_file_path():
    config_path_in_home = os.path.join(user_config_dir('pystencils'), 'config.json')
118

119
120
121
    # 1) Read path from environment variable if found
    if 'PYSTENCILS_CONFIG' in os.environ:
        return os.environ['PYSTENCILS_CONFIG'], True
Martin Bauer's avatar
Martin Bauer committed
122
123
124
    # 2) Look in current directory for pystencils.json
    elif os.path.exists("pystencils.json"):
        return "pystencils.json", True
125
    # 3) Try ~/.pystencils.json
Martin Bauer's avatar
Martin Bauer committed
126
127
    elif os.path.exists(config_path_in_home):
        return config_path_in_home, True
128
    else:
Martin Bauer's avatar
Martin Bauer committed
129
        return config_path_in_home, False
130
131


Martin Bauer's avatar
Martin Bauer committed
132
133
def create_folder(path, is_file):
    if is_file:
Martin Bauer's avatar
Martin Bauer committed
134
135
136
137
138
139
140
        path = os.path.split(path)[0]
    try:
        os.makedirs(path)
    except os.error:
        pass


141
142
143
144
145
146
147
148
149
150
def get_llc_command():
    """Try to get executable for llvm's IR compiler llc

    We try if one of the following is in PATH: llc, llc-10, llc-9, llc-8, llc-7, llc-6
    """
    candidates = ['llc', 'llc-10', 'llc-9', 'llc-8', 'llc-7', 'llc-6']
    found_executables = (e for e in candidates if shutil.which(e))
    return next(found_executables, None)


Martin Bauer's avatar
Martin Bauer committed
151
def read_config():
Martin Bauer's avatar
Martin Bauer committed
152
    if platform.system().lower() == 'linux':
Martin Bauer's avatar
Martin Bauer committed
153
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
154
155
            ('os', 'linux'),
            ('command', 'g++'),
156
            ('llc_command', get_llc_command() or 'llc'),
157
            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11'),
Martin Bauer's avatar
Martin Bauer committed
158
            ('restrict_qualifier', '__restrict__')
Martin Bauer's avatar
Martin Bauer committed
159
        ])
Michael Kuron's avatar
Michael Kuron committed
160
161
162
        if platform.machine().startswith('ppc64'):
            default_compiler_config['flags'] = default_compiler_config['flags'].replace('-march=native',
                                                                                        '-mcpu=native')
Martin Bauer's avatar
Martin Bauer committed
163
    elif platform.system().lower() == 'windows':
Martin Bauer's avatar
Martin Bauer committed
164
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
165
            ('os', 'windows'),
Martin Bauer's avatar
Martin Bauer committed
166
            ('msvc_version', 'latest'),
167
            ('llc_command', get_llc_command() or 'llc'),
Martin Bauer's avatar
Martin Bauer committed
168
            ('arch', 'x64'),
Michael Kuron's avatar
Michael Kuron committed
169
            ('flags', '/Ox /fp:fast /OpenMP /arch:avx'),
Martin Bauer's avatar
Martin Bauer committed
170
            ('restrict_qualifier', '__restrict')
Martin Bauer's avatar
Martin Bauer committed
171
        ])
Julian Hammer's avatar
Julian Hammer committed
172
173
174
175
    elif platform.system().lower() == 'darwin':
        default_compiler_config = OrderedDict([
            ('os', 'darwin'),
            ('command', 'clang++'),
176
            ('llc_command', get_llc_command() or 'llc'),
Michael Kuron's avatar
Michael Kuron committed
177
            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -Xclang -fopenmp -std=c++11'),
Julian Hammer's avatar
Julian Hammer committed
178
179
            ('restrict_qualifier', '__restrict__')
        ])
Michael Kuron's avatar
Michael Kuron committed
180
181
182
183
184
185
186
        if platform.machine() == 'arm64':
            default_compiler_config['flags'] = default_compiler_config['flags'].replace('-march=native ', '')
        for libomp in ['/opt/local/lib/libomp/libomp.dylib', '/usr/local/lib/libomp.dylib',
                       '/opt/homebrew/lib/libomp.dylib']:
            if os.path.exists(libomp):
                default_compiler_config['flags'] += ' ' + libomp
                break
Markus Holzer's avatar
Markus Holzer committed
187
188
189
190
    else:
        raise ValueError("The detection of the platform with platform.system() did not work. "
                         "Pystencils is only supported for linux, windows, and darwin platforms.")

Martin Bauer's avatar
Martin Bauer committed
191
192
193
    default_cache_config = OrderedDict([
        ('object_cache', os.path.join(user_cache_dir('pystencils'), 'objectcache')),
        ('clear_cache_on_start', False),
194
    ])
Martin Bauer's avatar
Martin Bauer committed
195

Martin Bauer's avatar
Martin Bauer committed
196
197
    default_config = OrderedDict([('compiler', default_compiler_config),
                                  ('cache', default_cache_config)])
Martin Bauer's avatar
Martin Bauer committed
198

Martin Bauer's avatar
Martin Bauer committed
199
200
201
    config_path, config_exists = get_configuration_file_path()
    config = default_config.copy()
    if config_exists:
Martin Bauer's avatar
Martin Bauer committed
202
203
        with open(config_path, 'r') as json_config_file:
            loaded_config = json.load(json_config_file)
Martin Bauer's avatar
Martin Bauer committed
204
        config = recursive_dict_update(config, loaded_config)
205
    else:
Martin Bauer's avatar
Martin Bauer committed
206
        create_folder(config_path, True)
207
208
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=4)
Martin Bauer's avatar
Martin Bauer committed
209

210
211
    if config['cache']['object_cache'] is not False:
        config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid())
Martin Bauer's avatar
Martin Bauer committed
212

Markus Holzer's avatar
Markus Holzer committed
213
        clear_cache_on_start = False
214
215
216
217
218
        cache_status_file = os.path.join(config['cache']['object_cache'], 'last_config.json')
        if os.path.exists(cache_status_file):
            # check if compiler config has changed
            last_config = json.load(open(cache_status_file, 'r'))
            if set(last_config.items()) != set(config['compiler'].items()):
Markus Holzer's avatar
Markus Holzer committed
219
                clear_cache_on_start = True
220
221
222
            else:
                for key in last_config.keys():
                    if last_config[key] != config['compiler'][key]:
Markus Holzer's avatar
Markus Holzer committed
223
                        clear_cache_on_start = True
224

Markus Holzer's avatar
Markus Holzer committed
225
        if config['cache']['clear_cache_on_start'] or clear_cache_on_start:
226
            shutil.rmtree(config['cache']['object_cache'], ignore_errors=True)
Martin Bauer's avatar
Martin Bauer committed
227

228
        create_folder(config['cache']['object_cache'], False)
229
230
231
        with NamedTemporaryFile('w', dir=os.path.dirname(cache_status_file), delete=False) as f:
            json.dump(config['compiler'], f, indent=4)
        os.replace(f.name, cache_status_file)
Martin Bauer's avatar
Martin Bauer committed
232
233

    if config['compiler']['os'] == 'windows':
Martin Bauer's avatar
Martin Bauer committed
234
235
        from pystencils.cpu.msvc_detection import get_environment
        msvc_env = get_environment(config['compiler']['msvc_version'], config['compiler']['arch'])
236
237
        if 'env' not in config['compiler']:
            config['compiler']['env'] = {}
Martin Bauer's avatar
Martin Bauer committed
238
        config['compiler']['env'].update(msvc_env)
Martin Bauer's avatar
Martin Bauer committed
239

240
241
242
    return config


Martin Bauer's avatar
Martin Bauer committed
243
_config = read_config()
244
245


Martin Bauer's avatar
Martin Bauer committed
246
def get_compiler_config():
Martin Bauer's avatar
Martin Bauer committed
247
248
249
    return _config['compiler']


Martin Bauer's avatar
Martin Bauer committed
250
def get_cache_config():
Martin Bauer's avatar
Martin Bauer committed
251
    return _config['cache']
Michael Kuron's avatar
Michael Kuron committed
252
253


254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def add_or_change_compiler_flags(flags):
    if not isinstance(flags, list) and not isinstance(flags, tuple):
        flags = [flags]

    compiler_config = get_compiler_config()
    cache_config = get_cache_config()
    cache_config['object_cache'] = False  # disable cache

    for flag in flags:
        flag = flag.strip()
        if '=' in flag:
            base = flag.split('=')[0].strip()
        else:
            base = flag

        new_flags = [c for c in compiler_config['flags'].split() if not c.startswith(base)]
        new_flags.append(flag)
        compiler_config['flags'] = ' '.join(new_flags)


Martin Bauer's avatar
Martin Bauer committed
274
275
def clear_cache():
    cache_config = get_cache_config()
276
277
278
    if cache_config['object_cache'] is not False:
        shutil.rmtree(cache_config['object_cache'], ignore_errors=True)
        create_folder(cache_config['object_cache'], False)
Martin Bauer's avatar
Martin Bauer committed
279
280


281
282
283
284
285
286
287
288
289
type_mapping = {
    np.float32: ('PyFloat_AsDouble', 'float'),
    np.float64: ('PyFloat_AsDouble', 'double'),
    np.int16: ('PyLong_AsLong', 'int16_t'),
    np.int32: ('PyLong_AsLong', 'int32_t'),
    np.int64: ('PyLong_AsLong', 'int64_t'),
    np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
    np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
    np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
290
291
    np.complex64: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexFloat'),
    np.complex128: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexDouble'),
292
}
Martin Bauer's avatar
Martin Bauer committed
293

294
295
296
297
298
299
template_extract_scalar = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
{target_type} {name} = ({target_type}) {extract_function}( obj_{name} );
if( PyErr_Occurred() ) {{ return NULL; }}
"""
Martin Bauer's avatar
Martin Bauer committed
300

301
302
303
template_extract_complex = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
Michael Kuron's avatar
Michael Kuron committed
304
305
{target_type} {name}{{ ({real_type}) {extract_function_real}( obj_{name} ),
                       ({real_type}) {extract_function_imag}( obj_{name} ) }};
306
307
308
if( PyErr_Occurred() ) {{ return NULL; }}
"""

309
310
311
312
template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
Py_buffer buffer_{name};
313
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT);
314
315
316
317
318
319
320
321
322
323
if (buffer_{name}_res == -1) {{ return NULL; }}
"""

template_release_buffer = """
PyBuffer_Release(&buffer_{name});
"""

template_function_boilerplate = """
static PyObject * {func_name}(PyObject * self, PyObject * args, PyObject * kwargs)
{{
Martin Bauer's avatar
Martin Bauer committed
324
325
326
327
    if( !kwargs || !PyDict_Check(kwargs) ) {{ 
        PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); 
        return NULL; 
    }}
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    {pre_call_code}
    kernel_{func_name}({parameters});
    {post_call_code}
    Py_RETURN_NONE;
}}
"""

template_check_array = """
if(!({cond})) {{ 
    PyErr_SetString(PyExc_ValueError, "Wrong {what} of array {name}. Expected {expected}"); 
    return NULL; 
}}
"""

template_size_check = """
if(!({cond})) {{ 
    PyErr_SetString(PyExc_TypeError, "Arrays must have same shape"); return NULL; 
}}"""

template_module_boilerplate = """
static PyMethodDef method_definitions[] = {{
    {method_definitions}
    {{NULL, NULL, 0, NULL}}
}};

static struct PyModuleDef module_definition = {{
    PyModuleDef_HEAD_INIT,
    "{module_name}",   /* name of module */
    NULL,     /* module documentation, may be NULL */
    -1,       /* size of per-interpreter state of the module,
                 or -1 if the module keeps state in global variables. */
    method_definitions
}};

PyMODINIT_FUNC
PyInit_{module_name}(void)
{{
    return PyModule_Create(&module_definition);
}}
"""


def equal_size_check(fields):
    fields = list(fields)
    if len(fields) <= 1:
        return ""

    ref_field = fields[0]
Markus Holzer's avatar
Markus Holzer committed
376
    cond = [f"(buffer_{field_to_test.name}.shape[{i}] == buffer_{ref_field.name}.shape[{i}])"
377
378
379
380
381
382
            for field_to_test in fields[1:]
            for i in range(fields[0].spatial_dimensions)]
    cond = " && ".join(cond)
    return template_size_check.format(cond=cond)


Markus Holzer's avatar
Markus Holzer committed
383
def create_function_boilerplate_code(parameter_info, name, ast_node, insert_checks=True):
384
385
386
387
388
389
    pre_call_code = ""
    parameters = []
    post_call_code = ""
    variable_sized_normal_fields = set()
    variable_sized_index_fields = set()

390
391
392
393
394
    for param in parameter_info:
        if param.is_field_pointer:
            field = param.fields[0]
            pre_call_code += template_extract_array.format(name=field.name)
            post_call_code += template_release_buffer.format(name=field.name)
395
            parameters.append(f"({str(field.dtype)} *)buffer_{field.name}.buf")
396

397
398
399
400
            if insert_checks:
                np_dtype = field.dtype.numpy_dtype
                item_size = np_dtype.itemsize

Markus Holzer's avatar
Markus Holzer committed
401
402
403
404
405
406
407
408
                aligned = False
                if ast_node.assignments:
                    aligned = any([a.lhs.args[2] for a in ast_node.assignments
                                   if hasattr(a, 'lhs') and isinstance(a.lhs, cast_func)
                                   and hasattr(a.lhs, 'dtype') and isinstance(a.lhs.dtype, VectorType)])

                if ast_node.instruction_set and aligned:
                    byte_width = ast_node.instruction_set['width'] * item_size
409
410
411
412
413
414
415
416
                    if 'cachelineZero' in ast_node.instruction_set:
                        has_openmp, has_nontemporal = False, False
                        for loop in ast_node.atoms(LoopOverCoordinate):
                            has_openmp = has_openmp or any(['#pragma omp' in p for p in loop.prefix_lines])
                            has_nontemporal = has_nontemporal or any([a.args[0].field == field and a.args[3] for a in
                                                                      loop.atoms(vector_memory_access)])
                        if has_openmp and has_nontemporal:
                            byte_width = ast_node.instruction_set['cachelineSize']
Markus Holzer's avatar
Markus Holzer committed
417
                    offset = max(max(ast_node.ghost_layers)) * item_size
Michael Kuron's avatar
Michael Kuron committed
418
                    offset_cond = f"(((uintptr_t) buffer_{field.name}.buf) + {offset}) % ({byte_width}) == 0"
Markus Holzer's avatar
Markus Holzer committed
419
420
421
422
423
424

                    message = str(offset) + ". This is probably due to a different number of ghost_layers chosen for " \
                                            "the arrays and the kernel creation. If the number of ghost layers for " \
                                            "the kernel creation is not specified it will choose a suitable value " \
                                            "automatically. This value might not " \
                                            "be compatible with the allocated arrays."
425
426
427
                    if type(byte_width) is not int:
                        message += " Note that when both OpenMP and non-temporal stores are enabled, alignment to the "\
                                   "cacheline size is required."
Markus Holzer's avatar
Markus Holzer committed
428
429
430
                    pre_call_code += template_check_array.format(cond=offset_cond, what="offset", name=field.name,
                                                                 expected=message)

431
432
                if (np_dtype.isbuiltin and FieldType.is_generic(field)
                        and not np.issubdtype(field.dtype.numpy_dtype, np.complexfloating)):
Markus Holzer's avatar
Markus Holzer committed
433
                    dtype_cond = f"buffer_{field.name}.format[0] == '{field.dtype.numpy_dtype.char}'"
434
435
436
                    pre_call_code += template_check_array.format(cond=dtype_cond, what="data type", name=field.name,
                                                                 expected=str(field.dtype.numpy_dtype))

437
                item_size_cond = f"buffer_{field.name}.itemsize == {item_size}"
438
439
440
441
                pre_call_code += template_check_array.format(cond=item_size_cond, what="itemsize", name=field.name,
                                                             expected=item_size)

                if field.has_fixed_shape:
442
                    shape_cond = [f"buffer_{field.name}.shape[{i}] == {s}"
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
                                  for i, s in enumerate(field.spatial_shape)]
                    shape_cond = " && ".join(shape_cond)
                    pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=field.name,
                                                                 expected=str(field.shape))

                    expected_strides = [e * item_size for e in field.spatial_strides]
                    stride_check_code = "(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)"
                    strides_cond = " && ".join([stride_check_code.format(s=s, i=i, name=field.name)
                                                for i, s in enumerate(expected_strides)])
                    pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=field.name,
                                                                 expected=str(expected_strides))
                else:
                    if FieldType.is_generic(field):
                        variable_sized_normal_fields.add(field)
                    elif FieldType.is_indexed(field):
                        variable_sized_index_fields.add(field)
459
460
461
        elif param.is_field_stride:
            field = param.fields[0]
            item_size = field.dtype.numpy_dtype.itemsize
Markus Holzer's avatar
Markus Holzer committed
462
            parameters.append(f"buffer_{field.name}.strides[{param.symbol.coordinate}] / {item_size}")
463
        elif param.is_field_shape:
464
            parameters.append(f"buffer_{param.field_name}.shape[{param.symbol.coordinate}]")
Michael Kuron's avatar
Michael Kuron committed
465
466
        elif type(param.symbol) is CFunction:
            continue
Martin Bauer's avatar
Martin Bauer committed
467
        else:
468
            extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
469
470
471
472
            if np.issubdtype(param.symbol.dtype.numpy_dtype, np.complexfloating):
                pre_call_code += template_extract_complex.format(extract_function_real=extract_function[0],
                                                                 extract_function_imag=extract_function[1],
                                                                 target_type=target_type,
Michael Kuron's avatar
Michael Kuron committed
473
                                                                 real_type="float" if target_type == "ComplexFloat"
Markus Holzer's avatar
Markus Holzer committed
474
                                                                 else "double",
475
476
477
478
479
480
                                                                 name=param.symbol.name)
            else:
                pre_call_code += template_extract_scalar.format(extract_function=extract_function,
                                                                target_type=target_type,
                                                                name=param.symbol.name)

481
            parameters.append(param.symbol.name)
Martin Bauer's avatar
Martin Bauer committed
482

483
484
    pre_call_code += equal_size_check(variable_sized_normal_fields)
    pre_call_code += equal_size_check(variable_sized_index_fields)
Martin Bauer's avatar
Martin Bauer committed
485

486
487
488
489
    pre_call_code = textwrap.indent(pre_call_code, '    ')
    post_call_code = textwrap.indent(post_call_code, '    ')
    return template_function_boilerplate.format(func_name=name, pre_call_code=pre_call_code,
                                                post_call_code=post_call_code, parameters=", ".join(parameters))
Martin Bauer's avatar
Martin Bauer committed
490
491


492
493
494
495
def create_module_boilerplate_code(module_name, names):
    method_definition = '{{"{name}", (PyCFunction){name}, METH_VARARGS | METH_KEYWORDS, ""}},'
    method_definitions = "\n".join([method_definition.format(name=name) for name in names])
    return template_module_boilerplate.format(module_name=module_name, method_definitions=method_definitions)
Martin Bauer's avatar
Martin Bauer committed
496

497

498
499
def load_kernel_from_file(module_name, function_name, path):
    from importlib.util import spec_from_file_location, module_from_spec
500
501
502
503
504
505
506
    try:
        spec = spec_from_file_location(name=module_name, location=path)
        mod = module_from_spec(spec)
        spec.loader.exec_module(mod)
    except ImportError:
        import time
        import warnings
Markus Holzer's avatar
Markus Holzer committed
507
508
        warnings.warn(f"Could not load {path}, trying on more time in 5 seconds ...")
        time.sleep(5)
509
510
511
512
        spec = spec_from_file_location(name=module_name, location=path)
        mod = module_from_spec(spec)
        spec.loader.exec_module(mod)

513
    return getattr(mod, function_name)
514

Martin Bauer's avatar
Martin Bauer committed
515

Martin Bauer's avatar
Martin Bauer committed
516
517
518
519
520
def run_compile_step(command):
    compiler_config = get_compiler_config()
    config_env = compiler_config['env'] if 'env' in compiler_config else {}
    compile_environment = os.environ.copy()
    compile_environment.update(config_env)
Michael Kuron's avatar
Michael Kuron committed
521
    try:
Martin Bauer's avatar
Martin Bauer committed
522
523
        shell = True if compiler_config['os'].lower() == 'windows' else False
        subprocess.check_output(command, env=compile_environment, stderr=subprocess.STDOUT, shell=shell)
Michael Kuron's avatar
Michael Kuron committed
524
    except subprocess.CalledProcessError as e:
Martin Bauer's avatar
Martin Bauer committed
525
        print(" ".join(command))
526
        print(e.output.decode('utf8'))
Michael Kuron's avatar
Michael Kuron committed
527
        raise e
528
529


530
class ExtensionModuleCode:
531
    def __init__(self, module_name='generated', custom_backend=None):
532
        self.module_name = module_name
Martin Bauer's avatar
Martin Bauer committed
533

534
535
        self._ast_nodes = []
        self._function_names = []
536
        self._custom_backend = custom_backend
Markus Holzer's avatar
Markus Holzer committed
537
538
        self._code_string = str()
        self._code_hash = None
Martin Bauer's avatar
Martin Bauer committed
539

540
541
542
    def add_function(self, ast, name=None):
        self._ast_nodes.append(ast)
        self._function_names.append(name if name is not None else ast.function_name)
Martin Bauer's avatar
Martin Bauer committed
543

Markus Holzer's avatar
Markus Holzer committed
544
545
546
    def create_code_string(self, restrict_qualifier, function_prefix):
        self._code_string = str()

547
        headers = {'<math.h>', '<stdint.h>'}
548
549
        for ast in self._ast_nodes:
            headers.update(get_headers(ast))
550
551
552
        header_list = list(headers)
        header_list.sort()
        header_list.insert(0, '"Python.h"')
553
554
555
        ps_headers = [os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]) for h in header_list
                      if os.path.exists(os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]))]
        header_hash = b''.join([hashlib.sha256(open(h, 'rb').read()).digest() for h in ps_headers])
556

557
        includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
Markus Holzer's avatar
Markus Holzer committed
558
559
560
561
562
        self._code_string += includes
        self._code_string += "\n"
        self._code_string += f"#define RESTRICT {restrict_qualifier} \n"
        self._code_string += f"#define FUNC_PREFIX {function_prefix}"
        self._code_string += "\n"
563

564
565
566
        for ast, name in zip(self._ast_nodes, self._function_names):
            old_name = ast.function_name
            ast.function_name = "kernel_" + name
Markus Holzer's avatar
Markus Holzer committed
567
568
            self._code_string += generate_c(ast, custom_backend=self._custom_backend)
            self._code_string += create_function_boilerplate_code(ast.get_parameters(), name, ast)
569
            ast.function_name = old_name
Markus Holzer's avatar
Markus Holzer committed
570

571
        self._code_hash = "mod_" + hashlib.sha256(self._code_string.encode() + header_hash).hexdigest()
Markus Holzer's avatar
Markus Holzer committed
572
573
574
575
576
577
578
579
580
        self._code_string += create_module_boilerplate_code(self._code_hash, self._function_names)

    def get_hash_of_code(self):
        assert self._code_string, "The code must be generated first"
        return self._code_hash

    def write_to_file(self, file):
        assert self._code_string, "The code must be generated first"
        print(self._code_string, file=file)
Martin Bauer's avatar
Martin Bauer committed
581
582


Markus Holzer's avatar
Markus Holzer committed
583
584
585
586
def compile_module(code, code_hash, base_dir, compile_flags=None):
    if compile_flags is None:
        compile_flags = []

587
    compiler_config = get_compiler_config()
Michael Kuron's avatar
Michael Kuron committed
588
    extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()] + compile_flags
589

590
    if compiler_config['os'].lower() == 'windows':
591
        lib_suffix = '.pyd'
592
593
594
595
596
597
598
        object_suffix = '.obj'
        windows = True
    else:
        lib_suffix = '.so'
        object_suffix = '.o'
        windows = False

599
600
601
602
603
    src_file = os.path.join(base_dir, code_hash + ".cpp")
    lib_file = os.path.join(base_dir, code_hash + lib_suffix)
    object_file = os.path.join(base_dir, code_hash + object_suffix)

    if not os.path.exists(object_file):
Markus Holzer's avatar
Testing    
Markus Holzer committed
604
605
606
607
608
        try:
            with open(src_file, 'x') as f:
                code.write_to_file(f)
        except FileExistsError:
            pass
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624

        if windows:
            compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split()
            compile_cmd += [*extra_flags, src_file, '/Fo' + object_file]
            run_compile_step(compile_cmd)
        else:
            with atomic_file_write(object_file) as file_name:
                compile_cmd = [compiler_config['command'], '-c'] + compiler_config['flags'].split()
                compile_cmd += [*extra_flags, '-o', file_name, src_file]
                run_compile_step(compile_cmd)

        # Linking
        if windows:
            import sysconfig
            config_vars = sysconfig.get_config_vars()
            py_lib = os.path.join(config_vars["installed_base"], "libs",
625
                                  f"python{config_vars['py_version_nodot']}.lib")
626
            run_compile_step(['link.exe', py_lib, '/DLL', '/out:' + lib_file, object_file])
Michael Kuron's avatar
Michael Kuron committed
627
628
629
630
631
        elif platform.system().lower() == 'darwin':
            with atomic_file_write(lib_file) as file_name:
                run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name, '-undefined',
                                  'dynamic_lookup']
                                 + compiler_config['flags'].split())
632
633
        else:
            with atomic_file_write(lib_file) as file_name:
Martin Bauer's avatar
Martin Bauer committed
634
635
                run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name]
                                 + compiler_config['flags'].split())
636
637
638
    return lib_file


639
def compile_and_load(ast, custom_backend=None):
640
    cache_config = get_cache_config()
Markus Holzer's avatar
Markus Holzer committed
641
642
643
644
645

    compiler_config = get_compiler_config()
    function_prefix = '__declspec(dllexport)' if compiler_config['os'].lower() == 'windows' else ''

    code = ExtensionModuleCode(custom_backend=custom_backend)
646
    code.add_function(ast, ast.function_name)
647

Markus Holzer's avatar
Markus Holzer committed
648
649
650
    code.create_code_string(compiler_config['restrict_qualifier'], function_prefix)
    code_hash_str = code.get_hash_of_code()

Michael Kuron's avatar
Michael Kuron committed
651
652
653
654
    compile_flags = []
    if ast.instruction_set and 'compile_flags' in ast.instruction_set:
        compile_flags = ast.instruction_set['compile_flags']

655
656
    if cache_config['object_cache'] is False:
        with TemporaryDirectory() as base_dir:
Michael Kuron's avatar
Michael Kuron committed
657
            lib_file = compile_module(code, code_hash_str, base_dir, compile_flags=compile_flags)
658
659
            result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)
    else:
Michael Kuron's avatar
Michael Kuron committed
660
661
        lib_file = compile_module(code, code_hash_str, base_dir=cache_config['object_cache'],
                                  compile_flags=compile_flags)
662
663
        result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)

664
    return KernelWrapper(result, ast.get_parameters(), ast)