cpujit.py 18.6 KB
Newer Older
1
r"""
2

Martin Bauer's avatar
Martin Bauer committed
3
4
5
6
7
8
*pystencils* automatically searches for a compiler, so in most cases no explicit configuration is required.
On Linux make sure that 'gcc' and 'g++' are installed and in your path.
On Windows a recent Visual Studio installation is required.
In case anything does not work as expected or a special compiler should be used, changes can be specified
in a configuration file.

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
*pystencils* looks for a configuration file in JSON format at the following locations in the listed order.

1. at the path specified in the environment variable ``PYSTENCILS_CONFIG``
2. in the current working direction for a file named ``pystencils.json``
3. or in your home directory at ``~/.config/pystencils/config.json`` (Linux) or
   ``%HOMEPATH%\.pystencils\config.json`` (Windows)

If no configuration file is found, a default configuration is created at the above mentioned location in your home.
So run *pystencils* once, then edit the created configuration file.


Compiler Config (Linux)
-----------------------

- **'os'**: should be detected automatically as 'linux'
- **'command'**: path to C++ compiler (defaults to 'g++')
- **'flags'**: space separated list of compiler flags. Make sure to activate OpenMP in your compiler
Martin Bauer's avatar
Martin Bauer committed
26
- **'restrict_qualifier'**: the restrict qualifier is not standardized accross compilers.
27
28
29
30
31
32
33
34
35
36
  For most Linux compilers the qualifier is ``__restrict__``


Compiler Config (Windows)
-------------------------

*pystencils* uses the mechanism of *setuptools.msvc* to search for a compilation environment.
Then 'cl.exe' is used to compile.

- **'os'**: should be detected automatically as 'windows'
Martin Bauer's avatar
Martin Bauer committed
37
- **'msvc_version'**:  either a version number, year number, 'auto' or 'latest' for automatic detection of latest
38
39
                      installed version or 'setuptools' for setuptools-based detection. Alternatively path to folder
                      where Visual Studio is installed. This path has to contain a file called 'vcvarsall.bat'
40
41
- **'arch'**: 'x86' or 'x64'
- **'flags'**: flags passed to 'cl.exe', make sure OpenMP is activated
Martin Bauer's avatar
Martin Bauer committed
42
- **'restrict_qualifier'**: the restrict qualifier is not standardized across compilers.
43
44
45
46
47
48
49
50
51
52
  For Windows compilers the qualifier should be ``__restrict``


Cache Config
------------

*pystencils* uses a directory to store intermediate files like the generated C++ files, compiled object files and
the shared libraries which are then loaded from Python using ctypes. The file names are SHA hashes of the
generated code. If the same kernel was already compiled, the existing object file is used - no recompilation is done.

Martin Bauer's avatar
Martin Bauer committed
53
If 'shared_library' is specified, all kernels that are currently in the cache are compiled into a single shared library.
54
55
56
This mechanism can be used to run *pystencils* on systems where compilation is not possible, e.g. on clusters where
compilation on the compute nodes is not possible.
First the script is run on a system where compilation is possible (e.g. the login node) with
Martin Bauer's avatar
Martin Bauer committed
57
'read_from_shared_library=False' and with 'shared_library' set a valid path.
58
59
All kernels generated during the run are put into the cache and at the end
compiled into the shared library. Then, the same script can be run from the compute nodes, with
Martin Bauer's avatar
Martin Bauer committed
60
'read_from_shared_library=True', such that kernels are taken from the library instead of compiling them.
61
62


Martin Bauer's avatar
Martin Bauer committed
63
64
65
- **'read_from_shared_library'**: if true kernels are not compiled but assumed to be in the shared library
- **'object_cache'**: path to a folder where intermediate files are stored
- **'clear_cache_on_start'**: when true the cache is cleared on each start of a *pystencils* script
66
- **'shared_library'**: path to a shared library file, which is created if 'read_from_shared_library=false'
67
"""
Michael Kuron's avatar
Michael Kuron committed
68
from __future__ import print_function
69
70
import os
import subprocess
71
import hashlib
72
import json
Martin Bauer's avatar
Martin Bauer committed
73
74
75
import platform
import glob
import atexit
76
import shutil
77
import numpy as np
78
from appdirs import user_config_dir, user_cache_dir
Martin Bauer's avatar
Martin Bauer committed
79
from ctypes import cdll
Martin Bauer's avatar
Martin Bauer committed
80
from pystencils.backends.cbackend import generate_c, get_headers
Martin Bauer's avatar
Martin Bauer committed
81
from collections import OrderedDict
Martin Bauer's avatar
Martin Bauer committed
82
from pystencils.transformations import symbol_name_to_variable_name
Martin Bauer's avatar
Martin Bauer committed
83
from pystencils.data_types import to_ctypes, get_base_type, StructType
84
from pystencils.field import FieldType
85
from pystencils.utils import recursive_dict_update, file_handle_for_atomic_write, atomic_file_write
86

