Commit 115d558d authored by Martin Bauer's avatar Martin Bauer
Browse files

SerialDataHandling now also supports arbitrary tensor fields

parent f16eadea
......@@ -72,7 +72,7 @@ class BoundaryHandling:
def __init__(self, data_handling, field_name, stencil, name="boundary_handling", flag_interface=None,
target='cpu', openmp=True):
assert data_handling.has_data(field_name)
assert data_handling.dim == len(stencil[0]), "Dimension of stencil and data handling do not match"
self._data_handling = data_handling
self._field_name = field_name
self._index_array_name = name + "IndexArrays"
......
......@@ -105,10 +105,10 @@ class ParallelBlock(Block):
super(ParallelBlock, self).__init__(offset, local_slice)
self._block = block
self._gls = inner_ghost_layers
self._namePrefix = name_prefix
self._name_prefix = name_prefix
def __getitem__(self, data_name):
result = self._block[self._namePrefix + data_name]
result = self._block[self._name_prefix + data_name]
type_name = type(result).__name__
if type_name == 'GhostLayerField':
result = wlb.field.toArray(result, withGhostLayers=self._gls)
......
......@@ -119,7 +119,7 @@ class DataHandling(ABC):
"""Returns the number of ghost layers for a specific field/array."""
@abstractmethod
def values_per_cell(self, name: str) -> int:
def values_per_cell(self, name: str) -> Tuple[int, ...]:
"""Returns values_per_cell of array."""
@abstractmethod
......@@ -239,7 +239,7 @@ class DataHandling(ABC):
# ------------------------------- Data access and modification -----------------------------------------------------
def fill(self, array_name: str, val, value_idx: Optional[int] = None,
def fill(self, array_name: str, val, value_idx: Optional[Tuple[int, ...]] = None,
slice_obj=None, ghost_layers=False, inner_ghost_layers=False) -> None:
"""Sets all cells to the same value.
......@@ -257,11 +257,13 @@ class DataHandling(ABC):
ghost_layers = self.ghost_layers_of_field(array_name)
if inner_ghost_layers is True:
ghost_layers = self.ghost_layers_of_field(array_name)
if value_idx is not None and self.values_per_cell(array_name) < 2:
raise ValueError("value_idx parameter only valid for fields with values_per_cell > 1")
for b in self.iterate(slice_obj, ghost_layers=ghost_layers, inner_ghost_layers=inner_ghost_layers):
if value_idx is not None:
b[array_name][..., value_idx].fill(val)
if isinstance(value_idx, int):
value_idx = (value_idx,)
assert len(value_idx) == len(self.values_per_cell(array_name))
b[array_name][(Ellipsis, *value_idx)].fill(val)
else:
b[array_name].fill(val)
......
......@@ -104,6 +104,8 @@ class ParallelDataHandling(DataHandling):
if alignment is False or alignment is None:
alignment = 0
if hasattr(values_per_cell, '__len__'):
raise NotImplementedError("Parallel data handling does not support multiple index dimensions")
self._fieldInformation[name] = {'ghost_layers': ghost_layers,
'values_per_cell': values_per_cell,
......
......@@ -86,6 +86,12 @@ class SerialDataHandling(DataHandling):
'shape': tuple(s + 2 * ghost_layers for s in self._domainSize),
'dtype': dtype,
}
if not hasattr(values_per_cell, '__len__'):
values_per_cell = (values_per_cell, )
if len(values_per_cell) == 1 and values_per_cell[0] == 1:
values_per_cell = ()
self._field_information[name] = {
'ghost_layers': ghost_layers,
'values_per_cell': values_per_cell,
......@@ -94,12 +100,12 @@ class SerialDataHandling(DataHandling):
'alignment': alignment,
}
if values_per_cell > 1:
kwargs['shape'] = kwargs['shape'] + (values_per_cell,)
index_dimensions = 1
layout_tuple = layout_string_to_tuple(layout, self.dim + 1)
index_dimensions = len(values_per_cell)
kwargs['shape'] = kwargs['shape'] + values_per_cell
if index_dimensions > 0:
layout_tuple = layout_string_to_tuple(layout, self.dim + index_dimensions)
else:
index_dimensions = 0
layout_tuple = spatial_layout_string_to_tuple(layout, self.dim)
# cpu_arr is always created - since there is no create_pycuda_array_with_layout()
......@@ -274,6 +280,14 @@ class SerialDataHandling(DataHandling):
result = []
for name in names:
gls = self._field_information[name]['ghost_layers']
values_per_cell = self._field_information[name]['values_per_cell']
if values_per_cell == ():
values_per_cell = (1, )
if len(values_per_cell) == 1:
values_per_cell = values_per_cell[0]
else:
raise NotImplementedError("Synchronization of this field is not supported: " + name)
if len(filtered_stencil) > 0:
if target == 'cpu':
from pystencils.slicing import get_periodic_boundary_functor
......@@ -282,7 +296,7 @@ class SerialDataHandling(DataHandling):
from pystencils.gpucuda.periodicity import get_periodic_boundary_functor as boundary_func
result.append(boundary_func(filtered_stencil, self._domainSize,
index_dimensions=self.fields[name].index_dimensions,
index_dim_shape=self._field_information[name]['values_per_cell'],
index_dim_shape=values_per_cell,
dtype=self.fields[name].dtype.numpy_dtype,
ghost_layers=gls))
......@@ -334,7 +348,8 @@ class SerialDataHandling(DataHandling):
for i in range(values_per_cell):
cell_data["%s[%d]" % (name, i)] = np.ascontiguousarray(field[..., i])
else:
assert False
raise NotImplementedError("VTK export for fields with more than one index "
"coordinate not implemented")
image_to_vtk(full_file_name, cell_data=cell_data)
return writer
......@@ -358,7 +373,8 @@ class SerialDataHandling(DataHandling):
ghost_layers = actual_ghost_layers
gl_to_remove = actual_ghost_layers - ghost_layers
ind_dims = 1 if self._field_information[name]['values_per_cell'] > 1 else 0
assert len(self._field_information[name]['values_per_cell']) == 1
ind_dims = 1 if self._field_information[name]['values_per_cell'][0] > 1 else 0
return remove_ghost_layers(self.cpu_arrays[name], ind_dims, gl_to_remove)
def log(self, *args, level='INFO'):
......
from types import MappingProxyType
import sympy as sp
from pystencils.field import Field
from pystencils.assignment import Assignment
from pystencils.astnodes import LoopOverCoordinate, Conditional, Block, SympyAssignment
from pystencils.cpu.vectorization import vectorize
......@@ -157,6 +158,21 @@ def create_indexed_kernel(assignments, index_fields, target='cpu', data_type="do
else:
raise ValueError("Unknown target %s. Has to be either 'cpu' or 'gpu'" % (target,))
def create_staggered_kernel_from_assignments(assignments, **kwargs):
assert 'iteration_slice' not in kwargs and 'ghost_layers' not in kwargs
lhs_fields = {a.lhs.atoms(Field.Access) for a in assignments}
assert len(lhs_fields) == 1
staggered_field = lhs_fields.pop()
dim = staggered_field.spatial_dimensions
counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)]
conditions = [counters[i] < staggered_field.shape[i] - 1 for i in range(dim)]
guarded_assignments = []
for d in range(dim):
cond = sp.And(*[conditions[i] for i in range(dim) if d != i])
guarded_assignments.append(Conditional(cond, Block(assignments)))
def create_staggered_kernel(staggered_field, expressions, subexpressions=(), target='cpu', **kwargs):
"""Kernel that updates a staggered field.
......@@ -165,11 +181,11 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar
Args:
staggered_field: field where the first index coordinate defines the location of the staggered value
can have 1 or 2 index coordinates, in case of of two index coordinates at every staggered location
a vector is stored, expressions has to be a sequence of sequences then
can have 1 or 2 index coordinates, in case of two index coordinates at every staggered location
a vector is stored, expressions parameter has to be a sequence of sequences then
where e.g. ``f[0,0](0)`` is interpreted as value at the left cell boundary, ``f[1,0](0)`` the right cell
boundary and ``f[0,0](1)`` the southern cell boundary etc.
expressions: sequence of expressions of length dim, defining how the east, southern, (bottom) cell boundary
expressions: sequence of expressions of length dim, defining how the west, southern, (bottom) cell boundary
should be updated.
subexpressions: optional sequence of Assignments, that define subexpressions used in the main expressions
target: 'cpu' or 'gpu'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment