Commit 97032f3d authored by Frederik Hennig's avatar Frederik Hennig
Browse files

communication slices WIP

parent a6c4c41d
Pipeline #27022 failed with stage
in 16 minutes and 19 seconds
import numpy as np
import pystencils as ps
from pystencils.slicing import make_slice, shift_slice, get_slice_before_ghost_layer
from lbmpy.stencils import get_stencil
from lbmpy.advanced_streaming.utility import get_accessor
def cut_slice_in_direction(slices, direction):
assert len(slices) == len(dir)
result = []
for s, d in zip(slices, direction):
start = s.start + 1 if d == -1 else s.start
stop = s.stop - 1 if d == 1 else s.stop
result.append(slice(start, stop, s.step))
return tuple(result)
class CommunicationSlices:
def __init__(self, lb_method, kernel_type='pull', after_timestep='both',
pdf_field=None, ghost_layers=1):
self.stencil = lb_method.stencil
self.dim = lb_method.dim
self.q = len(self.stencil)
self.ghost_layers = ghost_layers
if pdf_field is None:
pdf_field = ps.Field.create_generic(
'pdfs', spatial_dimensions=self.dim, index_shape=(self.q, ))
self.pdf_field = pdf_field
self.write_accesses = get_accessor(kernel_type, after_timestep).write()
def copy_slices(self, copy_dir, streaming_dir):
d = self.stencil.index(streaming_dir)
write_offsets = self.write_accesses[d].offsets
write_index = self.write_accesses[d].index[0]
tangential_dir = tuple(s - c for s, c in zip(streaming_dir, copy_dir))
origin_slice = get_slice_before_ghost_layer(copy_dir, ghost_layers=self.ghost_layers, thickness=1)
src_slice = shift_slice(cut_slice_in_direction(origin_slice, tangential_dir), write_offsets)
# TODO: Calculate neighbour anchor point and transform src_slice to obtain dst_slice
raise NotImplementedError()
def copy_slices_for_comm_direction(self, copy_dir):
dim = len(copy_dir[0])
raise NotImplementedError()
from lbmpy.fieldaccess import StreamPullTwoFieldsAccessor, \
from lbmpy.fieldaccess import PdfFieldAccessor, \
StreamPullTwoFieldsAccessor, \
StreamPushTwoFieldsAccessor, \
AAEvenTimeStepAccessor, \
AAOddTimeStepAccessor, \
......@@ -24,7 +25,7 @@ odd_accessors = {
'esotwist': EsoTwistOddTimeStepAccessor
}
def get_accessor(kernel_type, timestep):
def get_accessor(kernel_type, timestep) -> PdfFieldAccessor:
if kernel_type not in supported_kernels:
raise ValueError(
"Invalid value of parameter 'kernel_type'.", kernel_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment