From a29833090bcff1c09168b03ca6998e8d7dc9f24f Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 13 Jun 2018 16:19:35 +0200
Subject: [PATCH] Further small-block optimizations

---
 boundaries/boundaryhandling.py        |  17 +-
 cpu/cpujit.py                         | 471 ++++++++++++++------------
 cpu/cpujit_module.py                  | 260 --------------
 datahandling/parallel_datahandling.py |  48 ++-
 datahandling/serial_datahandling.py   |   6 +-
 5 files changed, 287 insertions(+), 515 deletions(-)
 delete mode 100644 cpu/cpujit_module.py

diff --git a/boundaries/boundaryhandling.py b/boundaries/boundaryhandling.py
index a5ea40a47..e82239f22 100644
--- a/boundaries/boundaryhandling.py
+++ b/boundaries/boundaryhandling.py
@@ -193,7 +193,7 @@ class BoundaryHandling:
         else:
             ff_ghost_layers = self._data_handling.ghost_layers_of_field(self.flag_interface.flag_field_name)
             for b in self._data_handling.iterate(ghost_layers=ff_ghost_layers):
-                for b_obj, setter in b[self._index_array_name].boundary_objectToDataSetter.items():
+                for b_obj, setter in b[self._index_array_name].boundary_object_to_data_setter.items():
                     self._boundary_data_initialization(b_obj, setter, **kwargs)
 
     def __call__(self, **kwargs):
@@ -202,14 +202,9 @@ class BoundaryHandling:
 
         for b in self._data_handling.iterate(gpu=self._target == 'gpu'):
             for b_obj, idx_arr in b[self._index_array_name].boundary_object_to_index_list.items():
-                kwargs[self._field_name] = b[self._field_name]
                 kwargs['indexField'] = idx_arr
-                data_used_in_kernel = (p.field_name
-                                       for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters
-                                       if p.is_field_ptr_argument and p.field_name not in kwargs)
-                kwargs.update({name: b[name] for name in data_used_in_kernel})
-
-                self._boundary_object_to_boundary_info[b_obj].kernel(**kwargs)
+                kernel = self._boundary_object_to_boundary_info[b_obj].kernel
+                self._data_handling.run_kernel(kernel, **kwargs)
 
     def geometry_to_vtk(self, file_name='geometry', boundaries='all', ghost_layers=False):
         """
@@ -273,7 +268,7 @@ class BoundaryHandling:
 
                 boundary_data_setter = BoundaryDataSetter(idx_arr, b.offset, self.stencil, ff_ghost_layers, pdf_arr)
                 index_array_bd.boundary_object_to_index_list[b_info.boundary_object] = idx_arr
-                index_array_bd.boundary_objectToDataSetter[b_info.boundary_object] = boundary_data_setter
+                index_array_bd.boundary_object_to_data_setter[b_info.boundary_object] = boundary_data_setter
                 self._boundary_data_initialization(b_info.boundary_object, boundary_data_setter)
 
     def _boundary_data_initialization(self, boundary_obj, boundary_data_setter, **kwargs):
@@ -291,11 +286,11 @@ class BoundaryHandling:
     class IndexFieldBlockData:
         def __init__(self, *_1, **_2):
             self.boundary_object_to_index_list = {}
-            self.boundary_objectToDataSetter = {}
+            self.boundary_object_to_data_setter = {}
 
         def clear(self):
             self.boundary_object_to_index_list.clear()
-            self.boundary_objectToDataSetter.clear()
+            self.boundary_object_to_data_setter.clear()
 
         @staticmethod
         def to_cpu(gpu_version, cpu_version):
diff --git a/cpu/cpujit.py b/cpu/cpujit.py
index 35afae599..f014ec7e2 100644
--- a/cpu/cpujit.py
+++ b/cpu/cpujit.py
@@ -42,50 +42,27 @@ Then 'cl.exe' is used to compile.
 - **'restrict_qualifier'**: the restrict qualifier is not standardized across compilers.
   For Windows compilers the qualifier should be ``__restrict``
 
-
-Cache Config
-------------
-
-*pystencils* uses a directory to store intermediate files like the generated C++ files, compiled object files and
-the shared libraries which are then loaded from Python using ctypes. The file names are SHA hashes of the
-generated code. If the same kernel was already compiled, the existing object file is used - no recompilation is done.
-
-If 'shared_library' is specified, all kernels that are currently in the cache are compiled into a single shared library.
-This mechanism can be used to run *pystencils* on systems where compilation is not possible, e.g. on clusters where
-compilation on the compute nodes is not possible.
-First the script is run on a system where compilation is possible (e.g. the login node) with
-'read_from_shared_library=False' and with 'shared_library' set a valid path.
-All kernels generated during the run are put into the cache and at the end
-compiled into the shared library. Then, the same script can be run from the compute nodes, with
-'read_from_shared_library=True', such that kernels are taken from the library instead of compiling them.
-
-
-- **'read_from_shared_library'**: if true kernels are not compiled but assumed to be in the shared library
-- **'object_cache'**: path to a folder where intermediate files are stored
-- **'clear_cache_on_start'**: when true the cache is cleared on each start of a *pystencils* script
-- **'shared_library'**: path to a shared library file, which is created if 'read_from_shared_library=false'
 """
-from __future__ import print_function
 import os
-import subprocess
 import hashlib
 import json
 import platform
-import glob
-import atexit
 import shutil
