diff --git a/field.py b/field.py index b9b6648cf7a4de325314ab397828a91ae7d6378f..3d646077643067e317422489a2b1e98f9e024030 100644 --- a/field.py +++ b/field.py @@ -10,6 +10,73 @@ from pystencils.data_types import TypedSymbol, create_type, create_composite_typ from pystencils.sympyextensions import is_integer_sequence +def fields(description=None, index_dimensions=0, layout=None, **kwargs): + """Creates pystencils fields from a string description. + + Examples: + Create a 2D scalar and vector field + >>> s, v = fields("s, v(2): double[2D]") + >>> assert s.spatial_dimensions == 2 and s.index_dimensions == 0 + >>> assert v.spatial_dimensions == 2 and v.index_dimensions == 1 and v.index_shape == (2,) + + Create an integer field of shape (10, 20) + >>> f = fields("f : int32[10, 20]") + >>> f.has_fixed_shape, f.shape + (True, (10, 20)) + + Numpy arrays can be used as template for shape and data type of field + >>> arr_s, arr_v = np.zeros([20, 20]), np.zeros([20, 20, 2]) + >>> s, v = fields("s, v(2)", s=arr_s, v=arr_v) + >>> assert s.index_dimensions == 0 and v.index_shape == (2,) and s.dtype.numpy_dtype == arr_s.dtype + + Format string can be left out, field names are taken from keyword arguments. + >>> fields(f1=arr_s, f2=arr_s) + [f1, f2] + + The keyword names 'index_dimension' and 'layout' have special meaning and thus can not be used to pass + numpy arrays: + >>> f = fields(f=arr_v, index_dimensions=1) + >>> assert f.index_dimensions == 1 + + >>> f = fields("pdfs(19) : float32[3D]", layout='fzyx') + >>> f.layout + (2, 1, 0) + """ + result = [] + if description: + field_descriptions, dtype, shape = _parse_description(description) + layout = 'numpy' if layout is None else layout + for field_name, idx_shape in field_descriptions: + if field_name in kwargs: + arr = kwargs[field_name] + idx_shape_of_arr = () if not len(idx_shape) else arr.shape[-len(idx_shape):] + assert idx_shape_of_arr == idx_shape + f = Field.create_from_numpy_array(field_name, kwargs[field_name], index_dimensions=len(idx_shape)) + elif isinstance(shape, tuple): + f = Field.create_fixed_size(field_name, shape + idx_shape, dtype=dtype, + index_dimensions=len(idx_shape), layout=layout) + elif isinstance(shape, int): + f = Field.create_generic(field_name, spatial_dimensions=shape, dtype=dtype, + index_shape=idx_shape, layout=layout) + elif shape is None: + f = Field.create_generic(field_name, spatial_dimensions=2, dtype=dtype, + index_shape=idx_shape, layout=layout) + else: + assert False + result.append(f) + else: + assert layout is None, "Layout can not be specified when creating Field from numpy array" + for field_name, arr in kwargs.items(): + result.append(Field.create_from_numpy_array(field_name, arr, index_dimensions=index_dimensions)) + + if len(result) == 0: + return None + elif len(result) == 1: + return result[0] + else: + return result + + class FieldType(Enum): # generic fields GENERIC = 0 @@ -418,6 +485,9 @@ class Field: def get_shifted(self, *shift)-> 'Field.Access': return Field.Access(self.field, tuple(a + b for a, b in zip(shift, self.offsets)), self.index) + def at_index(self, *idx_tuple): + return Field.Access(self.field, self.offsets, idx_tuple) + def _hashable_content(self): super_class_contents = list(super(Field.Access, self)._hashable_content()) t = tuple(super_class_contents + [hash(self._field), self._index] + self._offsets) @@ -654,3 +724,55 @@ def direction_string_to_offset(direction: str, dim: int = 3): offset += factor * cur_offset direction = direction[1:] return offset[:dim] + +def _parse_type_description(type_description): + if not type_description: + return np.float64, None + elif '[' in type_description: + assert type_description[-1] == ']' + splitted = type_description[:-1].split("[", ) + type_part, size_part = type_description[:-1].split("[", ) + if not type_part: + type_part = "float64" + if size_part.lower()[-1] == 'd': + size_part = int(size_part[:-1]) + else: + size_part = tuple(int(i) for i in size_part.split(',')) + else: + type_part, size_part = type_description, None + + dtype = np.dtype(type_part).type + return dtype, size_part + + +def _parse_field_description(description): + if '(' not in description: + return description, () + assert description[-1] == ')' + name, index_shape = description[:-1].split('(') + index_shape = tuple(int(i) for i in index_shape.split(',')) + return name, index_shape + + +def _parse_description(description): + description = description.replace(' ', '') + if ':' in description: + name_descr, type_descr = description.split(':') + else: + name_descr, type_descr = description, '' + + # correct ',' splits inside brackets + field_names = name_descr.split(',') + cleaned_field_names = [] + prefix = '' + for field_name in field_names: + full_field_name = prefix + field_name + if '(' in full_field_name and ')' not in full_field_name: + prefix += field_name + ',' + else: + prefix = '' + cleaned_field_names.append(full_field_name) + + dtype, size = _parse_type_description(type_descr) + fields_info = tuple(_parse_field_description(fd) for fd in cleaned_field_names) + return fields_info, dtype, size