diff --git a/pystencils/field.py b/pystencils/field.py
index 911ae827de1f7e386458f90e69599561712cf853..69adfcb8acda5a4af24fd3de97aaa6940c9ea6bd 100644
--- a/pystencils/field.py
+++ b/pystencils/field.py
@@ -329,10 +329,10 @@ class Field(AbstractField):
         self._layout = normalize_layout(layout)
         self.shape = shape
         self.strides = strides
-        self.latex_name = None  # type: Optional[str]
-        self.coordinate_origin = sp.Matrix(tuple(
+        self.latex_name: Optional[str] = None
+        self.coordinate_origin: tuple[float, sp.Symbol] = sp.Matrix(tuple(
             0 for _ in range(self.spatial_dimensions)
-        ))  # type: tuple[float,sp.Symbol]
+        ))  # type
         self.coordinate_transform = sp.eye(self.spatial_dimensions)
         if field_type == FieldType.STAGGERED:
             assert self.staggered_stencil
@@ -433,7 +433,7 @@ class Field(AbstractField):
             return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])])
         elif len(index_shape) == 3:
             return sp.Matrix([[[self(i, j, k) for k in range(index_shape[2])]
-                             for j in range(index_shape[1])] for i in range(index_shape[0])])
+                               for j in range(index_shape[1])] for i in range(index_shape[0])])
         else:
             raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions")
 
@@ -454,7 +454,7 @@ class Field(AbstractField):
             return sp.Matrix([self.__getitem__(offset)(i) for i in range(self.index_shape[0])])
         elif self.index_dimensions == 2:
             return sp.Matrix([[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])]
-                             for i in range(self.index_shape[0])])
+                              for i in range(self.index_shape[0])])
         else:
             raise NotImplementedError("neighbor_vector is not implemented for more than 2 index dimensions")
 
@@ -529,7 +529,7 @@ class Field(AbstractField):
                 prefactor = -1
         if neighbor not in self.staggered_stencil:
             raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
-                             self.staggered_stencil_name))
+                                                                                    self.staggered_stencil_name))
 
         offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec))
 
@@ -563,7 +563,7 @@ class Field(AbstractField):
             return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])])
         elif self.index_dimensions == 3:
             return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])]
-                             for i in range(self.index_shape[1])])
+                              for i in range(self.index_shape[1])])
         else:
             raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions")
 
@@ -627,10 +627,23 @@ class Field(AbstractField):
     def index_to_physical(self, index_coordinates, staggered=False):
         if staggered:
             index_coordinates = sp.Matrix([i + 0.5 for i in index_coordinates])
-        return self.coordinate_transform @ (self.coordinate_origin + index_coordinates)
+        if hasattr(self.coordinate_transform, '__call__'):
+            return self.coordinate_transform(self.coordinate_origin + index_coordinates)
+        else:
+            return self.coordinate_transform @ (self.coordinate_origin + index_coordinates)
 
     def physical_to_index(self, physical_coordinates, staggered=False):
-        rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin
+        if hasattr(self.coordinate_transform, '__call__'):
+            if hasattr(self.coordinate_transform, 'inv'):
+                return self.coordinate_transform.inv()(physical_coordinates) - self.coordinate_origin
+            else:
+                idx = sp.Matrix(sp.symbols(f'index_coordinates:{self.ndim}', real=True))
+                rtn = sp.solve(self.index_to_physical(idx) - physical_coordinates, idx)
+                assert rtn, f'Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}'
+                return rtn
+
+        else:
+            rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin
         if staggered:
             rtn = sp.Matrix([i - 0.5 for i in rtn])