87

Martin Bauer's avatar
Martin Bauer committed
88
def make_python_function(kernel_function_node, argument_dict={}):
89
90
91
92
93
94
95
    """
    Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function

    The parameters of the kernel are:
        - numpy arrays for each field used in the kernel. The keyword argument name is the name of the field
        - all symbols which are not defined in the kernel itself are expected as parameters

Martin Bauer's avatar
Martin Bauer committed
96
97
    :param kernel_function_node: the abstract syntax tree
    :param argument_dict: parameters passed here are already fixed. Remaining parameters have to be passed to the
98
99
100
101
                        returned kernel functor.
    :return: kernel functor
    """
    # build up list of CType arguments
Martin Bauer's avatar
Martin Bauer committed
102
    func = compile_and_load(kernel_function_node)
103
    func.restype = None
104
    try:
Martin Bauer's avatar
Martin Bauer committed
105
        args = build_ctypes_argument_list(kernel_function_node.parameters, argument_dict)
106
107
    except KeyError:
        # not all parameters specified yet
Martin Bauer's avatar
Martin Bauer committed
108
        return make_python_function_incomplete_params(kernel_function_node, argument_dict, func)
109
110
111
    return lambda: func(*args)


Martin Bauer's avatar
Martin Bauer committed
112
def set_config(config):
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    """
    Override the configuration provided in config file

    Configuration of compiler parameters:
    If this function is not called the configuration is taken from a config file in JSON format which
    is searched in the following locations in the order specified:
        - at location provided in environment variable PYSTENCILS_CONFIG (if this variable exists)
        - a file called ".pystencils.json" in the current working directory
        - ~/.pystencils.json in your home
    If none of these files exist a file ~/.pystencils.json is created with a default configuration using
    the GNU 'g++'

    An example JSON file with all possible keys. If not all keys are specified, default values are used
    ``
    {
Martin Bauer's avatar
Martin Bauer committed
128
129
130
131
132
133
134
        'compiler' :
        {
            "command": "/software/intel/2017/bin/icpc",
            "flags": "-Ofast -DNDEBUG -fPIC -march=native -fopenmp",
            "env": {
                "LM_PROJECT": "iwia",
            }
135
        }
136
    }
137
138
    ``
    """
Martin Bauer's avatar
Martin Bauer committed
139
140
141
142
    global _config
    _config = config.copy()


Martin Bauer's avatar
Martin Bauer committed
143
144
def get_configuration_file_path():
    config_path_in_home = os.path.join(user_config_dir('pystencils'), 'config.json')
145

146
147
148
    # 1) Read path from environment variable if found
    if 'PYSTENCILS_CONFIG' in os.environ:
        return os.environ['PYSTENCILS_CONFIG'], True
Martin Bauer's avatar
Martin Bauer committed
149
150
151
    # 2) Look in current directory for pystencils.json
    elif os.path.exists("pystencils.json"):
        return "pystencils.json", True
152
    # 3) Try ~/.pystencils.json
Martin Bauer's avatar
Martin Bauer committed
153
154
    elif os.path.exists(config_path_in_home):
        return config_path_in_home, True
155
    else:
Martin Bauer's avatar
Martin Bauer committed
156
        return config_path_in_home, False
157
158


Martin Bauer's avatar
Martin Bauer committed
159
160
def create_folder(path, is_file):
    if is_file:
Martin Bauer's avatar
Martin Bauer committed
161
162
163
164
165
166
167
        path = os.path.split(path)[0]
    try:
        os.makedirs(path)
    except os.error:
        pass


Martin Bauer's avatar
Martin Bauer committed
168
def read_config():
Martin Bauer's avatar
Martin Bauer committed
169
    if platform.system().lower() == 'linux':
Martin Bauer's avatar
Martin Bauer committed
170
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
171
172
            ('os', 'linux'),
            ('command', 'g++'),
173
            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11'),
