cpujit.py 18.5 KB
Newer Older
1
r"""
2

Martin Bauer's avatar
Martin Bauer committed
3
4
5
6
7
8
*pystencils* automatically searches for a compiler, so in most cases no explicit configuration is required.
On Linux make sure that 'gcc' and 'g++' are installed and in your path.
On Windows a recent Visual Studio installation is required.
In case anything does not work as expected or a special compiler should be used, changes can be specified
in a configuration file.

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
*pystencils* looks for a configuration file in JSON format at the following locations in the listed order.

1. at the path specified in the environment variable ``PYSTENCILS_CONFIG``
2. in the current working direction for a file named ``pystencils.json``
3. or in your home directory at ``~/.config/pystencils/config.json`` (Linux) or
   ``%HOMEPATH%\.pystencils\config.json`` (Windows)

If no configuration file is found, a default configuration is created at the above mentioned location in your home.
So run *pystencils* once, then edit the created configuration file.


Compiler Config (Linux)
-----------------------

- **'os'**: should be detected automatically as 'linux'
- **'command'**: path to C++ compiler (defaults to 'g++')
- **'flags'**: space separated list of compiler flags. Make sure to activate OpenMP in your compiler
Martin Bauer's avatar
Martin Bauer committed
26
- **'restrict_qualifier'**: the restrict qualifier is not standardized accross compilers.
27
28
29
30
31
32
33
34
35
36
  For most Linux compilers the qualifier is ``__restrict__``


Compiler Config (Windows)
-------------------------

*pystencils* uses the mechanism of *setuptools.msvc* to search for a compilation environment.
Then 'cl.exe' is used to compile.

- **'os'**: should be detected automatically as 'windows'
Martin Bauer's avatar
Martin Bauer committed
37
- **'msvc_version'**:  either a version number, year number, 'auto' or 'latest' for automatic detection of latest
38
39
                      installed version or 'setuptools' for setuptools-based detection. Alternatively path to folder
                      where Visual Studio is installed. This path has to contain a file called 'vcvarsall.bat'
40
41
- **'arch'**: 'x86' or 'x64'
- **'flags'**: flags passed to 'cl.exe', make sure OpenMP is activated
Martin Bauer's avatar
Martin Bauer committed
42
- **'restrict_qualifier'**: the restrict qualifier is not standardized across compilers.
43
44
45
46
47
48
49
50
51
52
  For Windows compilers the qualifier should be ``__restrict``


Cache Config
------------

*pystencils* uses a directory to store intermediate files like the generated C++ files, compiled object files and
the shared libraries which are then loaded from Python using ctypes. The file names are SHA hashes of the
generated code. If the same kernel was already compiled, the existing object file is used - no recompilation is done.

Martin Bauer's avatar
Martin Bauer committed
53
If 'shared_library' is specified, all kernels that are currently in the cache are compiled into a single shared library.
54
55
56
This mechanism can be used to run *pystencils* on systems where compilation is not possible, e.g. on clusters where
compilation on the compute nodes is not possible.
First the script is run on a system where compilation is possible (e.g. the login node) with
Martin Bauer's avatar
Martin Bauer committed
57
'read_from_shared_library=False' and with 'shared_library' set a valid path.
58
59
All kernels generated during the run are put into the cache and at the end
compiled into the shared library. Then, the same script can be run from the compute nodes, with
Martin Bauer's avatar
Martin Bauer committed
60
'read_from_shared_library=True', such that kernels are taken from the library instead of compiling them.
61
62


Martin Bauer's avatar
Martin Bauer committed
63
64
65
66
- **'read_from_shared_library'**: if true kernels are not compiled but assumed to be in the shared library
- **'object_cache'**: path to a folder where intermediate files are stored
- **'clear_cache_on_start'**: when true the cache is cleared on each start of a *pystencils* script
- **'shared_library'**: path to a shared library file, which is created if `read_from_shared_library=false`
67
"""
Michael Kuron's avatar
Michael Kuron committed
68
from __future__ import print_function
69
70
import os
import subprocess
71
import hashlib
72
import json
Martin Bauer's avatar
Martin Bauer committed
73
74
75
import platform
import glob
import atexit
76
import shutil
77
import numpy as np
78
from appdirs import user_config_dir, user_cache_dir
Martin Bauer's avatar
Martin Bauer committed
79
from ctypes import cdll
Martin Bauer's avatar
Martin Bauer committed
80
from pystencils.backends.cbackend import generate_c, get_headers
Martin Bauer's avatar
Martin Bauer committed
81
from collections import OrderedDict
Martin Bauer's avatar
Martin Bauer committed
82
from pystencils.transformations import symbol_name_to_variable_name
Martin Bauer's avatar
Martin Bauer committed
83
from pystencils.data_types import to_ctypes, get_base_type, StructType
84
from pystencils.field import FieldType
85
from pystencils.utils import recursive_dict_update, file_handle_for_atomic_write, atomic_file_write
86

87

Martin Bauer's avatar
Martin Bauer committed
88
def make_python_function(kernel_function_node, argument_dict={}):
89
90
91
92
93
94
95
    """
    Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function

    The parameters of the kernel are:
        - numpy arrays for each field used in the kernel. The keyword argument name is the name of the field
        - all symbols which are not defined in the kernel itself are expected as parameters

Martin Bauer's avatar
Martin Bauer committed
96
97
    :param kernel_function_node: the abstract syntax tree
    :param argument_dict: parameters passed here are already fixed. Remaining parameters have to be passed to the
98
99
100
101
                        returned kernel functor.
    :return: kernel functor
    """
    # build up list of CType arguments
Martin Bauer's avatar
Martin Bauer committed
102
    func = compile_and_load(kernel_function_node)
103
    func.restype = None
104
    try:
Martin Bauer's avatar
Martin Bauer committed
105
        args = build_ctypes_argument_list(kernel_function_node.parameters, argument_dict)
106
107
    except KeyError:
        # not all parameters specified yet
Martin Bauer's avatar
Martin Bauer committed
108
        return make_python_function_incomplete_params(kernel_function_node, argument_dict, func)
109
110
111
    return lambda: func(*args)


Martin Bauer's avatar
Martin Bauer committed
112
def set_compiler_config(config):
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    """
    Override the configuration provided in config file

    Configuration of compiler parameters:
    If this function is not called the configuration is taken from a config file in JSON format which
    is searched in the following locations in the order specified:
        - at location provided in environment variable PYSTENCILS_CONFIG (if this variable exists)
        - a file called ".pystencils.json" in the current working directory
        - ~/.pystencils.json in your home
    If none of these files exist a file ~/.pystencils.json is created with a default configuration using
    the GNU 'g++'

    An example JSON file with all possible keys. If not all keys are specified, default values are used
    ``
    {
Martin Bauer's avatar
Martin Bauer committed
128
129
130
131
132
133
134
        'compiler' :
        {
            "command": "/software/intel/2017/bin/icpc",
            "flags": "-Ofast -DNDEBUG -fPIC -march=native -fopenmp",
            "env": {
                "LM_PROJECT": "iwia",
            }
135
        }
136
    }
137
138
    ``
    """
Martin Bauer's avatar
Martin Bauer committed
139
140
141
142
    global _config
    _config = config.copy()


Martin Bauer's avatar
Martin Bauer committed
143
144
def get_configuration_file_path():
    config_path_in_home = os.path.join(user_config_dir('pystencils'), 'config.json')
145

146
147
148
    # 1) Read path from environment variable if found
    if 'PYSTENCILS_CONFIG' in os.environ:
        return os.environ['PYSTENCILS_CONFIG'], True
Martin Bauer's avatar
Martin Bauer committed
149
150
151
    # 2) Look in current directory for pystencils.json
    elif os.path.exists("pystencils.json"):
        return "pystencils.json", True
152
    # 3) Try ~/.pystencils.json
Martin Bauer's avatar
Martin Bauer committed
153
154
    elif os.path.exists(config_path_in_home):
        return config_path_in_home, True
155
    else:
Martin Bauer's avatar
Martin Bauer committed
156
        return config_path_in_home, False
157
158


Martin Bauer's avatar
Martin Bauer committed
159
160
def create_folder(path, is_file):
    if is_file:
Martin Bauer's avatar
Martin Bauer committed
161
162
163
164
165
166
167
        path = os.path.split(path)[0]
    try:
        os.makedirs(path)
    except os.error:
        pass


Martin Bauer's avatar
Martin Bauer committed
168
def read_config():
Martin Bauer's avatar
Martin Bauer committed
169
    if platform.system().lower() == 'linux':
Martin Bauer's avatar
Martin Bauer committed
170
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
171
172
            ('os', 'linux'),
            ('command', 'g++'),
173
            ('flags', '-Ofast -DNDEBUG -fPIC -march=native -fopenmp -std=c++11'),
Martin Bauer's avatar
Martin Bauer committed
174
            ('restrict_qualifier', '__restrict__')
Martin Bauer's avatar
Martin Bauer committed
175
        ])
176

Martin Bauer's avatar
Martin Bauer committed
177
    elif platform.system().lower() == 'windows':
Martin Bauer's avatar
Martin Bauer committed
178
        default_compiler_config = OrderedDict([
Martin Bauer's avatar
Martin Bauer committed
179
            ('os', 'windows'),
Martin Bauer's avatar
Martin Bauer committed
180
            ('msvc_version', 'latest'),
Martin Bauer's avatar
Martin Bauer committed
181
            ('arch', 'x64'),
182
            ('flags', '/Ox /fp:fast /openmp /arch:avx'),
Martin Bauer's avatar
Martin Bauer committed
183
            ('restrict_qualifier', '__restrict')
Martin Bauer's avatar
Martin Bauer committed
184
        ])
Martin Bauer's avatar
Martin Bauer committed
185
186
187
188
189
    default_cache_config = OrderedDict([
        ('read_from_shared_library', False),
        ('object_cache', os.path.join(user_cache_dir('pystencils'), 'objectcache')),
        ('clear_cache_on_start', False),
        ('shared_library', os.path.join(user_cache_dir('pystencils'), 'cache.so')),
190
    ])
Martin Bauer's avatar
Martin Bauer committed
191

Martin Bauer's avatar
Martin Bauer committed
192
193
    default_config = OrderedDict([('compiler', default_compiler_config),
                                  ('cache', default_cache_config)])
Martin Bauer's avatar
Martin Bauer committed
194

Martin Bauer's avatar
Martin Bauer committed
195
196
197
    config_path, config_exists = get_configuration_file_path()
    config = default_config.copy()
    if config_exists:
Martin Bauer's avatar
Martin Bauer committed
198
199
        with open(config_path, 'r') as json_config_file:
            loaded_config = json.load(json_config_file)
Martin Bauer's avatar
Martin Bauer committed
200
        config = recursive_dict_update(config, loaded_config)
201
    else:
Martin Bauer's avatar
Martin Bauer committed
202
203
        create_folder(config_path, True)
        json.dump(config, open(config_path, 'w'), indent=4)
Martin Bauer's avatar
Martin Bauer committed
204

Martin Bauer's avatar
Martin Bauer committed
205
206
    config['cache']['shared_library'] = os.path.expanduser(config['cache']['shared_library']).format(pid=os.getpid())
    config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid())
Martin Bauer's avatar
Martin Bauer committed
207

Martin Bauer's avatar
Martin Bauer committed
208
209
    if config['cache']['clear_cache_on_start']:
        shutil.rmtree(config['cache']['object_cache'], ignore_errors=True)
Martin Bauer's avatar
Martin Bauer committed
210

Martin Bauer's avatar
Martin Bauer committed
211
212
    create_folder(config['cache']['object_cache'], False)
    create_folder(config['cache']['shared_library'], True)
Martin Bauer's avatar
Martin Bauer committed
213
214
215
216
217

    if 'env' not in config['compiler']:
        config['compiler']['env'] = {}

    if config['compiler']['os'] == 'windows':
Martin Bauer's avatar
Martin Bauer committed
218
219
220
        from pystencils.cpu.msvc_detection import get_environment
        msvc_env = get_environment(config['compiler']['msvc_version'], config['compiler']['arch'])
        config['compiler']['env'].update(msvc_env)
Martin Bauer's avatar
Martin Bauer committed
221

222
223
224
    return config


Martin Bauer's avatar
Martin Bauer committed
225
_config = read_config()
226
227


Martin Bauer's avatar
Martin Bauer committed
228
def get_compiler_config():
Martin Bauer's avatar
Martin Bauer committed
229
230
231
    return _config['compiler']


Martin Bauer's avatar
Martin Bauer committed
232
def get_cache_config():
Martin Bauer's avatar
Martin Bauer committed
233
    return _config['cache']
Michael Kuron's avatar
Michael Kuron committed
234
235


Martin Bauer's avatar
Martin Bauer committed
236
def hash_to_function_name(h):
Martin Bauer's avatar
Martin Bauer committed
237
238
239
240
    res = "func_%s" % (h,)
    return res.replace('-', 'm')


Martin Bauer's avatar
Martin Bauer committed
241
242
243
def compile_object_cache_to_shared_library():
    compiler_config = get_compiler_config()
    cache_config = get_cache_config()
Martin Bauer's avatar
Martin Bauer committed
244

Martin Bauer's avatar
Martin Bauer committed
245
246
    shared_library = cache_config['shared_library']
    if len(shared_library) == 0 or cache_config['read_from_shared_library']:
Martin Bauer's avatar
Martin Bauer committed
247
248
        return

Martin Bauer's avatar
Martin Bauer committed
249
250
251
    config_env = compiler_config['env'] if 'env' in compiler_config else {}
    compile_environment = os.environ.copy()
    compile_environment.update(config_env)
Martin Bauer's avatar
Martin Bauer committed
252
253

    try:
Martin Bauer's avatar
Martin Bauer committed
254
255
        if compiler_config['os'] == 'windows':
            all_object_files = glob.glob(os.path.join(cache_config['object_cache'], '*.obj'))
Martin Bauer's avatar
Martin Bauer committed
256
            link_cmd = ['link.exe', '/DLL', '/out:' + shared_library]
Martin Bauer's avatar
Martin Bauer committed
257
        else:
Martin Bauer's avatar
Martin Bauer committed
258
259
            all_object_files = glob.glob(os.path.join(cache_config['object_cache'], '*.o'))
            link_cmd = [compiler_config['command'], '-shared', '-o', shared_library]
Martin Bauer's avatar
Martin Bauer committed
260

Martin Bauer's avatar
Martin Bauer committed
261
262
263
        link_cmd += all_object_files
        if len(all_object_files) > 0:
            run_compile_step(link_cmd)
Martin Bauer's avatar
Martin Bauer committed
264
265
266
267
268
    except subprocess.CalledProcessError as e:
        print(e.output)
        raise e


Martin Bauer's avatar
Martin Bauer committed
269
atexit.register(compile_object_cache_to_shared_library)
Martin Bauer's avatar
Martin Bauer committed
270

Martin Bauer's avatar
Martin Bauer committed
271

272
def generate_code(ast, restrict_qualifier, function_prefix, source_file):
Martin Bauer's avatar
Martin Bauer committed
273
    headers = get_headers(ast)
274
275
    headers.update(['<cmath>', '<cstdint>'])

276
277
278
279
280
281
282
283
    code = generate_c(ast)
    includes = "\n".join(["#include %s" % (include_file,) for include_file in headers])
    print(includes, file=source_file)
    print("#define RESTRICT %s" % (restrict_qualifier,), file=source_file)
    print("#define FUNC_PREFIX %s" % (function_prefix,), file=source_file)
    print('extern "C" { ', file=source_file)
    print(code, file=source_file)
    print('}', file=source_file)
284

Martin Bauer's avatar
Martin Bauer committed
285

Martin Bauer's avatar
Martin Bauer committed
286
287
288
289
290
def run_compile_step(command):
    compiler_config = get_compiler_config()
    config_env = compiler_config['env'] if 'env' in compiler_config else {}
    compile_environment = os.environ.copy()
    compile_environment.update(config_env)
Martin Bauer's avatar
Martin Bauer committed
291

Michael Kuron's avatar
Michael Kuron committed
292
    try:
Martin Bauer's avatar
Martin Bauer committed
293
294
        shell = True if compiler_config['os'].lower() == 'windows' else False
        subprocess.check_output(command, env=compile_environment, stderr=subprocess.STDOUT, shell=shell)
Michael Kuron's avatar
Michael Kuron committed
295
    except subprocess.CalledProcessError as e:
Martin Bauer's avatar
Martin Bauer committed
296
        print(" ".join(command))
297
        print(e.output.decode('utf8'))
Michael Kuron's avatar
Michael Kuron committed
298
        raise e
299
300


Martin Bauer's avatar
Martin Bauer committed
301
302
303
def compile_linux(ast, code_hash_str, src_file, lib_file):
    cache_config = get_cache_config()
    compiler_config = get_compiler_config()
Martin Bauer's avatar
Martin Bauer committed
304

Martin Bauer's avatar
Martin Bauer committed
305
306
    object_file = os.path.join(cache_config['object_cache'], code_hash_str + '.o')
    if not os.path.exists(object_file):
