From 99aef3f880eb14870d1ec9ba7d54b495b9e52850 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 7 Nov 2018 16:46:56 +0100
Subject: [PATCH] Refactored buffer treatment

- put all buffer related stuff into separate functions
- should be functionally equivalent
---
 cpu/kernelcreation.py     | 29 +++++++----------
 gpucuda/kernelcreation.py | 12 +++----
 transformations.py        | 66 ++++++++++++++++++++++-----------------
 3 files changed, 53 insertions(+), 54 deletions(-)

diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index c34a83ec4..87247daba 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -3,7 +3,7 @@ from functools import partial
 from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
 from pystencils.transformations import resolve_buffer_accesses, resolve_field_accesses, make_loop_over_domain, \
     add_types, get_optimal_loop_ordering, parse_base_pointer_info, move_constants_before_loop, \
-    split_inner_loop, substitute_array_accesses_with_constants
+    split_inner_loop, substitute_array_accesses_with_constants, get_base_buffer_index
 from pystencils.data_types import TypedSymbol, BasicType, StructType, create_type
 from pystencils.field import Field, FieldType
 import pystencils.astnodes as ast
@@ -61,13 +61,13 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
 
     body = ast.Block(assignments)
     loop_order = get_optimal_loop_ordering(fields_without_buffers)
-    code, loop_strides, loop_vars = make_loop_over_domain(body, function_name, iteration_slice=iteration_slice,
+    ast_node = make_loop_over_domain(body, function_name, iteration_slice=iteration_slice,
                                                           ghost_layers=ghost_layers, loop_order=loop_order)
-    code.target = 'cpu'
+    ast_node.target = 'cpu'
 
     if split_groups:
         typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups]
-        split_inner_loop(code, typed_split_groups)
+        split_inner_loop(ast_node, typed_split_groups)
 
     base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']]
     base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order,
@@ -79,20 +79,13 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
                                 for field in buffers}
     base_pointer_info.update(buffer_base_pointer_info)
 
-    base_buffer_index = loop_vars[0]
-    stride = 1
-    for idx, var in enumerate(loop_vars[1:]):
-        cur_stride = loop_strides[idx]
-        stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
-        base_buffer_index += var * stride
-
-    resolve_buffer_accesses(code, base_buffer_index, read_only_fields)
-
-    resolve_field_accesses(code, read_only_fields, field_to_base_pointer_info=base_pointer_info)
-    substitute_array_accesses_with_constants(code)
-    move_constants_before_loop(code)
-    code.compile = partial(make_python_function, code)
-    return code
+    if any(FieldType.is_buffer(f) for f in all_fields):
+        resolve_buffer_accesses(ast_node, get_base_buffer_index(ast_node), read_only_fields)
+    resolve_field_accesses(ast_node, read_only_fields, field_to_base_pointer_info=base_pointer_info)
+    substitute_array_accesses_with_constants(ast_node)
+    move_constants_before_loop(ast_node)
+    ast_node.compile = partial(make_python_function, ast_node)
+    return ast_node
 
 
 def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, function_name="kernel",
diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py
index 1fb637cac..da3b39df5 100644
--- a/gpucuda/kernelcreation.py
+++ b/gpucuda/kernelcreation.py
@@ -2,7 +2,8 @@ from functools import partial
 
 from pystencils.gpucuda.indexing import BlockIndexing
 from pystencils.transformations import resolve_field_accesses, add_types, parse_base_pointer_info, \
-    get_common_shape, substitute_array_accesses_with_constants, resolve_buffer_accesses, unify_shape_symbols
+    get_common_shape, substitute_array_accesses_with_constants, resolve_buffer_accesses, unify_shape_symbols, \
+    get_base_buffer_index
 from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate
 from pystencils.data_types import TypedSymbol, BasicType, StructType
 from pystencils import Field, FieldType
@@ -63,16 +64,11 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
 
     coord_mapping = {f.name: cell_idx_symbols for f in all_fields}
 
-    loop_vars = [num_buffer_accesses * i for i in indexing.coordinates]
     loop_strides = list(fields_without_buffers)[0].shape
 
-    base_buffer_index = loop_vars[0]
-    stride = 1
-    for idx, var in enumerate(loop_vars[1:]):
-        stride *= loop_strides[idx]
-        base_buffer_index += var * stride
+    if any(FieldType.is_buffer(f) for f in all_fields):
+        resolve_buffer_accesses(ast, get_base_buffer_index(ast, indexing.coordinates, loop_strides), read_only_fields)
 
-    resolve_buffer_accesses(ast, base_buffer_index, read_only_fields)
     resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info,
                            field_to_fixed_coordinates=coord_mapping)
 
diff --git a/transformations.py b/transformations.py
index b80cc37f6..7ef7edd2d 100644
--- a/transformations.py
+++ b/transformations.py
@@ -73,20 +73,6 @@ def get_common_shape(field_set):
     return shape
 
 
-def get_field_accesses(expr, result=set()):
-    if isinstance(expr, Field.Access):
-        result.add(expr)
-        for o in expr.offsets:
-            get_field_accesses(o, result)
-        for i in expr.index:
-            get_field_accesses(i, result)
-    elif hasattr(expr, 'atoms'):
-        new_accesses = expr.atoms(Field.Access)
-        result.update(new_accesses)
-        for a in new_accesses:
-            get_field_accesses(a, result)
-
-
 def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layers=None, loop_order=None):
     """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
 
@@ -103,14 +89,12 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
         :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
     """
     # find correct ordering by inspecting participating FieldAccesses
-    field_accesses = set()
-    get_field_accesses(body, field_accesses)
+    field_accesses = body.atoms(Field.Access)
     field_accesses = {e for e in field_accesses if not e.is_absolute_access}
 
     # exclude accesses to buffers from field_list, because buffers are treated separately
     field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)]
     fields = set(field_list)
-    num_buffer_accesses = len(field_accesses) - len(field_list)
 
     if loop_order is None:
         loop_order = get_optimal_loop_ordering(fields)
@@ -127,11 +111,6 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
     if isinstance(ghost_layers, int):
         ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)
 
-    def get_loop_stride(loop_begin, loop_end, step):
-        return (loop_end - loop_begin) / step
-
-    loop_strides = []
-    loop_vars = []
     current_body = body
     for i, loop_coordinate in enumerate(reversed(loop_order)):
         if iteration_slice is None:
@@ -139,24 +118,19 @@ def make_loop_over_domain(body, function_name, iteration_slice=None, ghost_layer
             end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
             new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
             current_body = ast.Block([new_loop])
-            loop_strides.append(get_loop_stride(begin, end, 1))
-            loop_vars.append(new_loop.loop_counter_symbol)
         else:
             slice_component = iteration_slice[loop_coordinate]
             if type(slice_component) is slice:
                 sc = slice_component
                 new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
                 current_body = ast.Block([new_loop])
-                loop_strides.append(get_loop_stride(sc.start, sc.stop, sc.step))
-                loop_vars.append(new_loop.loop_counter_symbol)
             else:
                 assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
                                                  sp.sympify(slice_component))
                 current_body.insert_front(assignment)
 
-    loop_vars = [num_buffer_accesses * var for var in loop_vars]
     ast_node = ast.KernelFunction(current_body, ghost_layers=ghost_layers, function_name=function_name, backend='cpu')
-    return ast_node, loop_strides, loop_vars
+    return ast_node
 
 
 def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
@@ -341,7 +315,43 @@ def substitute_array_accesses_with_constants(ast_node):
             substitute_array_accesses_with_constants(a)
 
 
+def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
+    """Used for buffer fields to determine the linearized index of the buffer dependent on loop counter symbols.
+
+    Args:
+        ast_node: ast before any field accesses are resolved
+        loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
+                       for GPU kernels: list of 'loop counters' from inner to outer loop
+        loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default
+
+    Returns:
+        base buffer index - required by 'resolve_buffer_accesses' function
+    """
+    if loop_counters is None or loop_iterations is None:
+        loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
+        loops.reverse()
+        parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
+        assert len(loops) == len(parents_of_innermost_loop)
+        assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
+
+        loop_iterations = [(l.stop - l.start) / l.step for l in loops]
+        loop_counters = [l.loop_counter_symbol for l in loops]
+
+    field_accesses = ast_node.atoms(Field.Access)
+    buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
+    loop_counters = [v * len(buffer_accesses) for v in loop_counters]
+
+    base_buffer_index = loop_counters[0]
+    stride = 1
+    for idx, var in enumerate(loop_counters[1:]):
+        cur_stride = loop_iterations[idx]
+        stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
+        base_buffer_index += var * stride
+    return base_buffer_index
+
+
 def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
+
     def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
         if isinstance(expr, Field.Access):
             field_access = expr
-- 
GitLab