diff --git a/transformations.py b/transformations.py
index 0c616a436d10dbb5ded66dee2f997e998eb6306d..2681585408892a2994e43e98bea4fc9b5ccab28a 100644
--- a/transformations.py
+++ b/transformations.py
@@ -11,7 +11,7 @@ from pystencils.assignment_collection.nestedscopes import NestedScopes
 from pystencils.field import Field, FieldType
 from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \
     pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
-from pystencils.kernelparameters import FieldPointerSymbol, FieldStrideSymbol
+from pystencils.kernelparameters import FieldPointerSymbol
 from pystencils.slicing import normalize_slice
 import pystencils.astnodes as ast
 
@@ -1001,10 +1001,9 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
 
     inner_loop_counter = inner_loop_counters.pop()
 
-    stride_params = [p for p in ast_node.get_parameters() if isinstance(p.symbol, FieldStrideSymbol)]
-    subs_dict = {}
-    for stride_param in stride_params:
-        stride_symbol = stride_param.symbol
-        subs_dict.update({IndexedBase(stride_symbol, shape=(1,))[inner_loop_counter]: 1})
+    parameters = ast_node.get_parameters()
+    stride_params = [p.symbol for p in parameters
+                     if p.is_field_stride and p.symbol.coordinate == inner_loop_counter]
+    subs_dict = {stride_param: 1 for stride_param in stride_params}
     if subs_dict:
         ast_node.subs(subs_dict)