Martin Bauer's avatar
Martin Bauer committed
174
            ('restrict_qualifier', '__restrict__')
Martin Bauer's avatar
Martin Bauer committed
175
        ])
176

Martin Bauer's avatar
Martin Bauer committed
177
    elif platform.system().lower() == 'windows':
Martin Bauer's avatar
Martin Bauer committed
178
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
179
            ('os', 'windows'),
Martin Bauer's avatar
Martin Bauer committed
180
            ('msvc_version', 'latest'),
Martin Bauer's avatar
Martin Bauer committed
181
            ('arch', 'x64'),
182
            ('flags', '/Ox /fp:fast /openmp /arch:avx'),
Martin Bauer's avatar
Martin Bauer committed
183
            ('restrict_qualifier', '__restrict')
Martin Bauer's avatar
Martin Bauer committed
184
        ])
Martin Bauer's avatar
Martin Bauer committed
185
186
187
188
189
    default_cache_config = OrderedDict([
        ('read_from_shared_library', False),
        ('object_cache', os.path.join(user_cache_dir('pystencils'), 'objectcache')),
        ('clear_cache_on_start', False),
        ('shared_library', os.path.join(user_cache_dir('pystencils'), 'cache.so')),
190
    ])
Martin Bauer's avatar
Martin Bauer committed
191

Martin Bauer's avatar
Martin Bauer committed
192
193
    default_config = OrderedDict([('compiler', default_compiler_config),
                                  ('cache', default_cache_config)])
Martin Bauer's avatar
Martin Bauer committed
194

Martin Bauer's avatar
Martin Bauer committed
195
196
197
    config_path, config_exists = get_configuration_file_path()
    config = default_config.copy()
    if config_exists:
Martin Bauer's avatar
Martin Bauer committed
198
199
        with open(config_path, 'r') as json_config_file:
            loaded_config = json.load(json_config_file)
Martin Bauer's avatar
Martin Bauer committed
200
        config = recursive_dict_update(config, loaded_config)
201
    else:
Martin Bauer's avatar
Martin Bauer committed
202
203
        create_folder(config_path, True)
        json.dump(config, open(config_path, 'w'), indent=4)
Martin Bauer's avatar
Martin Bauer committed
204

Martin Bauer's avatar
Martin Bauer committed
205
206
    config['cache']['shared_library'] = os.path.expanduser(config['cache']['shared_library']).format(pid=os.getpid())
    config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid())
Martin Bauer's avatar
Martin Bauer committed
207

Martin Bauer's avatar
Martin Bauer committed
208
    if config['cache']['clear_cache_on_start']:
Martin Bauer's avatar
Martin Bauer committed
209
        clear_cache()
Martin Bauer's avatar
Martin Bauer committed
210

Martin Bauer's avatar
Martin Bauer committed
211
212
    create_folder(config['cache']['object_cache'], False)
    create_folder(config['cache']['shared_library'], True)
Martin Bauer's avatar
Martin Bauer committed
213
214
215
216
217

    if 'env' not in config['compiler']:
        config['compiler']['env'] = {}

    if config['compiler']['os'] == 'windows':
Martin Bauer's avatar
Martin Bauer committed
218
219
220
        from pystencils.cpu.msvc_detection import get_environment
        msvc_env = get_environment(config['compiler']['msvc_version'], config['compiler']['arch'])
        config['compiler']['env'].update(msvc_env)
Martin Bauer's avatar
Martin Bauer committed
221

222
223
224
    return config


Martin Bauer's avatar
Martin Bauer committed
225
_config = read_config()
226
227


Martin Bauer's avatar
Martin Bauer committed
228
def get_compiler_config():
Martin Bauer's avatar
Martin Bauer committed
229
230
231
    return _config['compiler']


Martin Bauer's avatar
Martin Bauer committed
232
def get_cache_config():
Martin Bauer's avatar
Martin Bauer committed
233
    return _config['cache']
Michael Kuron's avatar
Michael Kuron committed
234
235


Martin Bauer's avatar
Martin Bauer committed
236
def hash_to_function_name(h):
Martin Bauer's avatar
Martin Bauer committed
237
238
239
240
    res = "func_%s" % (h,)
    return res.replace('-', 'm')


Martin Bauer's avatar
Martin Bauer committed
241
242
243
244
245
246
def clear_cache():
    cache_config = get_cache_config()
    shutil.rmtree(cache_config['object_cache'], ignore_errors=True)
    create_folder(cache_config['object_cache'], False)


