From 74567900ea73f39e06d10ceed0dc24308e611e40 Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Wed, 29 Mar 2023 09:39:04 +0200 Subject: [PATCH] Replace get common shape --- pystencils/gpucuda/kernelcreation.py | 8 ++++---- pystencils/transformations.py | 24 ++++++++++++++---------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py index 7bc7fb9dd..96a531383 100644 --- a/pystencils/gpucuda/kernelcreation.py +++ b/pystencils/gpucuda/kernelcreation.py @@ -13,7 +13,7 @@ from pystencils.node_collection import NodeCollection from pystencils.gpucuda.indexing import indexing_creator_from_params from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.transformations import ( - get_base_buffer_index, get_common_shape, parse_base_pointer_info, + get_base_buffer_index, get_common_field, parse_base_pointer_info, resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols) @@ -45,8 +45,8 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], num_buffer_accesses += sum(1 for access in eq.atoms(Field.Access) if FieldType.is_buffer(access.field)) # common shape and field to from the iteration space - common_shape = get_common_shape(fields_without_buffers) - common_field = list(sorted(fields_without_buffers, key=lambda e: str(e)))[0] + common_field = get_common_field(fields_without_buffers) + common_shape = common_field.spatial_shape if iteration_slice is None: # determine iteration slice from ghost layers @@ -160,7 +160,7 @@ def created_indexed_cuda_kernel(assignments: Union[AssignmentCollection, NodeCol iteration_slice=[slice(None, None, None)] * len(idx_field.spatial_shape)) function_body = Block(coordinate_symbol_assignments + assignments) - function_body = indexing.guard(function_body, get_common_shape(index_fields)) + function_body = indexing.guard(function_body, get_common_field(index_fields).spatial_shape) ast = KernelFunction(function_body, Target.GPU, Backend.CUDA, make_python_function, None, function_name, assignments=assignments) ast.global_variables.update(indexing.index_variables) diff --git a/pystencils/transformations.py b/pystencils/transformations.py index eb5c4f77e..02806d622 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -122,9 +122,10 @@ def unify_shape_symbols(body, common_shape, fields): body.subs(substitutions) -def get_common_shape(field_set): - """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise - ValueError is raised""" +def get_common_field(field_set): + """Takes a set of pystencils Fields, checks if a common spatial shape exists and returns one + representative field, that can be used for shape information etc. in the kernel creation. + If the fields have different shapes ValueError is raised""" nr_of_fixed_shaped_fields = 0 for f in field_set: if f.has_fixed_shape: @@ -142,8 +143,9 @@ def get_common_shape(field_set): if len(shape_set) != 1: raise ValueError("Differently sized field accesses in loop body: " + str(shape_set)) - shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0] - return shape + # Sort the fields by their name to ensure that always the same field is returned + reference_field = list(sorted(field_set, key=lambda e: str(e)))[0] + return reference_field def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None): @@ -178,13 +180,15 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or if absolut_accesses_only: absolut_access_fields = {e.field for e in body.atoms(Field.Access)} - shape = get_common_shape(absolut_access_fields) + common_field = get_common_field(absolut_access_fields) + common_shape = common_field.spatial_shape else: - shape = get_common_shape(fields) - unify_shape_symbols(body, common_shape=shape, fields=fields) + common_field = get_common_field(fields) + common_shape = common_field.spatial_shape + unify_shape_symbols(body, common_shape=common_shape, fields=fields) if iteration_slice is not None: - iteration_slice = normalize_slice(iteration_slice, shape) + iteration_slice = normalize_slice(iteration_slice, common_shape) if ghost_layers is None: if absolut_accesses_only: @@ -199,7 +203,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or for i, loop_coordinate in enumerate(reversed(loop_order)): if iteration_slice is None: begin = ghost_layers[loop_coordinate][0] - end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1] + end = common_shape[loop_coordinate] - ghost_layers[loop_coordinate][1] new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1) current_body = ast.Block([new_loop]) else: -- GitLab