307
308
309
310
311
312
        with file_handle_for_atomic_write(src_file) as f:
            generate_code(ast, compiler_config['restrict_qualifier'], '', f)
        with atomic_file_write(object_file) as file_name:
            compile_cmd = [compiler_config['command'], '-c'] + compiler_config['flags'].split()
            compile_cmd += ['-o', file_name, src_file]
            run_compile_step(compile_cmd)
Martin Bauer's avatar
Martin Bauer committed
313
314

    # Linking
315
316
317
    with atomic_file_write(lib_file) as file_name:
        run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name] +
                         compiler_config['flags'].split())
Martin Bauer's avatar
Martin Bauer committed
318
319


Martin Bauer's avatar
Martin Bauer committed
320
321
322
def compile_windows(ast, code_hash_str, src_file, lib_file):
    cache_config = get_cache_config()
    compiler_config = get_compiler_config()
Martin Bauer's avatar
Martin Bauer committed
323

Martin Bauer's avatar
Martin Bauer committed
324
    object_file = os.path.join(cache_config['object_cache'], code_hash_str + '.obj')
Martin Bauer's avatar
Martin Bauer committed
325
    # Compilation
Martin Bauer's avatar
Martin Bauer committed
326
    if not os.path.exists(object_file):
327
328
        with file_handle_for_atomic_write(src_file) as f:
            generate_code(ast, compiler_config['restrict_qualifier'], '__declspec(dllexport)', f)
Martin Bauer's avatar
Martin Bauer committed
329
330

        # /c compiles only, /EHsc turns of exception handling in c code
Martin Bauer's avatar
Martin Bauer committed
331
332
333
        compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split()
        compile_cmd += [src_file, '/Fo' + object_file]
        run_compile_step(compile_cmd)
Martin Bauer's avatar
Martin Bauer committed
334
335

    # Linking
Martin Bauer's avatar
Martin Bauer committed
336
    run_compile_step(['link.exe', '/DLL', '/out:' + lib_file, object_file])
337

338

Martin Bauer's avatar
Martin Bauer committed
339
340
def compile_and_load(ast):
    cache_config = get_cache_config()
Martin Bauer's avatar
Martin Bauer committed
341

Martin Bauer's avatar
Martin Bauer committed
342
343
    code_hash_str = hashlib.sha256(generate_c(ast).encode()).hexdigest()
    ast.function_name = hash_to_function_name(code_hash_str)
Martin Bauer's avatar
Martin Bauer committed
344

Martin Bauer's avatar
Martin Bauer committed
345
    src_file = os.path.join(cache_config['object_cache'], code_hash_str + ".cpp")
Martin Bauer's avatar
Martin Bauer committed
346

Martin Bauer's avatar
Martin Bauer committed
347
348
    if cache_config['read_from_shared_library']:
        return cdll.LoadLibrary(cache_config['shared_library'])[ast.function_name]
Martin Bauer's avatar
Martin Bauer committed
349
    else:
Martin Bauer's avatar
Martin Bauer committed
350
        if get_compiler_config()['os'].lower() == 'windows':
Martin Bauer's avatar
Martin Bauer committed
351
352
353
            lib_file = os.path.join(cache_config['object_cache'], code_hash_str + ".dll")
            if not os.path.exists(lib_file):
                compile_windows(ast, code_hash_str, src_file, lib_file)
Martin Bauer's avatar
Martin Bauer committed
354
        else:
Martin Bauer's avatar
Martin Bauer committed
355
356
357
358
            lib_file = os.path.join(cache_config['object_cache'], code_hash_str + ".so")
            if not os.path.exists(lib_file):
                compile_linux(ast, code_hash_str, src_file, lib_file)
        return cdll.LoadLibrary(lib_file)[ast.function_name]
359
360


Martin Bauer's avatar
Martin Bauer committed
361
362
363
364
365
def build_ctypes_argument_list(parameter_specification, argument_dict):
    argument_dict = {symbol_name_to_variable_name(k): v for k, v in argument_dict.items()}
    ct_arguments = []
    array_shapes = set()
    index_arr_shapes = set()
366

Martin Bauer's avatar
Martin Bauer committed
367
    for arg in parameter_specification:
Martin Bauer's avatar
Martin Bauer committed
368
        if arg.is_field_argument:
369
            try:
Martin Bauer's avatar
Martin Bauer committed
370
                field_arr = argument_dict[arg.field_name]
371
            except KeyError:
Martin Bauer's avatar
Martin Bauer committed
372
                raise KeyError("Missing field parameter for kernel call " + arg.field_name)
373