Martin Bauer's avatar
Martin Bauer committed
247
248
249
def compile_object_cache_to_shared_library():
    compiler_config = get_compiler_config()
    cache_config = get_cache_config()
Martin Bauer's avatar
Martin Bauer committed
250

Martin Bauer's avatar
Martin Bauer committed
251
252
    shared_library = cache_config['shared_library']
    if len(shared_library) == 0 or cache_config['read_from_shared_library']:
Martin Bauer's avatar
Martin Bauer committed
253
254
        return

Martin Bauer's avatar
Martin Bauer committed
255
256
257
    config_env = compiler_config['env'] if 'env' in compiler_config else {}
    compile_environment = os.environ.copy()
    compile_environment.update(config_env)
Martin Bauer's avatar
Martin Bauer committed
258
259

    try:
Martin Bauer's avatar
Martin Bauer committed
260
261
        if compiler_config['os'] == 'windows':
            all_object_files = glob.glob(os.path.join(cache_config['object_cache'], '*.obj'))
Martin Bauer's avatar
Martin Bauer committed
262
            link_cmd = ['link.exe', '/DLL', '/out:' + shared_library]
Martin Bauer's avatar
Martin Bauer committed
263
        else:
Martin Bauer's avatar
Martin Bauer committed
264
265
            all_object_files = glob.glob(os.path.join(cache_config['object_cache'], '*.o'))
            link_cmd = [compiler_config['command'], '-shared', '-o', shared_library]
Martin Bauer's avatar
Martin Bauer committed
266

Martin Bauer's avatar
Martin Bauer committed
267
268
269
        link_cmd += all_object_files
        if len(all_object_files) > 0:
            run_compile_step(link_cmd)
Martin Bauer's avatar
Martin Bauer committed
270
271
272
273
274
    except subprocess.CalledProcessError as e:
        print(e.output)
        raise e


Martin Bauer's avatar
Martin Bauer committed
275
atexit.register(compile_object_cache_to_shared_library)
Martin Bauer's avatar
Martin Bauer committed
276

Martin Bauer's avatar
Martin Bauer committed
277

278
def generate_code(ast, restrict_qualifier, function_prefix, source_file):
Martin Bauer's avatar
Martin Bauer committed
279
    headers = get_headers(ast)
280
281
    headers.update(['<cmath>', '<cstdint>'])

282
283
284
285
286
287
288
289
    code = generate_c(ast)
    includes = "\n".join(["#include %s" % (include_file,) for include_file in headers])
    print(includes, file=source_file)
    print("#define RESTRICT %s" % (restrict_qualifier,), file=source_file)
    print("#define FUNC_PREFIX %s" % (function_prefix,), file=source_file)
    print('extern "C" { ', file=source_file)
    print(code, file=source_file)
    print('}', file=source_file)
290

Martin Bauer's avatar
Martin Bauer committed
291

Martin Bauer's avatar
Martin Bauer committed
292
293
294
295
296
def run_compile_step(command):
    compiler_config = get_compiler_config()
    config_env = compiler_config['env'] if 'env' in compiler_config else {}
    compile_environment = os.environ.copy()
    compile_environment.update(config_env)
Martin Bauer's avatar
Martin Bauer committed
297

Michael Kuron's avatar
Michael Kuron committed
298
    try:
Martin Bauer's avatar
Martin Bauer committed
299
300
        shell = True if compiler_config['os'].lower() == 'windows' else False
        subprocess.check_output(command, env=compile_environment, stderr=subprocess.STDOUT, shell=shell)
Michael Kuron's avatar
Michael Kuron committed
301
    except subprocess.CalledProcessError as e:
Martin Bauer's avatar
Martin Bauer committed
302
        print(" ".join(command))
303
        print(e.output.decode('utf8'))
Michael Kuron's avatar
Michael Kuron committed
304
        raise e
305
306


Martin Bauer's avatar
Martin Bauer committed
307
308
309
def compile_linux(ast, code_hash_str, src_file, lib_file):
    cache_config = get_cache_config()
    compiler_config = get_compiler_config()
Martin Bauer's avatar
Martin Bauer committed
310

Martin Bauer's avatar
Martin Bauer committed
311
312
    object_file = os.path.join(cache_config['object_cache'], code_hash_str + '.o')
    if not os.path.exists(object_file):
313
314
315
316
317
318
        with file_handle_for_atomic_write(src_file) as f:
            generate_code(ast, compiler_config['restrict_qualifier'], '', f)
        with atomic_file_write(object_file) as file_name:
            compile_cmd = [compiler_config['command'], '-c'] + compiler_config['flags'].split()
            compile_cmd += ['-o', file_name, src_file]
            run_compile_step(compile_cmd)
Martin Bauer's avatar
Martin Bauer committed
319
320

    # Linking
321
322
323
    with atomic_file_write(lib_file) as file_name:
        run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name] +
                         compiler_config['flags'].split())
Martin Bauer's avatar
Martin Bauer committed
324
325


Martin Bauer's avatar
Martin Bauer committed
326
327
328
def compile_windows(ast, code_hash_str, src_file, lib_file):
    cache_config = get_cache_config()
    compiler_config = get_compiler_config()
Martin Bauer's avatar
Martin Bauer committed
329

Martin Bauer's avatar
Martin Bauer committed
330
    object_file = os.path.join(cache_config['object_cache'], code_hash_str + '.obj')
Martin Bauer's avatar
Martin Bauer committed
331
    # Compilation
Martin Bauer's avatar
Martin Bauer committed
332
    if not os.path.exists(object_file):
333
334
        with file_handle_for_atomic_write(src_file) as f:
            generate_code(ast, compiler_config['restrict_qualifier'], '__declspec(dllexport)', f)
Martin Bauer's avatar
Martin Bauer committed
335
336

        # /c compiles only, /EHsc turns of exception handling in c code
Martin Bauer's avatar
Martin Bauer committed
337
338
339
        compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split()
        compile_cmd += [src_file, '/Fo' + object_file]
        run_compile_step(compile_cmd)
Martin Bauer's avatar
Martin Bauer committed
340
341

    # Linking
Martin Bauer's avatar
Martin Bauer committed
342
    run_compile_step(['link.exe', '/DLL', '/out:' + lib_file, object_file])
343

344

Martin Bauer's avatar
Martin Bauer committed
345
346
def compile_and_load(ast):
    cache_config = get_cache_config()
Martin Bauer's avatar
Martin Bauer committed
347

Martin Bauer's avatar
Martin Bauer committed
348
349
    code_hash_str = hashlib.sha256(generate_c(ast).encode()).hexdigest()
    ast.function_name = hash_to_function_name(code_hash_str)
Martin Bauer's avatar
Martin Bauer committed
350

Martin Bauer's avatar
Martin Bauer committed
351
    src_file = os.path.join(cache_config['object_cache'], code_hash_str + ".cpp")
Martin Bauer's avatar
Martin Bauer committed
352

Martin Bauer's avatar
Martin Bauer committed
353
354
    if cache_config['read_from_shared_library']:
        return cdll.LoadLibrary(cache_config['shared_library'])[ast.function_name]
Martin Bauer's avatar
Martin Bauer committed
355
    else:
Martin Bauer's avatar
Martin Bauer committed
356
        if get_compiler_config()['os'].lower() == 'windows':
Martin Bauer's avatar
Martin Bauer committed
357
358
359
            lib_file = os.path.join(cache_config['object_cache'], code_hash_str + ".dll")
            if not os.path.exists(lib_file):
                compile_windows(ast, code_hash_str, src_file, lib_file)
Martin Bauer's avatar
Martin Bauer committed
360
        else:
Martin Bauer's avatar
Martin Bauer committed
361
362
363
364
            lib_file = os.path.join(cache_config['object_cache'], code_hash_str + ".so")
            if not os.path.exists(lib_file):
                compile_linux(ast, code_hash_str, src_file, lib_file)
        return cdll.LoadLibrary(lib_file)[ast.function_name]
365
366


Martin Bauer's avatar
Martin Bauer committed
367
368
369
370
371
def build_ctypes_argument_list(parameter_specification, argument_dict):
    argument_dict = {symbol_name_to_variable_name(k): v for k, v in argument_dict.items()}
    ct_arguments = []
    array_shapes = set()
    index_arr_shapes = set()
372

Martin Bauer's avatar
Martin Bauer committed
373
    for arg in parameter_specification:
Martin Bauer's avatar
Martin Bauer committed
374
        if arg.is_field_argument:
375
            try:
Martin Bauer's avatar
Martin Bauer committed
376
                field_arr = argument_dict[arg.field_name]
377
            except KeyError:
Martin Bauer's avatar
Martin Bauer committed
378
                raise KeyError("Missing field parameter for kernel call " + arg.field_name)
