cpujit.py 21.1 KB
Newer Older
1
r"""
2

Martin Bauer's avatar
Martin Bauer committed
3 4 5 6 7 8
*pystencils* automatically searches for a compiler, so in most cases no explicit configuration is required.
On Linux make sure that 'gcc' and 'g++' are installed and in your path.
On Windows a recent Visual Studio installation is required.
In case anything does not work as expected or a special compiler should be used, changes can be specified
in a configuration file.

9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
*pystencils* looks for a configuration file in JSON format at the following locations in the listed order.

1. at the path specified in the environment variable ``PYSTENCILS_CONFIG``
2. in the current working direction for a file named ``pystencils.json``
3. or in your home directory at ``~/.config/pystencils/config.json`` (Linux) or
   ``%HOMEPATH%\.pystencils\config.json`` (Windows)

If no configuration file is found, a default configuration is created at the above mentioned location in your home.
So run *pystencils* once, then edit the created configuration file.


Compiler Config (Linux)
-----------------------

- **'os'**: should be detected automatically as 'linux'
- **'command'**: path to C++ compiler (defaults to 'g++')
- **'flags'**: space separated list of compiler flags. Make sure to activate OpenMP in your compiler
Martin Bauer's avatar
Martin Bauer committed
26
- **'restrict_qualifier'**: the restrict qualifier is not standardized accross compilers.
27 28 29 30 31 32 33 34 35 36
  For most Linux compilers the qualifier is ``__restrict__``


Compiler Config (Windows)
-------------------------

*pystencils* uses the mechanism of *setuptools.msvc* to search for a compilation environment.
Then 'cl.exe' is used to compile.

- **'os'**: should be detected automatically as 'windows'
Martin Bauer's avatar
Martin Bauer committed
37
- **'msvc_version'**:  either a version number, year number, 'auto' or 'latest' for automatic detection of latest
38 39
  installed version or 'setuptools' for setuptools-based detection. Alternatively path to folder
  where Visual Studio is installed. This path has to contain a file called 'vcvarsall.bat'
40 41
- **'arch'**: 'x86' or 'x64'
- **'flags'**: flags passed to 'cl.exe', make sure OpenMP is activated
Martin Bauer's avatar
Martin Bauer committed
42
- **'restrict_qualifier'**: the restrict qualifier is not standardized across compilers.
43 44 45
  For Windows compilers the qualifier should be ``__restrict``

"""
46
import hashlib
47
import json
Martin Bauer's avatar
Martin Bauer committed
48
import os
Martin Bauer's avatar
Martin Bauer committed
49
import platform
50
import shutil
Martin Bauer's avatar
Martin Bauer committed
51
import subprocess
52
import textwrap
Martin Bauer's avatar
Martin Bauer committed
53 54
from collections import OrderedDict
from sysconfig import get_paths
55 56
from tempfile import TemporaryDirectory

57
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
58
from appdirs import user_cache_dir, user_config_dir
59

60
from pystencils.backends.cbackend import generate_c, get_headers
Martin Bauer's avatar
Martin Bauer committed
61
from pystencils.field import FieldType
62
from pystencils.include import get_pystencils_include_path
Martin Bauer's avatar
Martin Bauer committed
63
from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
64

65

Martin Bauer's avatar
Martin Bauer committed
66
def make_python_function(kernel_function_node):
67 68 69 70 71 72 73
    """
    Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function

    The parameters of the kernel are:
        - numpy arrays for each field used in the kernel. The keyword argument name is the name of the field
        - all symbols which are not defined in the kernel itself are expected as parameters

Martin Bauer's avatar
Martin Bauer committed
74
    :param kernel_function_node: the abstract syntax tree
75 76
    :return: kernel functor
    """
77 78
    result = compile_and_load(kernel_function_node)
    return result
79 80


Martin Bauer's avatar
Martin Bauer committed
81
def set_config(config):
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    """
    Override the configuration provided in config file

    Configuration of compiler parameters:
    If this function is not called the configuration is taken from a config file in JSON format which
    is searched in the following locations in the order specified:
        - at location provided in environment variable PYSTENCILS_CONFIG (if this variable exists)
        - a file called ".pystencils.json" in the current working directory
        - ~/.pystencils.json in your home
    If none of these files exist a file ~/.pystencils.json is created with a default configuration using
    the GNU 'g++'

    An example JSON file with all possible keys. If not all keys are specified, default values are used
    ``
    {
Martin Bauer's avatar
Martin Bauer committed
97 98 99 100 101 102 103
        'compiler' :
        {
            "command": "/software/intel/2017/bin/icpc",
            "flags": "-Ofast -DNDEBUG -fPIC -march=native -fopenmp",
            "env": {
                "LM_PROJECT": "iwia",
            }
104
        }
105
    }
106 107
    ``
    """
Martin Bauer's avatar
Martin Bauer committed
108 109 110 111
    global _config
    _config = config.copy()


Martin Bauer's avatar
Martin Bauer committed
112 113
def get_configuration_file_path():
    config_path_in_home = os.path.join(user_config_dir('pystencils'), 'config.json')
114

115 116 117
    # 1) Read path from environment variable if found
    if 'PYSTENCILS_CONFIG' in os.environ:
        return os.environ['PYSTENCILS_CONFIG'], True
Martin Bauer's avatar
Martin Bauer committed
118 119 120
    # 2) Look in current directory for pystencils.json
    elif os.path.exists("pystencils.json"):
        return "pystencils.json", True
121
    # 3) Try ~/.pystencils.json
Martin Bauer's avatar
Martin Bauer committed
122 123
    elif os.path.exists(config_path_in_home):
        return config_path_in_home, True
124
    else:
Martin Bauer's avatar
Martin Bauer committed
125
        return config_path_in_home, False
126 127


Martin Bauer's avatar
Martin Bauer committed
128 129
def create_folder(path, is_file):
    if is_file:
Martin Bauer's avatar
Martin Bauer committed
130 131 132 133 134 135 136
        path = os.path.split(path)[0]
    try:
        os.makedirs(path)
    except os.error:
        pass


Martin Bauer's avatar
Martin Bauer committed
137
def read_config():
Martin Bauer's avatar
Martin Bauer committed
138
    if platform.system().lower() == 'linux':
Martin Bauer's avatar
Martin Bauer committed
139
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
140 141
            ('os', 'linux'),
            ('command', 'g++'),
142
            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11'),
Martin Bauer's avatar
Martin Bauer committed
143
            ('restrict_qualifier', '__restrict__')
Martin Bauer's avatar
Martin Bauer committed
144 145
        ])
    elif platform.system().lower() == 'windows':
Martin Bauer's avatar
Martin Bauer committed
146
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
147
            ('os', 'windows'),
Martin Bauer's avatar
Martin Bauer committed
148
            ('msvc_version', 'latest'),
Martin Bauer's avatar
Martin Bauer committed
149
            ('arch', 'x64'),
150
            ('flags', '/Ox /fp:fast /openmp /arch:avx'),
Martin Bauer's avatar
Martin Bauer committed
151
            ('restrict_qualifier', '__restrict')
Martin Bauer's avatar
Martin Bauer committed
152
        ])
Julian Hammer's avatar
Julian Hammer committed
153 154 155 156 157 158 159
    elif platform.system().lower() == 'darwin':
        default_compiler_config = OrderedDict([
            ('os', 'darwin'),
            ('command', 'clang++'),
            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11'),
            ('restrict_qualifier', '__restrict__')
        ])
Martin Bauer's avatar
Martin Bauer committed
160 161 162
    default_cache_config = OrderedDict([
        ('object_cache', os.path.join(user_cache_dir('pystencils'), 'objectcache')),
        ('clear_cache_on_start', False),
163
    ])
Martin Bauer's avatar
Martin Bauer committed
164

Martin Bauer's avatar
Martin Bauer committed
165 166
    default_config = OrderedDict([('compiler', default_compiler_config),
                                  ('cache', default_cache_config)])
Martin Bauer's avatar
Martin Bauer committed
167

Martin Bauer's avatar
Martin Bauer committed
168 169 170
    config_path, config_exists = get_configuration_file_path()
    config = default_config.copy()
    if config_exists:
Martin Bauer's avatar
Martin Bauer committed
171 172
        with open(config_path, 'r') as json_config_file:
            loaded_config = json.load(json_config_file)
Martin Bauer's avatar
Martin Bauer committed
173
        config = recursive_dict_update(config, loaded_config)
174
    else:
Martin Bauer's avatar
Martin Bauer committed
175 176
        create_folder(config_path, True)
        json.dump(config, open(config_path, 'w'), indent=4)
Martin Bauer's avatar
Martin Bauer committed
177

178 179
    if config['cache']['object_cache'] is not False:
        config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid())
Martin Bauer's avatar
Martin Bauer committed
180

181 182
        if config['cache']['clear_cache_on_start']:
            clear_cache()
Martin Bauer's avatar
Martin Bauer committed
183

184
        create_folder(config['cache']['object_cache'], False)
Martin Bauer's avatar
Martin Bauer committed
185 186

    if config['compiler']['os'] == 'windows':
Martin Bauer's avatar
Martin Bauer committed
187 188
        from pystencils.cpu.msvc_detection import get_environment
        msvc_env = get_environment(config['compiler']['msvc_version'], config['compiler']['arch'])
189 190
        if 'env' not in config['compiler']:
            config['compiler']['env'] = {}
Martin Bauer's avatar
Martin Bauer committed
191
        config['compiler']['env'].update(msvc_env)
Martin Bauer's avatar
Martin Bauer committed
192

193 194 195
    return config


Martin Bauer's avatar
Martin Bauer committed
196
_config = read_config()
197 198


Martin Bauer's avatar
Martin Bauer committed
199
def get_compiler_config():
Martin Bauer's avatar
Martin Bauer committed
200 201 202
    return _config['compiler']


Martin Bauer's avatar
Martin Bauer committed
203
def get_cache_config():
Martin Bauer's avatar
Martin Bauer committed
204
    return _config['cache']
Michael Kuron's avatar
Michael Kuron committed
205 206


207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
def add_or_change_compiler_flags(flags):
    if not isinstance(flags, list) and not isinstance(flags, tuple):
        flags = [flags]

    compiler_config = get_compiler_config()
    cache_config = get_cache_config()
    cache_config['object_cache'] = False  # disable cache

    for flag in flags:
        flag = flag.strip()
        if '=' in flag:
            base = flag.split('=')[0].strip()
        else:
            base = flag

        new_flags = [c for c in compiler_config['flags'].split() if not c.startswith(base)]
        new_flags.append(flag)
        compiler_config['flags'] = ' '.join(new_flags)


Martin Bauer's avatar
Martin Bauer committed
227 228
def clear_cache():
    cache_config = get_cache_config()
229 230 231
    if cache_config['object_cache'] is not False:
        shutil.rmtree(cache_config['object_cache'], ignore_errors=True)
        create_folder(cache_config['object_cache'], False)
Martin Bauer's avatar
Martin Bauer committed
232 233


234 235 236 237 238 239 240 241 242 243
type_mapping = {
    np.float32: ('PyFloat_AsDouble', 'float'),
    np.float64: ('PyFloat_AsDouble', 'double'),
    np.int16: ('PyLong_AsLong', 'int16_t'),
    np.int32: ('PyLong_AsLong', 'int32_t'),
    np.int64: ('PyLong_AsLong', 'int64_t'),
    np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
    np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
    np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
}
Martin Bauer's avatar
Martin Bauer committed
244 245


246 247 248 249 250 251
template_extract_scalar = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
{target_type} {name} = ({target_type}) {extract_function}( obj_{name} );
if( PyErr_Occurred() ) {{ return NULL; }}
"""
Martin Bauer's avatar
Martin Bauer committed
252

253 254 255 256
template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
Py_buffer buffer_{name};
257
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT);
258 259 260 261 262 263 264 265 266 267
if (buffer_{name}_res == -1) {{ return NULL; }}
"""

template_release_buffer = """
PyBuffer_Release(&buffer_{name});
"""

template_function_boilerplate = """
static PyObject * {func_name}(PyObject * self, PyObject * args, PyObject * kwargs)
{{
Martin Bauer's avatar
Martin Bauer committed
268 269 270 271
    if( !kwargs || !PyDict_Check(kwargs) ) {{ 
        PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); 
        return NULL; 
    }}
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
    {pre_call_code}
    kernel_{func_name}({parameters});
    {post_call_code}
    Py_RETURN_NONE;
}}
"""

template_check_array = """
if(!({cond})) {{ 
    PyErr_SetString(PyExc_ValueError, "Wrong {what} of array {name}. Expected {expected}"); 
    return NULL; 
}}
"""

template_size_check = """
if(!({cond})) {{ 
    PyErr_SetString(PyExc_TypeError, "Arrays must have same shape"); return NULL; 
}}"""

template_module_boilerplate = """
static PyMethodDef method_definitions[] = {{
    {method_definitions}
    {{NULL, NULL, 0, NULL}}
}};

static struct PyModuleDef module_definition = {{
    PyModuleDef_HEAD_INIT,
    "{module_name}",   /* name of module */
    NULL,     /* module documentation, may be NULL */
    -1,       /* size of per-interpreter state of the module,
                 or -1 if the module keeps state in global variables. */
    method_definitions
}};

PyMODINIT_FUNC
PyInit_{module_name}(void)
{{
    return PyModule_Create(&module_definition);
}}
"""


def equal_size_check(fields):
    fields = list(fields)
    if len(fields) <= 1:
        return ""

    ref_field = fields[0]
320 321
    cond = ["(buffer_{field.name}.shape[{i}] == buffer_{ref_field.name}.shape[{i}])".format(ref_field=ref_field,
                                                                                            field=field_to_test, i=i)
322 323 324 325 326 327 328 329 330 331 332 333 334
            for field_to_test in fields[1:]
            for i in range(fields[0].spatial_dimensions)]
    cond = " && ".join(cond)
    return template_size_check.format(cond=cond)


def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
    pre_call_code = ""
    parameters = []
    post_call_code = ""
    variable_sized_normal_fields = set()
    variable_sized_index_fields = set()

335 336 337 338 339 340 341
    for param in parameter_info:
        if param.is_field_pointer:
            field = param.fields[0]
            pre_call_code += template_extract_array.format(name=field.name)
            post_call_code += template_release_buffer.format(name=field.name)
            parameters.append("({dtype} *)buffer_{name}.buf".format(dtype=str(field.dtype), name=field.name))

342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
            if insert_checks:
                np_dtype = field.dtype.numpy_dtype
                item_size = np_dtype.itemsize

                if np_dtype.isbuiltin and FieldType.is_generic(field):
                    dtype_cond = "buffer_{name}.format[0] == '{format}'".format(name=field.name,
                                                                                format=field.dtype.numpy_dtype.char)
                    pre_call_code += template_check_array.format(cond=dtype_cond, what="data type", name=field.name,
                                                                 expected=str(field.dtype.numpy_dtype))

                item_size_cond = "buffer_{name}.itemsize == {size}".format(name=field.name, size=item_size)
                pre_call_code += template_check_array.format(cond=item_size_cond, what="itemsize", name=field.name,
                                                             expected=item_size)

                if field.has_fixed_shape:
                    shape_cond = ["buffer_{name}.shape[{i}] == {s}".format(s=s, name=field.name, i=i)
                                  for i, s in enumerate(field.spatial_shape)]
                    shape_cond = " && ".join(shape_cond)
                    pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=field.name,
                                                                 expected=str(field.shape))

                    expected_strides = [e * item_size for e in field.spatial_strides]
                    stride_check_code = "(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)"
                    strides_cond = " && ".join([stride_check_code.format(s=s, i=i, name=field.name)
                                                for i, s in enumerate(expected_strides)])
                    pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=field.name,
                                                                 expected=str(expected_strides))
                else:
                    if FieldType.is_generic(field):
                        variable_sized_normal_fields.add(field)
                    elif FieldType.is_indexed(field):
                        variable_sized_index_fields.add(field)