+import textwrap
 import numpy as np
+import functools
+import subprocess
 from appdirs import user_config_dir, user_cache_dir
-from ctypes import cdll
-from pystencils.backends.cbackend import generate_c, get_headers
 from collections import OrderedDict
-from pystencils.transformations import symbol_name_to_variable_name
-from pystencils.data_types import to_ctypes, get_base_type, StructType
-from pystencils.field import FieldType
-from pystencils.utils import recursive_dict_update, file_handle_for_atomic_write, atomic_file_write
+from pystencils.utils import recursive_dict_update
+from sysconfig import get_paths
+from pystencils import FieldType
+from pystencils.data_types import get_base_type
+from pystencils.backends.cbackend import generate_c, get_headers
+from pystencils.utils import file_handle_for_atomic_write, atomic_file_write
 
 
-def make_python_function(kernel_function_node, argument_dict={}):
+def make_python_function(kernel_function_node, argument_dict=None):
     """
     Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function
 
@@ -98,18 +75,10 @@ def make_python_function(kernel_function_node, argument_dict={}):
                         returned kernel functor.
     :return: kernel functor
     """
-    from pystencils.cpu.cpujit_module import make_python_function as otherimpl
-    return otherimpl(kernel_function_node, argument_dict)
-
-    # build up list of CType arguments
-    func = compile_and_load(kernel_function_node)
-    func.restype = None
-    try:
-        args = build_ctypes_argument_list(kernel_function_node.parameters, argument_dict)
-    except KeyError:
-        # not all parameters specified yet
-        return make_python_function_incomplete_params(kernel_function_node, argument_dict, func)
-    return lambda: func(*args)
+    result = compile_and_load(kernel_function_node)
+    if argument_dict:
+        result = functools.partial(result, **argument_dict)
+    return result
 
 
 def set_config(config):
@@ -186,10 +155,8 @@ def read_config():
             ('restrict_qualifier', '__restrict')
         ])
     default_cache_config = OrderedDict([
-        ('read_from_shared_library', False),
         ('object_cache', os.path.join(user_cache_dir('pystencils'), 'objectcache')),
         ('clear_cache_on_start', False),
-        ('shared_library', os.path.join(user_cache_dir('pystencils'), 'cache.so')),
     ])
 
     default_config = OrderedDict([('compiler', default_compiler_config),
@@ -205,14 +172,12 @@ def read_config():
         create_folder(config_path, True)
         json.dump(config, open(config_path, 'w'), indent=4)
 
-    config['cache']['shared_library'] = os.path.expanduser(config['cache']['shared_library']).format(pid=os.getpid())
     config['cache']['object_cache'] = os.path.expanduser(config['cache']['object_cache']).format(pid=os.getpid())
 
     if config['cache']['clear_cache_on_start']:
         clear_cache()
 
     create_folder(config['cache']['object_cache'], False)
-    create_folder(config['cache']['shared_library'], True)
 
     if 'env' not in config['compiler']:
         config['compiler']['env'] = {}
@@ -247,49 +212,175 @@ def clear_cache():
     create_folder(cache_config['object_cache'], False)
 
 
-def compile_object_cache_to_shared_library():
-    compiler_config = get_compiler_config()
-    cache_config = get_cache_config()
+type_mapping = {
+    np.float32: ('PyFloat_AsDouble', 'float'),
+    np.float64: ('PyFloat_AsDouble', 'double'),
+    np.int16: ('PyLong_AsLong', 'int16_t'),
+    np.int32: ('PyLong_AsLong', 'int32_t'),
+    np.int64: ('PyLong_AsLong', 'int64_t'),
+    np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
+    np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
+    np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
+}
 
-    shared_library = cache_config['shared_library']
-    if len(shared_library) == 0 or cache_config['read_from_shared_library']:
-        return
 
-    config_env = compiler_config['env'] if 'env' in compiler_config else {}
-    compile_environment = os.environ.copy()
-    compile_environment.update(config_env)
+template_extract_scalar = """
+PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
+if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
+{target_type} {name} = ({target_type}) {extract_function}( obj_{name} );
+if( PyErr_Occurred() ) {{ return NULL; }}
+"""
 
-    try:
-        if compiler_config['os'] == 'windows':
-            all_object_files = glob.glob(os.path.join(cache_config['object_cache'], '*.obj'))
-            link_cmd = ['link.exe', '/DLL', '/out:' + shared_library]
+template_extract_array = """
+PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
+if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
+Py_buffer buffer_{name};
+int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE);
+if (buffer_{name}_res == -1) {{ return NULL; }}
+"""
+
+template_release_buffer = """
+PyBuffer_Release(&buffer_{name});
+"""
+
+template_function_boilerplate = """
+static PyObject * {func_name}(PyObject * self, PyObject * args, PyObject * kwargs)
+{{
+    if( !kwargs || !PyDict_Check(kwargs) ) {{ PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); return NULL; }}
+    {pre_call_code}
+    kernel_{func_name}({parameters});
+    {post_call_code}
+    Py_RETURN_NONE;
+}}
+"""
+
+template_check_array = """
+if(!({cond})) {{ 
+    PyErr_SetString(PyExc_ValueError, "Wrong {what} of array {name}. Expected {expected}"); 
+    return NULL; 
+}}
+"""
+
+template_size_check = """
+if(!({cond})) {{ 
+    PyErr_SetString(PyExc_TypeError, "Arrays must have same shape"); return NULL; 
+}}"""
+
+template_module_boilerplate = """
+static PyMethodDef method_definitions[] = {{
+    {method_definitions}
+    {{NULL, NULL, 0, NULL}}
+}};
+
+static struct PyModuleDef module_definition = {{
+    PyModuleDef_HEAD_INIT,
+    "{module_name}",   /* name of module */
+    NULL,     /* module documentation, may be NULL */
+    -1,       /* size of per-interpreter state of the module,
+                 or -1 if the module keeps state in global variables. */
+    method_definitions
+}};
+
+PyMODINIT_FUNC
+PyInit_{module_name}(void)
+{{
+    return PyModule_Create(&module_definition);
+}}
+"""
+
+
+def equal_size_check(fields):
+    fields = list(fields)
+    if len(fields) <= 1:
+        return ""
+
+    ref_field = fields[0]
+    cond = ["({field.name}_shape[{i}] == {ref_field.name}_shape[{i}])".format(ref_field=ref_field,
+                                                                              field=field_to_test, i=i)
+            for field_to_test in fields[1:]
+            for i in range(fields[0].spatial_dimensions)]
+    cond = " && ".join(cond)
+    return template_size_check.format(cond=cond)
+
+
+def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
+    pre_call_code = ""
+    parameters = []
+    post_call_code = ""
+    variable_sized_normal_fields = set()
+    variable_sized_index_fields = set()
+
+    for arg in parameter_info:
+        if arg.is_field_argument:
+            if arg.is_field_ptr_argument:
+                pre_call_code += template_extract_array.format(name=arg.field_name)
+                post_call_code += template_release_buffer.format(name=arg.field_name)
+                parameters.append("({dtype} *)buffer_{name}.buf".format(dtype=str(arg.field.dtype),
+                                                                        name=arg.field_name))
+
+                shapes = ", ".join(["buffer_{name}.shape[{i}]".format(name=arg.field_name, i=i)
+                                    for i in range(len(arg.field.strides))])
+                pre_call_code += "Py_ssize_t {name}_shape[] = {{ {elements} }};\n".format(name=arg.field_name,
+                                                                                          elements=shapes)
+
+                item_size = get_base_type(arg.dtype).numpy_dtype.itemsize
+                strides = ["buffer_{name}.strides[{i}] / {bytes}".format(i=i, name=arg.field_name, bytes=item_size)
+                           for i in range(len(arg.field.strides))]
+                strides = ", ".join(strides)
+                pre_call_code += "Py_ssize_t {name}_strides[] = {{ {elements} }};\n".format(name=arg.field_name,
+                                                                                            elements=strides)
+
+                if insert_checks and arg.field.has_fixed_shape:
+                    shape_cond = ["{name}_shape[{i}] == {s}".format(s=s, name=arg.field_name, i=i)
+                                  for i, s in enumerate(arg.field.spatial_shape)]
+                    shape_cond = " && ".join(shape_cond)
+                    pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=arg.field.name,
+                                                                 expected=str(arg.field.shape))
+
+                    strides_cond = ["({name}_strides[{i}] == {s} || {name}_shape[{i}]<=1)".format(s=s, i=i,
+                                                                                                  name=arg.field_name)
+                                    for i, s in enumerate(arg.field.spatial_strides)]
+                    strides_cond = " && ".join(strides_cond)
+                    expected_strides_str = str([e * item_size for e in arg.field.strides])
+                    pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=arg.field.name,
+                                                                 expected=expected_strides_str)
+                if insert_checks and not arg.field.has_fixed_shape:
+                    if FieldType.is_generic(arg.field):
+                        variable_sized_normal_fields.add(arg.field)
+                    elif FieldType.is_indexed(arg.field):
+                        variable_sized_index_fields.add(arg.field)
+
+            elif arg.is_field_shape_argument:
+                parameters.append("{name}_shape".format(name=arg.field_name))
+            elif arg.is_field_stride_argument:
+                parameters.append("{name}_strides".format(name=arg.field_name))
         else:
-            all_object_files = glob.glob(os.path.join(cache_config['object_cache'], '*.o'))
-            link_cmd = [compiler_config['command'], '-shared', '-o', shared_library]
+            extract_function, target_type = type_mapping[arg.dtype.numpy_dtype.type]
+            pre_call_code += template_extract_scalar.format(extract_function=extract_function, target_type=target_type,
+                                                            name=arg.name)
+            parameters.append(arg.name)
 
-        link_cmd += all_object_files
-        if len(all_object_files) > 0:
-            run_compile_step(link_cmd)
-    except subprocess.CalledProcessError as e:
-        print(e.output)
-        raise e
+    pre_call_code += equal_size_check(variable_sized_normal_fields)
+    pre_call_code += equal_size_check(variable_sized_index_fields)
 
+    pre_call_code = textwrap.indent(pre_call_code, '    ')
+    post_call_code = textwrap.indent(post_call_code, '    ')
+    return template_function_boilerplate.format(func_name=name, pre_call_code=pre_call_code,
+                                                post_call_code=post_call_code, parameters=", ".join(parameters))
 
-#atexit.register(compile_object_cache_to_shared_library)
 
+def create_module_boilerplate_code(module_name, names):
+    method_definition = '{{"{name}", (PyCFunction){name}, METH_VARARGS | METH_KEYWORDS, ""}},'
+    method_definitions = "\n".join([method_definition.format(name=name) for name in names])
+    return template_module_boilerplate.format(module_name=module_name, method_definitions=method_definitions)
 
-def generate_code(ast, restrict_qualifier, function_prefix, source_file):
-    headers = get_headers(ast)
-    headers.update(['<math.h>', '<stdint.h>'])
 
-    code = generate_c(ast)
-    includes = "\n".join(["#include %s" % (include_file,) for include_file in headers])
-    print(includes, file=source_file)
-    print("#define RESTRICT %s" % (restrict_qualifier,), file=source_file)
-    print("#define FUNC_PREFIX %s" % (function_prefix,), file=source_file)
-    print('extern "C" { ', file=source_file)
-    print(code, file=source_file)
-    print('}', file=source_file)
+def load_kernel_from_file(module_name, function_name, path):
+    from importlib.util import spec_from_file_location, module_from_spec
+    spec = spec_from_file_location(name=module_name, location=path)
+    mod = module_from_spec(spec)
+    spec.loader.exec_module(mod)
+    return getattr(mod, function_name)
 
 
 def run_compile_step(command):
@@ -307,156 +398,88 @@ def run_compile_step(command):
         raise e
 
 
-def compile_linux(ast, code_hash_str, src_file, lib_file, with_python_include_path=False):
-    cache_config = get_cache_config()
-    compiler_config = get_compiler_config()
-    extra_flags = []
-    if with_python_include_path:
-        from sysconfig import get_paths
-        extra_flags = ['-I' + get_paths()['include']]
-    object_file = os.path.join(cache_config['object_cache'], code_hash_str + '.o')
-    if not os.path.exists(object_file):
-        with file_handle_for_atomic_write(src_file) as f:
-            generate_code(ast, compiler_config['restrict_qualifier'], '', f)
-        with atomic_file_write(object_file) as file_name:
-            compile_cmd = [compiler_config['command'], '-c'] + compiler_config['flags'].split()
-            compile_cmd += [*extra_flags, '-o', file_name, src_file]
-            run_compile_step(compile_cmd)
-
-    # Linking
-    with atomic_file_write(lib_file) as file_name:
-        run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name] +
-                         compiler_config['flags'].split())
-
-
-def compile_windows(ast, code_hash_str, src_file, lib_file, with_python_include_path=False):
-    cache_config = get_cache_config()
-    compiler_config = get_compiler_config()
-    extra_flags = []
-    if with_python_include_path:
-        from sysconfig import get_paths
-        extra_flags = ['/I' + get_paths()['include']]
+class ExtensionModuleCode:
+    def __init__(self, module_name='generated'):
+        self.module_name = module_name
 
-    object_file = os.path.join(cache_config['object_cache'], code_hash_str + '.obj')
-    # Compilation
-    if not os.path.exists(object_file):
-        with file_handle_for_atomic_write(src_file) as f:
-            generate_code(ast, compiler_config['restrict_qualifier'], '__declspec(dllexport)', f)
+        self._ast_nodes = []
+        self._function_names = []
 
-        # /c compiles only, /EHsc turns of exception handling in c code
-        compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split()
-        compile_cmd += [*extra_flags, src_file, '/Fo' + object_file]
-        run_compile_step(compile_cmd)
+    def add_function(self, ast, name=None):
+        self._ast_nodes.append(ast)
+        self._function_names.append(name if name is not None else ast.function_name)
 
-    # Linking
-    run_compile_step(['link.exe', '/DLL', '/out:' + lib_file, object_file])
+    def write_to_file(self, restrict_qualifier, function_prefix, file):
+        headers = {'<math.h>', '<stdint.h>', '"Python.h"'}
+        for ast in self._ast_nodes:
+            headers.update(get_headers(ast))
 
+        includes = "\n".join(["#include %s" % (include_file,) for include_file in headers])
+        print(includes, file=file)
+        print("\n", file=file)
+        print("#define RESTRICT %s" % (restrict_qualifier,), file=file)
+        print("#define FUNC_PREFIX %s" % (function_prefix,), file=file)
+        print("\n", file=file)
 
-def compile_and_load(ast):
-    cache_config = get_cache_config()
+        for ast, name in zip(self._ast_nodes, self._function_names):
+            old_name = ast.function_name
+            ast.function_name = "kernel_" + name
+            print(generate_c(ast), file=file)
+            print(create_function_boilerplate_code(ast.parameters, name), file=file)
+            ast.function_name = old_name
+        print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
 
-    code_hash_str = hashlib.sha256(generate_c(ast).encode()).hexdigest()
-    ast.function_name = hash_to_function_name(code_hash_str)
 
-    src_file = os.path.join(cache_config['object_cache'], code_hash_str + ".cpp")
+class KernelWrapper:
+    def __init__(self, kernel, parameters, ast_node):
+        self.kernel = kernel
+        self.parameters = parameters
+        self.ast = ast_node
 
