Commit 77087e9e authored by Martin Bauer's avatar Martin Bauer
Browse files

Added convenience function to create pystencils fields

parent 362b4611
......@@ -10,6 +10,73 @@ from pystencils.data_types import TypedSymbol, create_type, create_composite_typ
from pystencils.sympyextensions import is_integer_sequence
def fields(description=None, index_dimensions=0, layout=None, **kwargs):
"""Creates pystencils fields from a string description.
Examples:
Create a 2D scalar and vector field
>>> s, v = fields("s, v(2): double[2D]")
>>> assert s.spatial_dimensions == 2 and s.index_dimensions == 0
>>> assert v.spatial_dimensions == 2 and v.index_dimensions == 1 and v.index_shape == (2,)
Create an integer field of shape (10, 20)
>>> f = fields("f : int32[10, 20]")
>>> f.has_fixed_shape, f.shape
(True, (10, 20))
Numpy arrays can be used as template for shape and data type of field
>>> arr_s, arr_v = np.zeros([20, 20]), np.zeros([20, 20, 2])
>>> s, v = fields("s, v(2)", s=arr_s, v=arr_v)
>>> assert s.index_dimensions == 0 and v.index_shape == (2,) and s.dtype.numpy_dtype == arr_s.dtype
Format string can be left out, field names are taken from keyword arguments.
>>> fields(f1=arr_s, f2=arr_s)
[f1, f2]
The keyword names 'index_dimension' and 'layout' have special meaning and thus can not be used to pass
numpy arrays:
>>> f = fields(f=arr_v, index_dimensions=1)
>>> assert f.index_dimensions == 1
>>> f = fields("pdfs(19) : float32[3D]", layout='fzyx')
>>> f.layout
(2, 1, 0)
"""
result = []
if description:
field_descriptions, dtype, shape = _parse_description(description)
layout = 'numpy' if layout is None else layout
for field_name, idx_shape in field_descriptions:
if field_name in kwargs:
arr = kwargs[field_name]
idx_shape_of_arr = () if not len(idx_shape) else arr.shape[-len(idx_shape):]
assert idx_shape_of_arr == idx_shape
f = Field.create_from_numpy_array(field_name, kwargs[field_name], index_dimensions=len(idx_shape))
elif isinstance(shape, tuple):
f = Field.create_fixed_size(field_name, shape + idx_shape, dtype=dtype,
index_dimensions=len(idx_shape), layout=layout)
elif isinstance(shape, int):
f = Field.create_generic(field_name, spatial_dimensions=shape, dtype=dtype,
index_shape=idx_shape, layout=layout)
elif shape is None:
f = Field.create_generic(field_name, spatial_dimensions=2, dtype=dtype,
index_shape=idx_shape, layout=layout)
else:
assert False
result.append(f)
else:
assert layout is None, "Layout can not be specified when creating Field from numpy array"
for field_name, arr in kwargs.items():
result.append(Field.create_from_numpy_array(field_name, arr, index_dimensions=index_dimensions))
if len(result) == 0:
return None
elif len(result) == 1:
return result[0]
else:
return result
class FieldType(Enum):
# generic fields
GENERIC = 0
......@@ -418,6 +485,9 @@ class Field:
def get_shifted(self, *shift)-> 'Field.Access':
return Field.Access(self.field, tuple(a + b for a, b in zip(shift, self.offsets)), self.index)
def at_index(self, *idx_tuple):
return Field.Access(self.field, self.offsets, idx_tuple)
def _hashable_content(self):
super_class_contents = list(super(Field.Access, self)._hashable_content())
t = tuple(super_class_contents + [hash(self._field), self._index] + self._offsets)
......@@ -654,3 +724,55 @@ def direction_string_to_offset(direction: str, dim: int = 3):
offset += factor * cur_offset
direction = direction[1:]
return offset[:dim]
def _parse_type_description(type_description):
if not type_description:
return np.float64, None
elif '[' in type_description:
assert type_description[-1] == ']'
splitted = type_description[:-1].split("[", )
type_part, size_part = type_description[:-1].split("[", )
if not type_part:
type_part = "float64"
if size_part.lower()[-1] == 'd':
size_part = int(size_part[:-1])
else:
size_part = tuple(int(i) for i in size_part.split(','))
else:
type_part, size_part = type_description, None
dtype = np.dtype(type_part).type
return dtype, size_part
def _parse_field_description(description):
if '(' not in description:
return description, ()
assert description[-1] == ')'
name, index_shape = description[:-1].split('(')
index_shape = tuple(int(i) for i in index_shape.split(','))
return name, index_shape
def _parse_description(description):
description = description.replace(' ', '')
if ':' in description:
name_descr, type_descr = description.split(':')
else:
name_descr, type_descr = description, ''
# correct ',' splits inside brackets
field_names = name_descr.split(',')
cleaned_field_names = []
prefix = ''
for field_name in field_names:
full_field_name = prefix + field_name
if '(' in full_field_name and ')' not in full_field_name:
prefix += field_name + ','
else:
prefix = ''
cleaned_field_names.append(full_field_name)
dtype, size = _parse_type_description(type_descr)
fields_info = tuple(_parse_field_description(fd) for fd in cleaned_field_names)
return fields_info, dtype, size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment