From f47d9378ea3640fa03e079b3c30c05a3111bb051 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Mon, 10 Dec 2018 13:49:55 +0100 Subject: [PATCH] Bugfix: `replace_inner_stride_with_one` did not work - bug was caused by change in stride parameter passing --- transformations.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/transformations.py b/transformations.py index 0c616a436..268158540 100644 --- a/transformations.py +++ b/transformations.py @@ -11,7 +11,7 @@ from pystencils.assignment_collection.nestedscopes import NestedScopes from pystencils.field import Field, FieldType from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \ pointer_arithmetic_func, get_type_of_expression, collate_types, create_type -from pystencils.kernelparameters import FieldPointerSymbol, FieldStrideSymbol +from pystencils.kernelparameters import FieldPointerSymbol from pystencils.slicing import normalize_slice import pystencils.astnodes as ast @@ -1001,10 +1001,9 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None: inner_loop_counter = inner_loop_counters.pop() - stride_params = [p for p in ast_node.get_parameters() if isinstance(p.symbol, FieldStrideSymbol)] - subs_dict = {} - for stride_param in stride_params: - stride_symbol = stride_param.symbol - subs_dict.update({IndexedBase(stride_symbol, shape=(1,))[inner_loop_counter]: 1}) + parameters = ast_node.get_parameters() + stride_params = [p.symbol for p in parameters + if p.is_field_stride and p.symbol.coordinate == inner_loop_counter] + subs_dict = {stride_param: 1 for stride_param in stride_params} if subs_dict: ast_node.subs(subs_dict) -- GitLab