diff --git a/pystencils/field.py b/pystencils/field.py index 507be0c2cad932eed64419b70e1b55783cff4273..69adfcb8acda5a4af24fd3de97aaa6940c9ea6bd 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -15,7 +15,8 @@ import pystencils from pystencils.alignedarray import aligned_empty from pystencils.data_types import StructType, TypedSymbol, create_type from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol -from pystencils.stencil import direction_string_to_offset, offset_to_direction_string, inverse_direction +from pystencils.stencil import ( + direction_string_to_offset, inverse_direction, offset_to_direction_string) from pystencils.sympyextensions import is_integer_sequence __all__ = ['Field', 'fields', 'FieldType', 'AbstractField'] @@ -328,10 +329,10 @@ class Field(AbstractField): self._layout = normalize_layout(layout) self.shape = shape self.strides = strides - self.latex_name = None # type: Optional[str] - self.coordinate_origin = sp.Matrix(tuple( + self.latex_name: Optional[str] = None + self.coordinate_origin: tuple[float, sp.Symbol] = sp.Matrix(tuple( 0 for _ in range(self.spatial_dimensions) - )) # type: tuple[float,sp.Symbol] + )) # type self.coordinate_transform = sp.eye(self.spatial_dimensions) if field_type == FieldType.STAGGERED: assert self.staggered_stencil @@ -432,7 +433,7 @@ class Field(AbstractField): return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])]) elif len(index_shape) == 3: return sp.Matrix([[[self(i, j, k) for k in range(index_shape[2])] - for j in range(index_shape[1])] for i in range(index_shape[0])]) + for j in range(index_shape[1])] for i in range(index_shape[0])]) else: raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions") @@ -453,7 +454,7 @@ class Field(AbstractField): return sp.Matrix([self.__getitem__(offset)(i) for i in range(self.index_shape[0])]) elif self.index_dimensions == 2: return sp.Matrix([[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])] - for i in range(self.index_shape[0])]) + for i in range(self.index_shape[0])]) else: raise NotImplementedError("neighbor_vector is not implemented for more than 2 index dimensions") @@ -528,7 +529,7 @@ class Field(AbstractField): prefactor = -1 if neighbor not in self.staggered_stencil: raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig, - self.staggered_stencil_name)) + self.staggered_stencil_name)) offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec)) @@ -562,7 +563,7 @@ class Field(AbstractField): return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])]) elif self.index_dimensions == 3: return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])] - for i in range(self.index_shape[1])]) + for i in range(self.index_shape[1])]) else: raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions") @@ -613,7 +614,10 @@ class Field(AbstractField): @property def physical_coordinates(self): - return self.coordinate_transform @ (self.coordinate_origin + pystencils.x_vector(self.spatial_dimensions)) + if hasattr(self.coordinate_transform, '__call__'): + return self.coordinate_transform(self.coordinate_origin + pystencils.x_vector(self.spatial_dimensions)) + else: + return self.coordinate_transform @ (self.coordinate_origin + pystencils.x_vector(self.spatial_dimensions)) @property def physical_coordinates_staggered(self): @@ -623,10 +627,23 @@ class Field(AbstractField): def index_to_physical(self, index_coordinates, staggered=False): if staggered: index_coordinates = sp.Matrix([i + 0.5 for i in index_coordinates]) - return self.coordinate_transform @ (self.coordinate_origin + index_coordinates) + if hasattr(self.coordinate_transform, '__call__'): + return self.coordinate_transform(self.coordinate_origin + index_coordinates) + else: + return self.coordinate_transform @ (self.coordinate_origin + index_coordinates) def physical_to_index(self, physical_coordinates, staggered=False): - rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin + if hasattr(self.coordinate_transform, '__call__'): + if hasattr(self.coordinate_transform, 'inv'): + return self.coordinate_transform.inv()(physical_coordinates) - self.coordinate_origin + else: + idx = sp.Matrix(sp.symbols(f'index_coordinates:{self.ndim}', real=True)) + rtn = sp.solve(self.index_to_physical(idx) - physical_coordinates, idx) + assert rtn, f'Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}' + return rtn + + else: + rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin if staggered: rtn = sp.Matrix([i - 0.5 for i in rtn])