Skip to content
Snippets Groups Projects
Commit 77087e9e authored by Martin Bauer's avatar Martin Bauer
Browse files

Added convenience function to create pystencils fields

parent 362b4611
Branches
Tags
No related merge requests found
......@@ -10,6 +10,73 @@ from pystencils.data_types import TypedSymbol, create_type, create_composite_typ
from pystencils.sympyextensions import is_integer_sequence
def fields(description=None, index_dimensions=0, layout=None, **kwargs):
"""Creates pystencils fields from a string description.
Examples:
Create a 2D scalar and vector field
>>> s, v = fields("s, v(2): double[2D]")
>>> assert s.spatial_dimensions == 2 and s.index_dimensions == 0
>>> assert v.spatial_dimensions == 2 and v.index_dimensions == 1 and v.index_shape == (2,)
Create an integer field of shape (10, 20)
>>> f = fields("f : int32[10, 20]")
>>> f.has_fixed_shape, f.shape
(True, (10, 20))
Numpy arrays can be used as template for shape and data type of field
>>> arr_s, arr_v = np.zeros([20, 20]), np.zeros([20, 20, 2])
>>> s, v = fields("s, v(2)", s=arr_s, v=arr_v)
>>> assert s.index_dimensions == 0 and v.index_shape == (2,) and s.dtype.numpy_dtype == arr_s.dtype
Format string can be left out, field names are taken from keyword arguments.
>>> fields(f1=arr_s, f2=arr_s)
[f1, f2]
The keyword names 'index_dimension' and 'layout' have special meaning and thus can not be used to pass
numpy arrays:
>>> f = fields(f=arr_v, index_dimensions=1)
>>> assert f.index_dimensions == 1
>>> f = fields("pdfs(19) : float32[3D]", layout='fzyx')
>>> f.layout
(2, 1, 0)
"""
result = []
if description:
field_descriptions, dtype, shape = _parse_description(description)
layout = 'numpy' if layout is None else layout
for field_name, idx_shape in field_descriptions:
if field_name in kwargs:
arr = kwargs[field_name]
idx_shape_of_arr = () if not len(idx_shape) else arr.shape[-len(idx_shape):]
assert idx_shape_of_arr == idx_shape
f = Field.create_from_numpy_array(field_name, kwargs[field_name], index_dimensions=len(idx_shape))
elif isinstance(shape, tuple):
f = Field.create_fixed_size(field_name, shape + idx_shape, dtype=dtype,
index_dimensions=len(idx_shape), layout=layout)
elif isinstance(shape, int):
f = Field.create_generic(field_name, spatial_dimensions=shape, dtype=dtype,
index_shape=idx_shape, layout=layout)
elif shape is None:
f = Field.create_generic(field_name, spatial_dimensions=2, dtype=dtype,
index_shape=idx_shape, layout=layout)
else:
assert False
result.append(f)
else:
assert layout is None, "Layout can not be specified when creating Field from numpy array"
for field_name, arr in kwargs.items():
result.append(Field.create_from_numpy_array(field_name, arr, index_dimensions=index_dimensions))
if len(result) == 0:
return None
elif len(result) == 1:
return result[0]
else:
return result
class FieldType(Enum):
# generic fields
GENERIC = 0
......@@ -418,6 +485,9 @@ class Field:
def get_shifted(self, *shift)-> 'Field.Access':
return Field.Access(self.field, tuple(a + b for a, b in zip(shift, self.offsets)), self.index)
def at_index(self, *idx_tuple):
return Field.Access(self.field, self.offsets, idx_tuple)
def _hashable_content(self):
super_class_contents = list(super(Field.Access, self)._hashable_content())
t = tuple(super_class_contents + [hash(self._field), self._index] + self._offsets)
......@@ -654,3 +724,55 @@ def direction_string_to_offset(direction: str, dim: int = 3):
offset += factor * cur_offset
direction = direction[1:]
return offset[:dim]
def _parse_type_description(type_description):
if not type_description:
return np.float64, None
elif '[' in type_description:
assert type_description[-1] == ']'
splitted = type_description[:-1].split("[", )
type_part, size_part = type_description[:-1].split("[", )
if not type_part:
type_part = "float64"
if size_part.lower()[-1] == 'd':
size_part = int(size_part[:-1])
else:
size_part = tuple(int(i) for i in size_part.split(','))
else:
type_part, size_part = type_description, None
dtype = np.dtype(type_part).type
return dtype, size_part
def _parse_field_description(description):
if '(' not in description:
return description, ()
assert description[-1] == ')'
name, index_shape = description[:-1].split('(')
index_shape = tuple(int(i) for i in index_shape.split(','))
return name, index_shape
def _parse_description(description):
description = description.replace(' ', '')
if ':' in description:
name_descr, type_descr = description.split(':')
else:
name_descr, type_descr = description, ''
# correct ',' splits inside brackets
field_names = name_descr.split(',')
cleaned_field_names = []
prefix = ''
for field_name in field_names:
full_field_name = prefix + field_name
if '(' in full_field_name and ')' not in full_field_name:
prefix += field_name + ','
else:
prefix = ''
cleaned_field_names.append(full_field_name)
dtype, size = _parse_type_description(type_descr)
fields_info = tuple(_parse_field_description(fd) for fd in cleaned_field_names)
return fields_info, dtype, size
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment