From ca31d3cc3bb2aa0f69554f59a1bdffb84e10ddaa Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Mon, 30 Apr 2018 11:16:27 +0200
Subject: [PATCH] Additional tests

---
 astnodes.py               |  4 ++--
 cpu/kernelcreation.py     |  6 ++++--
 gpucuda/cudajit.py        |  1 +
 gpucuda/kernelcreation.py | 16 ++++++++++------
 transformations.py        | 18 ++++++++++++------
 5 files changed, 29 insertions(+), 16 deletions(-)

diff --git a/astnodes.py b/astnodes.py
index 7ed51daac..9ad56cdcd 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -17,12 +17,12 @@ class Node:
     @property
     def args(self) -> List[NodeOrExpr]:
         """Returns all arguments/children of this node."""
-        return []
+        raise NotImplementedError()
 
     @property
     def symbols_defined(self) -> Set[sp.Symbol]:
         """Set of symbols which are defined by this node."""
-        return set()
+        raise NotImplementedError()
 
     @property
     def undefined_symbols(self) -> Set[sp.Symbol]:
diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index 557f817eb..b9abd61a8 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -70,10 +70,12 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
         split_inner_loop(code, typed_split_groups)
 
     base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']]
-    base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order, field)
+    base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order,
+                                                             field.spatial_dimensions, field.index_dimensions)
                          for field in fields_without_buffers}
 
-    buffer_base_pointer_info = {field.name: parse_base_pointer_info([['spatialInner0']], [0], field)
+    buffer_base_pointer_info = {field.name: parse_base_pointer_info([['spatialInner0']], [0],
+                                                                    field.spatial_dimensions, field.index_dimensions)
                                 for field in buffers}
     base_pointer_info.update(buffer_base_pointer_info)
 
diff --git a/gpucuda/cudajit.py b/gpucuda/cudajit.py
index 544adea58..1127ecadc 100644
--- a/gpucuda/cudajit.py
+++ b/gpucuda/cudajit.py
@@ -61,6 +61,7 @@ def make_python_function(kernel_function_node, argument_dict=None):
         # cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
     wrapper.ast = kernel_function_node
     wrapper.parameters = kernel_function_node.parameters
+    wrapper.num_regs = func.num_regs
     return wrapper
 
 
diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py
index bb60dedfa..1fb637cac 100644
--- a/gpucuda/kernelcreation.py
+++ b/gpucuda/kernelcreation.py
@@ -56,8 +56,10 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
     ast = KernelFunction(block, function_name=function_name, ghost_layers=ghost_layers, backend='gpucuda')
     ast.global_variables.update(indexing.index_variables)
 
-    base_pointer_info = [['spatialInner0']]
-    base_pointer_infos = {f.name: parse_base_pointer_info(base_pointer_info, [2, 1, 0], f) for f in all_fields}
+    base_pointer_spec = [['spatialInner0']]
+    base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0],
+                                                         f.spatial_dimensions, f.index_dimensions)
+                         for f in all_fields}
 
     coord_mapping = {f.name: cell_idx_symbols for f in all_fields}
 
@@ -71,7 +73,7 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
         base_buffer_index += var * stride
 
     resolve_buffer_accesses(ast, base_buffer_index, read_only_fields)
-    resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_infos,
+    resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info,
                            field_to_fixed_coordinates=coord_mapping)
 
     substitute_array_accesses_with_constants(ast)
@@ -131,13 +133,15 @@ def created_indexed_cuda_kernel(assignments, index_fields, function_name="kernel
     ast.global_variables.update(indexing.index_variables)
 
     coord_mapping = indexing.coordinates
-    base_pointer_info = [['spatialInner0']]
-    base_pointer_infos = {f.name: parse_base_pointer_info(base_pointer_info, [2, 1, 0], f) for f in all_fields}
+    base_pointer_spec = [['spatialInner0']]
+    base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0],
+                                                         f.spatial_dimensions, f.index_dimensions)
+                         for f in all_fields}
 
     coord_mapping = {f.name: coord_mapping for f in index_fields}
     coord_mapping.update({f.name: coordinate_typed_symbols for f in non_index_fields})
     resolve_field_accesses(ast, read_only_fields, field_to_fixed_coordinates=coord_mapping,
-                           field_to_base_pointer_info=base_pointer_infos)
+                           field_to_base_pointer_info=base_pointer_info)
     substitute_array_accesses_with_constants(ast)
 
     # add the function which determines #blocks and #threads as additional member to KernelFunction node
diff --git a/transformations.py b/transformations.py
index 727aae027..946886b31 100644
--- a/transformations.py
+++ b/transformations.py
@@ -191,7 +191,7 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
     return new_ptr, offset
 
 
-def parse_base_pointer_info(base_pointer_specification, loop_order, field):
+def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
     """
     Creates base pointer specification for :func:`resolve_field_accesses` function.
 
@@ -210,10 +210,16 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field):
     Args:
         base_pointer_specification: nested list with above specifications
         loop_order: list with ordering of loops from outer to inner
-        field:
+        spatial_dimensions: number of spatial dimensions
+        index_dimensions: number of index dimensions
 
     Returns:
         list of tuples that can be passed to :func:`resolve_field_accesses`
+
+    Examples:
+        >>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0],
+        ...                         spatial_dimensions=3, index_dimensions=1)
+        [[0], [3], [1, 2]]
     """
     result = []
     specified_coordinates = set()
@@ -222,7 +228,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field):
         new_group = []
 
         def add_new_element(elem):
-            if elem >= field.spatial_dimensions + field.index_dimensions:
+            if elem >= spatial_dimensions + index_dimensions:
                 raise ValueError("Coordinate %d does not exist" % (elem,))
             new_group.append(elem)
             if elem in specified_coordinates:
@@ -240,19 +246,19 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, field):
                     index = int(element[len("Outer"):])
                     add_new_element(loop_order[-index])
                 elif element == "all":
-                    for i in range(field.spatial_dimensions):
+                    for i in range(spatial_dimensions):
                         add_new_element(i)
                 else:
                     raise ValueError("Could not parse " + element)
             elif element.startswith("index"):
                 index = int(element[len("index"):])
-                add_new_element(field.spatial_dimensions + index)
+                add_new_element(spatial_dimensions + index)
             else:
                 raise ValueError("Unknown specification %s" % (element,))
 
         result.append(new_group)
 
-    all_coordinates = set(range(field.spatial_dimensions + field.index_dimensions))
+    all_coordinates = set(range(spatial_dimensions + index_dimensions))
     rest = all_coordinates - specified_coordinates
     if rest:
         result.append(list(rest))
-- 
GitLab