379

Martin Bauer's avatar
Martin Bauer committed
380
            symbolic_field = arg.field
Martin Bauer's avatar
Martin Bauer committed
381
            if arg.is_field_ptr_argument:
Martin Bauer's avatar
Martin Bauer committed
382
383
384
385
386
387
                ct_arguments.append(field_arr.ctypes.data_as(to_ctypes(arg.dtype)))
                if symbolic_field.has_fixed_shape:
                    symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
                    if isinstance(symbolic_field.dtype, StructType):
                        symbolic_field_shape = symbolic_field_shape[:-1]
                    if symbolic_field_shape != field_arr.shape:
388
                        raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
Martin Bauer's avatar
Martin Bauer committed
389
390
391
392
393
394
                                         (arg.field_name, str(field_arr.shape), str(symbolic_field.shape)))
                if symbolic_field.has_fixed_shape:
                    symbolic_field_strides = tuple(int(i) * field_arr.itemsize for i in symbolic_field.strides)
                    if isinstance(symbolic_field.dtype, StructType):
                        symbolic_field_strides = symbolic_field_strides[:-1]
                    if symbolic_field_strides != field_arr.strides:
395
                        raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
Martin Bauer's avatar
Martin Bauer committed
396
                                         (arg.field_name, str(field_arr.strides), str(symbolic_field_strides)))
397

Martin Bauer's avatar
Martin Bauer committed
398
399
400
401
                if FieldType.is_indexed(symbolic_field):
                    index_arr_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
                elif not FieldType.is_buffer(symbolic_field):
                    array_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
402

Martin Bauer's avatar
Martin Bauer committed
403
            elif arg.is_field_shape_argument:
Martin Bauer's avatar
Martin Bauer committed
404
405
                data_type = to_ctypes(get_base_type(arg.dtype))
                ct_arguments.append(field_arr.ctypes.shape_as(data_type))
Martin Bauer's avatar
Martin Bauer committed
406
            elif arg.is_field_stride_argument:
Martin Bauer's avatar
Martin Bauer committed
407
408
409
410
411
412
                data_type = to_ctypes(get_base_type(arg.dtype))
                strides = field_arr.ctypes.strides_as(data_type)
                for i in range(len(field_arr.shape)):
                    assert strides[i] % field_arr.itemsize == 0
                    strides[i] //= field_arr.itemsize
                ct_arguments.append(strides)
413
414
415
            else:
                assert False
        else:
416
            try:
Martin Bauer's avatar
Martin Bauer committed
417
                param = argument_dict[arg.name]
418
419
            except KeyError:
                raise KeyError("Missing parameter for kernel call " + arg.name)
Martin Bauer's avatar
Martin Bauer committed
420
421
            expected_type = to_ctypes(arg.dtype)
            ct_arguments.append(expected_type(param))
422

Martin Bauer's avatar
Martin Bauer committed
423
424
425
426
    if len(array_shapes) > 1:
        raise ValueError("All passed arrays have to have the same size " + str(array_shapes))
    if len(index_arr_shapes) > 1:
        raise ValueError("All passed index arrays have to have the same size " + str(array_shapes))
427

Martin Bauer's avatar
Martin Bauer committed
428
    return ct_arguments
429
430


Martin Bauer's avatar
Martin Bauer committed
431
432
def make_python_function_incomplete_params(kernel_function_node, argument_dict, func):
    parameters = kernel_function_node.parameters
433

434
    cache = {}
Martin Bauer's avatar
Martin Bauer committed
435
    cache_values = []
436

437
    def wrapper(**kwargs):
438
439
        key = hash(tuple((k, v.ctypes.data, v.strides, v.shape) if isinstance(v, np.ndarray) else (k, id(v))
                         for k, v in kwargs.items()))
440
441
442
443
        try:
            args = cache[key]
            func(*args)
        except KeyError:
Martin Bauer's avatar
Martin Bauer committed
444
445
446
            full_arguments = argument_dict.copy()
            full_arguments.update(kwargs)
            args = build_ctypes_argument_list(parameters, full_arguments)
447
            cache[key] = args
Martin Bauer's avatar
Martin Bauer committed
448
            cache_values.append(kwargs)  # keep objects alive such that ids remain unique
449
            func(*args)
Martin Bauer's avatar
Martin Bauer committed
450
451
    wrapper.ast = kernel_function_node
    wrapper.parameters = kernel_function_node.parameters
452
    return wrapper