-    if cache_config['read_from_shared_library']:
-        return cdll.LoadLibrary(cache_config['shared_library'])[ast.function_name]
-    else:
-        if get_compiler_config()['os'].lower() == 'windows':
-            lib_file = os.path.join(cache_config['object_cache'], code_hash_str + ".dll")
-            if not os.path.exists(lib_file):
-                compile_windows(ast, code_hash_str, src_file, lib_file)
-        else:
-            lib_file = os.path.join(cache_config['object_cache'], code_hash_str + ".so")
-            if not os.path.exists(lib_file):
-                compile_linux(ast, code_hash_str, src_file, lib_file)
-        return cdll.LoadLibrary(lib_file)[ast.function_name]
-
-
-def build_ctypes_argument_list(parameter_specification, argument_dict):
-    argument_dict = {symbol_name_to_variable_name(k): v for k, v in argument_dict.items()}
-    ct_arguments = []
-    array_shapes = set()
-    index_arr_shapes = set()
+    def __call__(self, **kwargs):
+        return self.kernel(**kwargs)
 
-    for arg in parameter_specification:
-        if arg.is_field_argument:
-            try:
-                field_arr = argument_dict[arg.field_name]
-            except KeyError:
-                raise KeyError("Missing field parameter for kernel call " + arg.field_name)
 
-            symbolic_field = arg.field
-            if arg.is_field_ptr_argument:
-                ct_arguments.append(field_arr.ctypes.data_as(to_ctypes(arg.dtype)))
-                if symbolic_field.has_fixed_shape:
-                    symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
-                    if isinstance(symbolic_field.dtype, StructType):
-                        symbolic_field_shape = symbolic_field_shape[:-1]
-                    if symbolic_field_shape != field_arr.shape:
-                        raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
-                                         (arg.field_name, str(field_arr.shape), str(symbolic_field.shape)))
-                if symbolic_field.has_fixed_shape:
-                    symbolic_field_strides = tuple(int(i) * field_arr.itemsize for i in symbolic_field.strides)
-                    if isinstance(symbolic_field.dtype, StructType):
-                        symbolic_field_strides = symbolic_field_strides[:-1]
-                    if symbolic_field_strides != field_arr.strides:
-                        raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
-                                         (arg.field_name, str(field_arr.strides), str(symbolic_field_strides)))
-
-                if FieldType.is_indexed(symbolic_field):
-                    index_arr_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
-                elif FieldType.is_generic(symbolic_field):
-                    array_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
+def compile_and_load(ast):
+    from pystencils.cpu.cpujit import get_cache_config, get_compiler_config
+    cache_config = get_cache_config()
+    compiler_config = get_compiler_config()
 
-            elif arg.is_field_shape_argument:
-                data_type = to_ctypes(get_base_type(arg.dtype))
-                ct_arguments.append(field_arr.ctypes.shape_as(data_type))
-            elif arg.is_field_stride_argument:
-                data_type = to_ctypes(get_base_type(arg.dtype))
-                strides = field_arr.ctypes.strides_as(data_type)
-                for i in range(len(field_arr.shape)):
-                    assert strides[i] % field_arr.itemsize == 0
-                    strides[i] //= field_arr.itemsize
-                ct_arguments.append(strides)
-            else:
-                assert False
-        else:
-            try:
-                param = argument_dict[arg.name]
-            except KeyError:
-                raise KeyError("Missing parameter for kernel call " + arg.name)
-            expected_type = to_ctypes(arg.dtype)
-            ct_arguments.append(expected_type(param))
-
-    if len(array_shapes) > 1:
-        raise ValueError("All passed arrays have to have the same size " + str(array_shapes))
-    if len(index_arr_shapes) > 1:
-        raise ValueError("All passed index arrays have to have the same size " + str(array_shapes))
-
-    return ct_arguments
-
-
-def make_python_function_incomplete_params(kernel_function_node, argument_dict, func):
-    parameters = kernel_function_node.parameters
-
-    cache = {}
-    cache_values = []
-
-    def wrapper(**kwargs):
-        key = hash(tuple((k, v.ctypes.data, v.strides, v.shape) if isinstance(v, np.ndarray) else (k, id(v))
-                         for k, v in kwargs.items()))
-        try:
-            args = cache[key]
-            func(*args)
-        except KeyError:
-            full_arguments = argument_dict.copy()
-            full_arguments.update(kwargs)
-            args = build_ctypes_argument_list(parameters, full_arguments)
-            cache[key] = args
-            cache_values.append(kwargs)  # keep objects alive such that ids remain unique
-            func(*args)
-    wrapper.ast = kernel_function_node
-    wrapper.parameters = kernel_function_node.parameters
-    return wrapper
+    if compiler_config['os'].lower() == 'windows':
+        function_prefix = '__declspec(dllexport)'
+        lib_suffix = '.dll'
+        object_suffix = '.obj'
+        windows = True
+    else:
+        function_prefix = ''
+        lib_suffix = '.so'
+        object_suffix = '.o'
+        windows = False
+
+    code_hash_str = "mod_" + hashlib.sha256(generate_c(ast).encode()).hexdigest()
+    code = ExtensionModuleCode(module_name=code_hash_str)
+    code.add_function(ast, ast.function_name)
+    src_file = os.path.join(cache_config['object_cache'], code_hash_str + ".cpp")
+    lib_file = os.path.join(cache_config['object_cache'], code_hash_str + lib_suffix)
+    if not os.path.exists(lib_file):
+        extra_flags = ['-I' + get_paths()['include']]
+        object_file = os.path.join(cache_config['object_cache'], code_hash_str + object_suffix)
+        if not os.path.exists(object_file):
+            with file_handle_for_atomic_write(src_file) as f:
+                code.write_to_file(compiler_config['restrict_qualifier'], function_prefix, f)
+            with atomic_file_write(object_file) as file_name:
+                if windows:
+                    compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split()
+                    compile_cmd += [*extra_flags, src_file, '/Fo' + object_file]
+                else:
+                    compile_cmd = [compiler_config['command'], '-c'] + compiler_config['flags'].split()
+                    compile_cmd += [*extra_flags, '-o', file_name, src_file]
+                run_compile_step(compile_cmd)
+
+        # Linking
+        with atomic_file_write(lib_file) as file_name:
+            run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name] +
+                             compiler_config['flags'].split())
+
+    result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)
+    return KernelWrapper(result, ast.parameters, ast)
diff --git a/cpu/cpujit_module.py b/cpu/cpujit_module.py
deleted file mode 100644
index 8231770b5..000000000
--- a/cpu/cpujit_module.py
+++ /dev/null
@@ -1,260 +0,0 @@
-import os
-import textwrap
-import hashlib
-import numpy as np
-from sysconfig import get_paths
-from pystencils import FieldType
-from pystencils.cpu.cpujit import run_compile_step
-from pystencils.data_types import get_base_type
-from pystencils.backends.cbackend import generate_c, get_headers
-from pystencils.utils import file_handle_for_atomic_write, atomic_file_write
-
-type_mapping = {
-    np.float32: ('PyFloat_AsDouble', 'float'),
-    np.float64: ('PyFloat_AsDouble', 'double'),
-    np.int16: ('PyLong_AsLong', 'int16_t'),
-    np.int32: ('PyLong_AsLong', 'int32_t'),
-    np.int64: ('PyLong_AsLong', 'int64_t'),
-    np.uint16: ('PyLong_AsUnsignedLong', 'uint16_t'),
-    np.uint32: ('PyLong_AsUnsignedLong', 'uint32_t'),
-    np.uint64: ('PyLong_AsUnsignedLong', 'uint64_t'),
-}
-
-
-template_extract_scalar = """
-PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
-if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
-{target_type} {name} = ({target_type}) {extract_function}( obj_{name} );
-if( PyErr_Occurred() ) {{ return NULL; }}
-"""
-
-template_extract_array = """
-PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
-if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
-Py_buffer buffer_{name};
-int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE);
-if (buffer_{name}_res == -1) {{ return NULL; }}
-"""
-
-template_release_buffer = """
-PyBuffer_Release(&buffer_{name});
-"""
-
-template_function_boilerplate = """
-static PyObject * {func_name}(PyObject * self, PyObject * args, PyObject * kwargs)
-{{
-    if( !kwargs || !PyDict_Check(kwargs) ) {{ PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); return NULL; }}
-    {pre_call_code}
-    kernel_{func_name}({parameters});
-    {post_call_code}
-    Py_RETURN_NONE;
-}}
-"""
-
-template_check_array = """
-if(!({cond})) {{ 
-    PyErr_SetString(PyExc_ValueError, "Wrong {what} of array {name}. Expected {expected}"); 
-    return NULL; 
-}}
-"""
-
-template_size_check = """
-if(!({cond})) {{ 
-    PyErr_SetString(PyExc_TypeError, "Arrays must have same shape"); return NULL; 
-}}"""
-
-template_module_boilerplate = """
-static PyMethodDef method_definitions[] = {{
-    {method_definitions}
-    {{NULL, NULL, 0, NULL}}
-}};
-
-static struct PyModuleDef module_definition = {{
-    PyModuleDef_HEAD_INIT,
-    "{module_name}",   /* name of module */
-    NULL,     /* module documentation, may be NULL */
-    -1,       /* size of per-interpreter state of the module,
-                 or -1 if the module keeps state in global variables. */
-    method_definitions
-}};
-
-PyMODINIT_FUNC
-PyInit_{module_name}(void)
-{{
-    return PyModule_Create(&module_definition);
-}}
-"""
-
-
-def equal_size_check(fields):
-    fields = list(fields)
-    if len(fields) <= 1:
-        return ""
-
-    ref_field = fields[0]
-    cond = ["({field.name}_shape[{i}] == {ref_field.name}_shape[{i}])".format(ref_field=ref_field,
-                                                                              field=field_to_test, i=i)
-            for field_to_test in fields[1:]
-            for i in range(fields[0].spatial_dimensions)]
-    cond = " && ".join(cond)
-    return template_size_check.format(cond=cond)
-
-
-def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
-    pre_call_code = ""
-    parameters = []
-    post_call_code = ""
-    variable_sized_normal_fields = set()
-    variable_sized_index_fields = set()
-
-    for arg in parameter_info:
-        if arg.is_field_argument:
-            if arg.is_field_ptr_argument:
-                pre_call_code += template_extract_array.format(name=arg.field_name)
-                post_call_code += template_release_buffer.format(name=arg.field_name)
-                parameters.append("({dtype} *)buffer_{name}.buf".format(dtype=str(arg.field.dtype),
-                                                                        name=arg.field_name))
-
-                shapes = ", ".join(["buffer_{name}.shape[{i}]".format(name=arg.field_name, i=i)
-                                    for i in range(len(arg.field.strides))])
-                pre_call_code += "Py_ssize_t {name}_shape[] = {{ {elements} }};\n".format(name=arg.field_name,
-                                                                                          elements=shapes)
-
-                item_size = get_base_type(arg.dtype).numpy_dtype.itemsize
-                strides = ["buffer_{name}.strides[{i}] / {bytes}".format(i=i, name=arg.field_name, bytes=item_size)
-                           for i in range(len(arg.field.strides))]
-                strides = ", ".join(strides)
-                pre_call_code += "Py_ssize_t {name}_strides[] = {{ {elements} }};\n".format(name=arg.field_name,
-                                                                                            elements=strides)
-
-                if insert_checks and arg.field.has_fixed_shape:
-                    shape_cond = ["{name}_shape[{i}] == {s}".format(s=s, name=arg.field_name, i=i)
-                                  for i, s in enumerate(arg.field.spatial_shape)]
-                    shape_cond = " && ".join(shape_cond)
-                    pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=arg.field.name,
-                                                                 expected=str(arg.field.shape))
-
-                    strides_cond = ["({name}_strides[{i}] == {s} || {name}_shape[{i}]<=1)".format(s=s, i=i,
-                                                                                                  name=arg.field_name)
-                                    for i, s in enumerate(arg.field.spatial_strides)]
-                    strides_cond = " && ".join(strides_cond)
-                    expected_strides_str = str([e * item_size for e in arg.field.strides])
-                    pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=arg.field.name,
-                                                                 expected=expected_strides_str)
-                if insert_checks and not arg.field.has_fixed_shape:
-                    if FieldType.is_generic(arg.field):
-                        variable_sized_normal_fields.add(arg.field)
-                    elif FieldType.is_indexed(arg.field):
-                        variable_sized_index_fields.add(arg.field)
-
-            elif arg.is_field_shape_argument:
-                parameters.append("{name}_shape".format(name=arg.field_name))
-            elif arg.is_field_stride_argument:
-                parameters.append("{name}_strides".format(name=arg.field_name))
-        else:
-            extract_function, target_type = type_mapping[arg.dtype.numpy_dtype.type]
-            pre_call_code += template_extract_scalar.format(extract_function=extract_function, target_type=target_type,
-                                                            name=arg.name)
-            parameters.append(arg.name)
-
-    pre_call_code += equal_size_check(variable_sized_normal_fields)
-    pre_call_code += equal_size_check(variable_sized_index_fields)
-
-    pre_call_code = textwrap.indent(pre_call_code, '    ')
-    post_call_code = textwrap.indent(post_call_code, '    ')
-    return template_function_boilerplate.format(func_name=name, pre_call_code=pre_call_code,
-                                                post_call_code=post_call_code, parameters=", ".join(parameters))
-
-
-def create_module_boilerplate_code(module_name, names):
-    method_definition = '{{"{name}", (PyCFunction){name}, METH_VARARGS | METH_KEYWORDS, ""}},'
-    method_definitions = "\n".join([method_definition.format(name=name) for name in names])
-    return template_module_boilerplate.format(module_name=module_name, method_definitions=method_definitions)
-
-
-def load_kernel_from_file(module_name, function_name, path):
-    from importlib.util import spec_from_file_location, module_from_spec
-    spec = spec_from_file_location(name=module_name, location=path)
-    mod = module_from_spec(spec)
-    spec.loader.exec_module(mod)
-    return getattr(mod, function_name)
-
-
-class ExtensionModuleCode:
-    def __init__(self, module_name='generated'):
-        self.module_name = module_name
-
-        self._ast_nodes = []
-        self._function_names = []
-
-    def add_function(self, ast, name=None):
-        self._ast_nodes.append(ast)
-        self._function_names.append(name if name is not None else ast.function_name)
-
-    def write_to_file(self, restrict_qualifier, function_prefix, file):
-        headers = {'<math.h>', '<stdint.h>', '"Python.h"'}
-        for ast in self._ast_nodes:
-            headers.update(get_headers(ast))
-
-        includes = "\n".join(["#include %s" % (include_file,) for include_file in headers])
-        print(includes, file=file)
-        print("\n", file=file)
-        print("#define RESTRICT %s" % (restrict_qualifier,), file=file)
-        print("#define FUNC_PREFIX %s" % (function_prefix,), file=file)
-        print("\n", file=file)
-
-        for ast, name in zip(self._ast_nodes, self._function_names):
-            old_name = ast.function_name
-            ast.function_name = "kernel_" + name
-            print(generate_c(ast), file=file)
-            print(create_function_boilerplate_code(ast.parameters, name), file=file)
-            ast.function_name = old_name
-        print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
-
-
-class KernelWrapper:
-    def __init__(self, kernel, parameters, ast_node):
-        self.kernel = kernel
-        self.parameters = parameters
-        self.ast = ast_node
-
-    def __call__(self, **kwargs):
-        return self.kernel(**kwargs)
-
-
-def compile_and_load(ast):
-    from pystencils.cpu.cpujit import get_cache_config, get_compiler_config
-
-    cache_config = get_cache_config()
-    code_hash_str = "mod_" + hashlib.sha256(generate_c(ast).encode()).hexdigest()
-    code = ExtensionModuleCode(module_name=code_hash_str)
-    code.add_function(ast, ast.function_name)
-    src_file = os.path.join(cache_config['object_cache'], code_hash_str + ".cpp")
-    lib_file = os.path.join(cache_config['object_cache'], code_hash_str + ".so")
-    if not os.path.exists(lib_file):
-        compiler_config = get_compiler_config()
-        extra_flags = ['-I' + get_paths()['include']]
-        object_file = os.path.join(cache_config['object_cache'], code_hash_str + '.o')
-        if not os.path.exists(object_file):
-            with file_handle_for_atomic_write(src_file) as f:
-                code.write_to_file(compiler_config['restrict_qualifier'], '', f)
-            with atomic_file_write(object_file) as file_name:
-                compile_cmd = [compiler_config['command'], '-c'] + compiler_config['flags'].split()
-                compile_cmd += [*extra_flags, '-o', file_name, src_file]
-                run_compile_step(compile_cmd)
-
-        # Linking
-        with atomic_file_write(lib_file) as file_name:
-            run_compile_step([compiler_config['command'], '-shared', object_file, '-o', file_name] +
-                             compiler_config['flags'].split())
-
-    result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)
-    return KernelWrapper(result, ast.parameters, ast)
-
-
-def make_python_function(kernel_function_node, argument_dict=None):
-    import functools
-    result = compile_and_load(kernel_function_node)
-    if argument_dict:
-        result = functools.partial(result, **argument_dict)
-    return result
diff --git a/datahandling/parallel_datahandling.py b/datahandling/parallel_datahandling.py
index 6ab1abe0f..a5d88c6d9 100644
--- a/datahandling/parallel_datahandling.py
+++ b/datahandling/parallel_datahandling.py
@@ -135,6 +135,8 @@ class ParallelDataHandling(DataHandling):
         self._field_name_to_cpu_data_name[name] = name
         if gpu:
             self._field_name_to_gpu_data_name[name] = self.GPU_DATA_PREFIX + name
+
+        self._rebuild_data_cache()
         return self.fields[name]
 
     def has_data(self, name):
@@ -154,8 +156,15 @@ class ParallelDataHandling(DataHandling):
 
     def swap(self, name1, name2, gpu=False):
         if gpu:
+            for d in self._data_cache_gpu:
+                d[name1], d[name2] = d[name2], d[name1]
+
             name1 = self.GPU_DATA_PREFIX + name1
             name2 = self.GPU_DATA_PREFIX + name2
+        else:
+            for d in self._data_cache_cpu:
+                d[name1], d[name2] = d[name2], d[name1]
+
         for block in self.blocks:
             block[name1].swapDataPointers(block[name2])
 
@@ -213,24 +222,31 @@ class ParallelDataHandling(DataHandling):
             arr = arr[:, :, 0]
         return arr
 
-    def run_kernel(self, kernel_function, *args, **kwargs):
+    def _rebuild_data_cache(self):
+        self._data_cache_cpu = []
+        self._data_cache_gpu = []
+
+        elements = [(self._data_cache_cpu, wlb.field.toArray, self._field_name_to_cpu_data_name)]
+        if self._field_name_to_gpu_data_name:
+            elements.append((self._data_cache_gpu, wlb.cuda.toGpuArray, self._field_name_to_gpu_data_name))
+
+        for cache, to_array, name_to_data_name in elements:
+            for block in self.blocks:
+                field_args = {}
+                for field_name, data_name in name_to_data_name.items():
+                    field = self.fields[field_name]
+                    arr = to_array(block[data_name], withGhostLayers=[True, True, self.dim == 3])
+                    arr = self._normalize_arr_shape(arr, field.index_dimensions)
+                    field_args[field_name] = arr
+                cache.append(field_args)
+
+    def run_kernel(self, kernel_function, **kwargs):
         if kernel_function.ast.backend == 'gpucuda':
