From b5f4dc959dedb6f62cff33f10977433f89f98985 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Thu, 14 Jun 2018 11:17:07 +0200
Subject: [PATCH] Reverted boundary optimization

- not generic enough - does not work if there a multiple blocks
---
 boundaries/boundaryhandling.py |  9 +++-
 cpu/cpujit.py                  | 18 ++++---
 llvm/llvmjit.py                | 93 +++++++++++++++++++++++++++++++++-
 3 files changed, 108 insertions(+), 12 deletions(-)

diff --git a/boundaries/boundaryhandling.py b/boundaries/boundaryhandling.py
index e82239f22..b38d86656 100644
--- a/boundaries/boundaryhandling.py
+++ b/boundaries/boundaryhandling.py
@@ -202,9 +202,14 @@ class BoundaryHandling:
 
         for b in self._data_handling.iterate(gpu=self._target == 'gpu'):
             for b_obj, idx_arr in b[self._index_array_name].boundary_object_to_index_list.items():
+                kwargs[self._field_name] = b[self._field_name]
                 kwargs['indexField'] = idx_arr
-                kernel = self._boundary_object_to_boundary_info[b_obj].kernel
-                self._data_handling.run_kernel(kernel, **kwargs)
+                data_used_in_kernel = (p.field_name
+                                       for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters
+                                       if p.is_field_ptr_argument and p.field_name not in kwargs)
+                kwargs.update({name: b[name] for name in data_used_in_kernel})
+
+                self._boundary_object_to_boundary_info[b_obj].kernel(**kwargs)
 
     def geometry_to_vtk(self, file_name='geometry', boundaries='all', ghost_layers=False):
         """
diff --git a/cpu/cpujit.py b/cpu/cpujit.py
index de832d240..62ed51cda 100644
--- a/cpu/cpujit.py
+++ b/cpu/cpujit.py
@@ -62,7 +62,7 @@ from pystencils.backends.cbackend import generate_c, get_headers
 from pystencils.utils import file_handle_for_atomic_write, atomic_file_write
 
 
-def make_python_function(kernel_function_node, argument_dict=None):
+def make_python_function(kernel_function_node):
     """
     Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function
 
@@ -71,13 +71,9 @@ def make_python_function(kernel_function_node, argument_dict=None):
         - all symbols which are not defined in the kernel itself are expected as parameters
 
     :param kernel_function_node: the abstract syntax tree
-    :param argument_dict: parameters passed here are already fixed. Remaining parameters have to be passed to the
-                        returned kernel functor.
     :return: kernel functor
     """
     result = compile_and_load(kernel_function_node)
-    if argument_dict:
-        result = functools.partial(result, **argument_dict)
     return result
 
 
@@ -246,7 +242,10 @@ PyBuffer_Release(&buffer_{name});
 template_function_boilerplate = """
 static PyObject * {func_name}(PyObject * self, PyObject * args, PyObject * kwargs)
 {{
-    if( !kwargs || !PyDict_Check(kwargs) ) {{ PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); return NULL; }}
+    if( !kwargs || !PyDict_Check(kwargs) ) {{ 
+        PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); 
+        return NULL; 
+    }}
     {pre_call_code}
     kernel_{func_name}({parameters});
     {post_call_code}
@@ -320,7 +319,9 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
 
                 shapes = ", ".join(["buffer_{name}.shape[{i}]".format(name=arg.field_name, i=i)
                                     for i in range(len(arg.field.strides))])
-                pre_call_code += "{type} {name}_shape[] = {{ {elements} }};\n".format(type=get_base_type(Field.SHAPE_DTYPE),
+
+                shape_type = get_base_type(Field.SHAPE_DTYPE)
+                pre_call_code += "{type} {name}_shape[] = {{ {elements} }};\n".format(type=shape_type,
                                                                                       name=arg.field_name,
                                                                                       elements=shapes)
 
@@ -328,7 +329,8 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
                 strides = ["buffer_{name}.strides[{i}] / {bytes}".format(i=i, name=arg.field_name, bytes=item_size)
                            for i in range(len(arg.field.strides))]
                 strides = ", ".join(strides)
-                pre_call_code += "{type} {name}_strides[] = {{ {elements} }};\n".format(type=get_base_type(Field.STRIDE_DTYPE),
+                strides_type = get_base_type(Field.STRIDE_DTYPE)
+                pre_call_code += "{type} {name}_strides[] = {{ {elements} }};\n".format(type=strides_type,
                                                                                         name=arg.field_name,
                                                                                         elements=strides)
 
diff --git a/llvm/llvmjit.py b/llvm/llvmjit.py
index c0e1238b8..efc0d9bc5 100644
--- a/llvm/llvmjit.py
+++ b/llvm/llvmjit.py
@@ -3,9 +3,98 @@ import llvmlite.binding as llvm
 import numpy as np
 import ctypes as ct
 from pystencils.data_types import create_composite_type_from_string
-from ..data_types import to_ctypes, ctypes_from_llvm
+from ..data_types import to_ctypes, ctypes_from_llvm, StructType, get_base_type
 from .llvm import generate_llvm
-from ..cpu.cpujit import build_ctypes_argument_list, make_python_function_incomplete_params
+from pystencils.transformations import symbol_name_to_variable_name
+from pystencils.field import FieldType
+
+
+def build_ctypes_argument_list(parameter_specification, argument_dict):
+    argument_dict = {symbol_name_to_variable_name(k): v for k, v in argument_dict.items()}
+    ct_arguments = []
+    array_shapes = set()
+    index_arr_shapes = set()
+
+    for arg in parameter_specification:
+        if arg.is_field_argument:
+            try:
+                field_arr = argument_dict[arg.field_name]
+            except KeyError:
+                raise KeyError("Missing field parameter for kernel call " + arg.field_name)
+
+            symbolic_field = arg.field
+            if arg.is_field_ptr_argument:
+                ct_arguments.append(field_arr.ctypes.data_as(to_ctypes(arg.dtype)))
+                if symbolic_field.has_fixed_shape:
+                    symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
+                    if isinstance(symbolic_field.dtype, StructType):
+                        symbolic_field_shape = symbolic_field_shape[:-1]
+                    if symbolic_field_shape != field_arr.shape:
+                        raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
+                                         (arg.field_name, str(field_arr.shape), str(symbolic_field.shape)))
+                if symbolic_field.has_fixed_shape:
+                    symbolic_field_strides = tuple(int(i) * field_arr.itemsize for i in symbolic_field.strides)
+                    if isinstance(symbolic_field.dtype, StructType):
+                        symbolic_field_strides = symbolic_field_strides[:-1]
+                    if symbolic_field_strides != field_arr.strides:
+                        raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
+                                         (arg.field_name, str(field_arr.strides), str(symbolic_field_strides)))
+
+                if FieldType.is_indexed(symbolic_field):
+                    index_arr_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
+                elif FieldType.is_generic(symbolic_field):
+                    array_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
+
+            elif arg.is_field_shape_argument:
+                data_type = to_ctypes(get_base_type(arg.dtype))
+                ct_arguments.append(field_arr.ctypes.shape_as(data_type))
+            elif arg.is_field_stride_argument:
+                data_type = to_ctypes(get_base_type(arg.dtype))
+                strides = field_arr.ctypes.strides_as(data_type)
+                for i in range(len(field_arr.shape)):
+                    assert strides[i] % field_arr.itemsize == 0
+                    strides[i] //= field_arr.itemsize
+                ct_arguments.append(strides)
+            else:
+                assert False
+        else:
+            try:
+                param = argument_dict[arg.name]
+            except KeyError:
+                raise KeyError("Missing parameter for kernel call " + arg.name)
+            expected_type = to_ctypes(arg.dtype)
+            ct_arguments.append(expected_type(param))
+
+    if len(array_shapes) > 1:
+        raise ValueError("All passed arrays have to have the same size " + str(array_shapes))
+    if len(index_arr_shapes) > 1:
+        raise ValueError("All passed index arrays have to have the same size " + str(array_shapes))
+
+    return ct_arguments
+
+
+def make_python_function_incomplete_params(kernel_function_node, argument_dict, func):
+    parameters = kernel_function_node.parameters
+
+    cache = {}
+    cache_values = []
+
+    def wrapper(**kwargs):
+        key = hash(tuple((k, v.ctypes.data, v.strides, v.shape) if isinstance(v, np.ndarray) else (k, id(v))
+                         for k, v in kwargs.items()))
+        try:
+            args = cache[key]
+            func(*args)
+        except KeyError:
+            full_arguments = argument_dict.copy()
+            full_arguments.update(kwargs)
+            args = build_ctypes_argument_list(parameters, full_arguments)
+            cache[key] = args
+            cache_values.append(kwargs)  # keep objects alive such that ids remain unique
+            func(*args)
+    wrapper.ast = kernel_function_node
+    wrapper.parameters = kernel_function_node.parameters
+    return wrapper
 
 
 def generate_and_jit(ast):
-- 
GitLab