From 74567900ea73f39e06d10ceed0dc24308e611e40 Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Wed, 29 Mar 2023 09:39:04 +0200
Subject: [PATCH] Replace get common shape

---
 pystencils/gpucuda/kernelcreation.py |  8 ++++----
 pystencils/transformations.py        | 24 ++++++++++++++----------
 2 files changed, 18 insertions(+), 14 deletions(-)

diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py
index 7bc7fb9dd..96a531383 100644
--- a/pystencils/gpucuda/kernelcreation.py
+++ b/pystencils/gpucuda/kernelcreation.py
@@ -13,7 +13,7 @@ from pystencils.node_collection import NodeCollection
 from pystencils.gpucuda.indexing import indexing_creator_from_params
 from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.transformations import (
-    get_base_buffer_index, get_common_shape, parse_base_pointer_info,
+    get_base_buffer_index, get_common_field, parse_base_pointer_info,
     resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols)
 
 
@@ -45,8 +45,8 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
         num_buffer_accesses += sum(1 for access in eq.atoms(Field.Access) if FieldType.is_buffer(access.field))
 
     # common shape and field to from the iteration space
-    common_shape = get_common_shape(fields_without_buffers)
-    common_field = list(sorted(fields_without_buffers, key=lambda e: str(e)))[0]
+    common_field = get_common_field(fields_without_buffers)
+    common_shape = common_field.spatial_shape
 
     if iteration_slice is None:
         # determine iteration slice from ghost layers
@@ -160,7 +160,7 @@ def created_indexed_cuda_kernel(assignments: Union[AssignmentCollection, NodeCol
                                 iteration_slice=[slice(None, None, None)] * len(idx_field.spatial_shape))
 
     function_body = Block(coordinate_symbol_assignments + assignments)
-    function_body = indexing.guard(function_body, get_common_shape(index_fields))
+    function_body = indexing.guard(function_body, get_common_field(index_fields).spatial_shape)
     ast = KernelFunction(function_body, Target.GPU, Backend.CUDA, make_python_function,
                          None, function_name, assignments=assignments)
     ast.global_variables.update(indexing.index_variables)
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index eb5c4f77e..02806d622 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -122,9 +122,10 @@ def unify_shape_symbols(body, common_shape, fields):
         body.subs(substitutions)
 
 
-def get_common_shape(field_set):
-    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
-    ValueError is raised"""
+def get_common_field(field_set):
+    """Takes a set of pystencils Fields, checks if a common spatial shape exists and returns one
+        representative field, that can be used for shape information etc. in the kernel creation.
+        If the fields have different shapes ValueError is raised"""
     nr_of_fixed_shaped_fields = 0
     for f in field_set:
         if f.has_fixed_shape:
@@ -142,8 +143,9 @@ def get_common_shape(field_set):
         if len(shape_set) != 1:
             raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
 
-    shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0]
-    return shape
+    # Sort the fields by their name to ensure that always the same field is returned
+    reference_field = list(sorted(field_set, key=lambda e: str(e)))[0]
+    return reference_field
 
 
 def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
@@ -178,13 +180,15 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
 
     if absolut_accesses_only:
         absolut_access_fields = {e.field for e in body.atoms(Field.Access)}
-        shape = get_common_shape(absolut_access_fields)
+        common_field = get_common_field(absolut_access_fields)
+        common_shape = common_field.spatial_shape
     else:
-        shape = get_common_shape(fields)
-    unify_shape_symbols(body, common_shape=shape, fields=fields)
+        common_field = get_common_field(fields)
+        common_shape = common_field.spatial_shape
+    unify_shape_symbols(body, common_shape=common_shape, fields=fields)
 
     if iteration_slice is not None:
-        iteration_slice = normalize_slice(iteration_slice, shape)
+        iteration_slice = normalize_slice(iteration_slice, common_shape)
 
     if ghost_layers is None:
         if absolut_accesses_only:
@@ -199,7 +203,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
     for i, loop_coordinate in enumerate(reversed(loop_order)):
         if iteration_slice is None:
             begin = ghost_layers[loop_coordinate][0]
-            end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
+            end = common_shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
             new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
             current_body = ast.Block([new_loop])
         else:
-- 
GitLab