-            name_map = self._field_name_to_gpu_data_name
-            to_array = wlb.cuda.toGpuArray
+            for d in self._data_cache_gpu:
+                kernel_function(**d, **kwargs)
         else:
-            name_map = self._field_name_to_cpu_data_name
-            to_array = wlb.field.toArray
-        data_used_in_kernel = [(name_map[p.field_name], self.fields[p.field_name])
-                               for p in kernel_function.parameters if
-                               p.is_field_ptr_argument and p.field_name not in kwargs]
-        for block in self.blocks:
-            field_args = {}
-            for data_name, f in data_used_in_kernel:
-                arr = to_array(block[data_name], withGhostLayers=[True, True, self.dim == 3])
-                arr = self._normalize_arr_shape(arr, f.index_dimensions)
-                field_args[f.name] = arr
-            field_args.update(kwargs)
-            kernel_function(*args, **field_args)
+            for d in self._data_cache_cpu:
+                kernel_function(**d, **kwargs)
 
     def to_cpu(self, name):
         if name in self._custom_data_transfer_functions:
diff --git a/datahandling/serial_datahandling.py b/datahandling/serial_datahandling.py
index 3a365675b..c6026a81a 100644
--- a/datahandling/serial_datahandling.py
+++ b/datahandling/serial_datahandling.py
@@ -196,10 +196,8 @@ class SerialDataHandling(DataHandling):
         return arr
 
     def swap(self, name1, name2, gpu=False):
-        if not gpu:
-            self.cpu_arrays[name1], self.cpu_arrays[name2] = self.cpu_arrays[name2], self.cpu_arrays[name1]
-        else:
-            self.gpu_arrays[name1], self.gpu_arrays[name2] = self.gpu_arrays[name2], self.gpu_arrays[name1]
+        arr = self.gpu_arrays if gpu else self.cpu_arrays
+        arr[name1], arr[name2] = arr[name2], arr[name1]
 
     def all_to_cpu(self):
         for name in (self.cpu_arrays.keys() & self.gpu_arrays.keys()) | self._custom_data_transfer_functions.keys():
-- 
GitLab