374 375 376 377 378 379 380
        elif param.is_field_stride:
            field = param.fields[0]
            item_size = field.dtype.numpy_dtype.itemsize
            parameters.append("buffer_{name}.strides[{i}] / {bytes}".format(bytes=item_size, i=param.symbol.coordinate,
                                                                            name=field.name))
        elif param.is_field_shape:
            parameters.append("buffer_{name}.shape[{i}]".format(i=param.symbol.coordinate, name=param.field_name))
Martin Bauer's avatar
Martin Bauer committed
381
        else:
382
            extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
383
            pre_call_code += template_extract_scalar.format(extract_function=extract_function, target_type=target_type,
384 385
                                                            name=param.symbol.name)
            parameters.append(param.symbol.name)
Martin Bauer's avatar
Martin Bauer committed
386

387 388
    pre_call_code += equal_size_check(variable_sized_normal_fields)
    pre_call_code += equal_size_check(variable_sized_index_fields)
Martin Bauer's avatar
Martin Bauer committed
389

390 391 392 393
    pre_call_code = textwrap.indent(pre_call_code, '    ')
    post_call_code = textwrap.indent(post_call_code, '    ')
    return template_function_boilerplate.format(func_name=name, pre_call_code=pre_call_code,
                                                post_call_code=post_call_code, parameters=", ".join(parameters))
Martin Bauer's avatar
Martin Bauer committed
394 395


396 397 398 399
def create_module_boilerplate_code(module_name, names):
    method_definition = '{{"{name}", (PyCFunction){name}, METH_VARARGS | METH_KEYWORDS, ""}},'
    method_definitions = "\n".join([method_definition.format(name=name) for name in names])
    return template_module_boilerplate.format(module_name=module_name, method_definitions=method_definitions)
Martin Bauer's avatar
Martin Bauer committed
400

401

402 403
def load_kernel_from_file(module_name, function_name, path):
    from importlib.util import spec_from_file_location, module_from_spec
404 405 406 407 408 409 410 411 412 413 414 415 416
    try:
        spec = spec_from_file_location(name=module_name, location=path)
        mod = module_from_spec(spec)
        spec.loader.exec_module(mod)
    except ImportError:
        import time
        import warnings
        warnings.warn("Could not load " + path + ", trying on more time...")
        time.sleep(1)
        spec = spec_from_file_location(name=module_name, location=path)
        mod = module_from_spec(spec)
        spec.loader.exec_module(mod)

417
    return getattr(mod, function_name)
418

Martin Bauer's avatar
Martin Bauer committed
419

Martin Bauer's avatar
Martin Bauer committed
420 421 422 423 424
def run_compile_step(command):
    compiler_config = get_compiler_config()
    config_env = compiler_config['env'] if 'env' in compiler_config else {}
    compile_environment = os.environ.copy()
    compile_environment.update(config_env)
Martin Bauer's avatar
Martin Bauer committed
425

Michael Kuron's avatar
Michael Kuron committed
426
    try:
Martin Bauer's avatar
Martin Bauer committed
427 428
        shell = True if compiler_config['os'].lower() == 'windows' else False
        subprocess.check_output(command, env=compile_environment, stderr=subprocess.STDOUT, shell=shell)
Michael Kuron's avatar
Michael Kuron committed
429
    except subprocess.CalledProcessError as e:
Martin Bauer's avatar
Martin Bauer committed
430
        print(" ".join(command))
431
        print(e.output.decode('utf8'))
Michael Kuron's avatar
Michael Kuron committed
432
        raise e
433 434


435 436 437
class ExtensionModuleCode:
    def __init__(self, module_name='generated'):
        self.module_name = module_name
Martin Bauer's avatar
Martin Bauer committed
438

439 440
        self._ast_nodes = []
        self._function_names = []
Martin Bauer's avatar
Martin Bauer committed
441

442 443 444
    def add_function(self, ast, name=None):
        self._ast_nodes.append(ast)
        self._function_names.append(name if name is not None else ast.function_name)
Martin Bauer's avatar
Martin Bauer committed
445

446
    def write_to_file(self, restrict_qualifier, function_prefix, file):
447
        headers = {'<math.h>', '<stdint.h>'}
448 449
        for ast in self._ast_nodes:
            headers.update(get_headers(ast))
450 451 452
        header_list = list(headers)
        header_list.sort()
        header_list.insert(0, '"Python.h"')
453

454
        includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
455 456 457 458 459
        print(includes, file=file)
        print("\n", file=file)
        print("#define RESTRICT %s" % (restrict_qualifier,), file=file)
        print("#define FUNC_PREFIX %s" % (function_prefix,), file=file)
        print("\n", file=file)
460

461 462 463 464
        for ast, name in zip(self._ast_nodes, self._function_names):
            old_name = ast.function_name
            ast.function_name = "kernel_" + name
            print(generate_c(ast), file=file)
465
            print(create_function_boilerplate_code(ast.get_parameters(), name), file=file)
466 467
            ast.function_name = old_name
        print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
Martin Bauer's avatar
Martin Bauer committed
468 469


470 471 472 473 474
class KernelWrapper:
    def __init__(self, kernel, parameters, ast_node):
        self.kernel = kernel
        self.parameters = parameters
        self.ast = ast_node
Martin Bauer's avatar
Martin Bauer committed
475

476 477
    def __call__(self, **kwargs):
        return self.kernel(**kwargs)
478

479

480
def compile_module(code, code_hash, base_dir):
481
    compiler_config = get_compiler_config()
482
    extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()]
483

484 485
    if compiler_config['os'].lower() == 'windows':
        function_prefix = '__declspec(dllexport)'
486
        lib_suffix = '.pyd'
487 488 489 490 491 492 493 494
        object_suffix = '.obj'
        windows = True
    else:
        function_prefix = ''
        lib_suffix = '.so'
        object_suffix = '.o'
        windows = False

495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521
    src_file = os.path.join(base_dir, code_hash + ".cpp")
    lib_file = os.path.join(base_dir, code_hash + lib_suffix)
    object_file = os.path.join(base_dir, code_hash + object_suffix)

    if not os.path.exists(object_file):
        with file_handle_for_atomic_write(src_file) as f:
            code.write_to_file(compiler_config['restrict_qualifier'], function_prefix, f)

        if windows:
            compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split()
            compile_cmd += [*extra_flags, src_file, '/Fo' + object_file]
            run_compile_step(compile_cmd)
        else:
            with atomic_file_write(object_file) as file_name:
                compile_cmd = [compiler_config['command'], '-c'] + compiler_config['flags'].split()
                compile_cmd += [*extra_flags, '-o', file_name, src_file]
                run_compile_step(compile_cmd)

        # Linking
        if windows:
            import sysconfig
            config_vars = sysconfig.get_config_vars()
            py_lib = os.path.join(config_vars["installed_base"], "libs",
                                  "python{}.lib".format(config_vars["py_version_nodot"]))
            run_compile_step(['link.exe', py_lib, '/DLL', '/out:' + lib_file, object_file])
        else:
            with atomic_file_write(lib_file) as file_name:
Martin Bauer's avatar
Martin Bauer committed
522 523
                run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name]
                                 + compiler_config['flags'].split())
524 525 526 527 528
    return lib_file


def compile_and_load(ast):
    cache_config = get_cache_config()
529
    code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c').encode()).hexdigest()
530 531
    code = ExtensionModuleCode(module_name=code_hash_str)
    code.add_function(ast, ast.function_name)
532 533 534 535 536 537 538 539 540

    if cache_config['object_cache'] is False:
        with TemporaryDirectory() as base_dir:
            lib_file = compile_module(code, code_hash_str, base_dir)
            result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)
    else:
        lib_file = compile_module(code, code_hash_str, base_dir=cache_config['object_cache'])
        result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)

541
    return KernelWrapper(result, ast.get_parameters(), ast)