diff --git a/pystencils/field.py b/pystencils/field.py index 911ae827de1f7e386458f90e69599561712cf853..69adfcb8acda5a4af24fd3de97aaa6940c9ea6bd 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -329,10 +329,10 @@ class Field(AbstractField): self._layout = normalize_layout(layout) self.shape = shape self.strides = strides - self.latex_name = None # type: Optional[str] - self.coordinate_origin = sp.Matrix(tuple( + self.latex_name: Optional[str] = None + self.coordinate_origin: tuple[float, sp.Symbol] = sp.Matrix(tuple( 0 for _ in range(self.spatial_dimensions) - )) # type: tuple[float,sp.Symbol] + )) # type self.coordinate_transform = sp.eye(self.spatial_dimensions) if field_type == FieldType.STAGGERED: assert self.staggered_stencil @@ -433,7 +433,7 @@ class Field(AbstractField): return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])]) elif len(index_shape) == 3: return sp.Matrix([[[self(i, j, k) for k in range(index_shape[2])] - for j in range(index_shape[1])] for i in range(index_shape[0])]) + for j in range(index_shape[1])] for i in range(index_shape[0])]) else: raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions") @@ -454,7 +454,7 @@ class Field(AbstractField): return sp.Matrix([self.__getitem__(offset)(i) for i in range(self.index_shape[0])]) elif self.index_dimensions == 2: return sp.Matrix([[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])] - for i in range(self.index_shape[0])]) + for i in range(self.index_shape[0])]) else: raise NotImplementedError("neighbor_vector is not implemented for more than 2 index dimensions") @@ -529,7 +529,7 @@ class Field(AbstractField): prefactor = -1 if neighbor not in self.staggered_stencil: raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig, - self.staggered_stencil_name)) + self.staggered_stencil_name)) offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec)) @@ -563,7 +563,7 @@ class Field(AbstractField): return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])]) elif self.index_dimensions == 3: return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])] - for i in range(self.index_shape[1])]) + for i in range(self.index_shape[1])]) else: raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions") @@ -627,10 +627,23 @@ class Field(AbstractField): def index_to_physical(self, index_coordinates, staggered=False): if staggered: index_coordinates = sp.Matrix([i + 0.5 for i in index_coordinates]) - return self.coordinate_transform @ (self.coordinate_origin + index_coordinates) + if hasattr(self.coordinate_transform, '__call__'): + return self.coordinate_transform(self.coordinate_origin + index_coordinates) + else: + return self.coordinate_transform @ (self.coordinate_origin + index_coordinates) def physical_to_index(self, physical_coordinates, staggered=False): - rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin + if hasattr(self.coordinate_transform, '__call__'): + if hasattr(self.coordinate_transform, 'inv'): + return self.coordinate_transform.inv()(physical_coordinates) - self.coordinate_origin + else: + idx = sp.Matrix(sp.symbols(f'index_coordinates:{self.ndim}', real=True)) + rtn = sp.solve(self.index_to_physical(idx) - physical_coordinates, idx) + assert rtn, f'Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}' + return rtn + + else: + rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin if staggered: rtn = sp.Matrix([i - 0.5 for i in rtn])