Martin Bauer's avatar
Martin Bauer committed
374
            symbolic_field = arg.field
Martin Bauer's avatar
Martin Bauer committed
375
            if arg.is_field_ptr_argument:
Martin Bauer's avatar
Martin Bauer committed
376
377
378
379
380
381
                ct_arguments.append(field_arr.ctypes.data_as(to_ctypes(arg.dtype)))
                if symbolic_field.has_fixed_shape:
                    symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
                    if isinstance(symbolic_field.dtype, StructType):
                        symbolic_field_shape = symbolic_field_shape[:-1]
                    if symbolic_field_shape != field_arr.shape:
382
                        raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
Martin Bauer's avatar
Martin Bauer committed
383
384
385
386
387
388
                                         (arg.field_name, str(field_arr.shape), str(symbolic_field.shape)))
                if symbolic_field.has_fixed_shape:
                    symbolic_field_strides = tuple(int(i) * field_arr.itemsize for i in symbolic_field.strides)
                    if isinstance(symbolic_field.dtype, StructType):
                        symbolic_field_strides = symbolic_field_strides[:-1]
                    if symbolic_field_strides != field_arr.strides:
389
                        raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
Martin Bauer's avatar
Martin Bauer committed
390
                                         (arg.field_name, str(field_arr.strides), str(symbolic_field_strides)))
391

Martin Bauer's avatar
Martin Bauer committed
392
393
394
395
                if FieldType.is_indexed(symbolic_field):
                    index_arr_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
                elif not FieldType.is_buffer(symbolic_field):
                    array_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
396

Martin Bauer's avatar
Martin Bauer committed
397
            elif arg.is_field_shape_argument:
Martin Bauer's avatar
Martin Bauer committed
398
399
                data_type = to_ctypes(get_base_type(arg.dtype))
                ct_arguments.append(field_arr.ctypes.shape_as(data_type))
Martin Bauer's avatar
Martin Bauer committed
400
            elif arg.is_field_stride_argument:
Martin Bauer's avatar
Martin Bauer committed
401
402
403
404
405
406
                data_type = to_ctypes(get_base_type(arg.dtype))
                strides = field_arr.ctypes.strides_as(data_type)
                for i in range(len(field_arr.shape)):
                    assert strides[i] % field_arr.itemsize == 0
                    strides[i] //= field_arr.itemsize
                ct_arguments.append(strides)
407
408
409
            else:
                assert False
        else:
410
            try:
Martin Bauer's avatar
Martin Bauer committed
411
                param = argument_dict[arg.name]
412
413
            except KeyError:
                raise KeyError("Missing parameter for kernel call " + arg.name)
Martin Bauer's avatar
Martin Bauer committed
414
415
            expected_type = to_ctypes(arg.dtype)
            ct_arguments.append(expected_type(param))
416

Martin Bauer's avatar
Martin Bauer committed
417
418
419
420
    if len(array_shapes) > 1:
        raise ValueError("All passed arrays have to have the same size " + str(array_shapes))
    if len(index_arr_shapes) > 1:
        raise ValueError("All passed index arrays have to have the same size " + str(array_shapes))
421

Martin Bauer's avatar
Martin Bauer committed
422
    return ct_arguments
423
424


Martin Bauer's avatar
Martin Bauer committed
425
426
def make_python_function_incomplete_params(kernel_function_node, argument_dict, func):
    parameters = kernel_function_node.parameters
427

428
    cache = {}
Martin Bauer's avatar
Martin Bauer committed
429
    cache_values = []
430

431
    def wrapper(**kwargs):
432
433
        key = hash(tuple((k, v.ctypes.data, v.strides, v.shape) if isinstance(v, np.ndarray) else (k, id(v))
                         for k, v in kwargs.items()))
434
435
436
437
        try:
            args = cache[key]
            func(*args)
        except KeyError:
Martin Bauer's avatar
Martin Bauer committed
438
439
440
            full_arguments = argument_dict.copy()
            full_arguments.update(kwargs)
            args = build_ctypes_argument_list(parameters, full_arguments)
441
            cache[key] = args
Martin Bauer's avatar
Martin Bauer committed
442
            cache_values.append(kwargs)  # keep objects alive such that ids remain unique
443
            func(*args)
Martin Bauer's avatar
Martin Bauer committed
444
445
    wrapper.ast = kernel_function_node
    wrapper.parameters = kernel_function_node.parameters
446
    return wrapper