diff --git a/transformations.py b/transformations.py index 0c616a436d10dbb5ded66dee2f997e998eb6306d..2681585408892a2994e43e98bea4fc9b5ccab28a 100644 --- a/transformations.py +++ b/transformations.py @@ -11,7 +11,7 @@ from pystencils.assignment_collection.nestedscopes import NestedScopes from pystencils.field import Field, FieldType from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \ pointer_arithmetic_func, get_type_of_expression, collate_types, create_type -from pystencils.kernelparameters import FieldPointerSymbol, FieldStrideSymbol +from pystencils.kernelparameters import FieldPointerSymbol from pystencils.slicing import normalize_slice import pystencils.astnodes as ast @@ -1001,10 +1001,9 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None: inner_loop_counter = inner_loop_counters.pop() - stride_params = [p for p in ast_node.get_parameters() if isinstance(p.symbol, FieldStrideSymbol)] - subs_dict = {} - for stride_param in stride_params: - stride_symbol = stride_param.symbol - subs_dict.update({IndexedBase(stride_symbol, shape=(1,))[inner_loop_counter]: 1}) + parameters = ast_node.get_parameters() + stride_params = [p.symbol for p in parameters + if p.is_field_stride and p.symbol.coordinate == inner_loop_counter] + subs_dict = {stride_param: 1 for stride_param in stride_params} if subs_dict: ast_node.subs(subs_dict)