communication.py 6.3 KB
Newer Older
Frederik Hennig's avatar
Frederik Hennig committed
1
from pystencils import Field
2
from pystencils.slicing import shift_slice, get_slice_before_ghost_layer
3
4
5
from lbmpy.advanced_streaming.utility import is_inplace, get_accessor, numeric_index, numeric_offsets
from pystencils.datahandling import SerialDataHandling
from itertools import chain
Frederik Hennig's avatar
Frederik Hennig committed
6

Frederik Hennig's avatar
Frederik Hennig committed
7

8
def trim_slice_in_direction(slices, direction):
Frederik Hennig's avatar
Frederik Hennig committed
9
    assert len(slices) == len(direction)
Frederik Hennig's avatar
Frederik Hennig committed
10
11
12

    result = []
    for s, d in zip(slices, direction):
13
14
15
        if isinstance(s, int):
            result.append(s)
            continue
Frederik Hennig's avatar
Frederik Hennig committed
16
17
18
19
20
21
22
        start = s.start + 1 if d == -1 else s.start
        stop = s.stop - 1 if d == 1 else s.stop
        result.append(slice(start, stop, s.step))

    return tuple(result)


Frederik Hennig's avatar
Frederik Hennig committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def extend_dir(direction):
    if len(direction) == 0:
        yield tuple()
    elif direction[0] == 0:
        for d in [-1, 0, 1]:
            for rest in extend_dir(direction[1:]):
                yield (d, ) + rest
    else:
        for rest in extend_dir(direction[1:]):
            yield (direction[0], ) + rest


def _get_neighbour_transform(direction, ghost_layers):
    return tuple(d * (ghost_layers + 1) for d in direction)


Frederik Hennig's avatar
Frederik Hennig committed
39
def _fix_length_one_slices(slices):
40
41
42
43
44
45
46
    """Slices of length one are replaced by their start value for correct periodic shifting"""
    if isinstance(slices, int):
        return slices
    elif isinstance(slices, slice):
        if slices.stop is not None and abs(slices.start - slices.stop) == 1:
            return slices.start
        elif slices.stop is None and slices.start == -1:
Frederik Hennig's avatar
Frederik Hennig committed
47
            return -1  # [-1:] also has length one
Frederik Hennig's avatar
Frederik Hennig committed
48
        else:
49
50
51
            return slices
    else:
        return tuple(_fix_length_one_slices(s) for s in slices)
Frederik Hennig's avatar
Frederik Hennig committed
52
53


Frederik Hennig's avatar
Frederik Hennig committed
54
55
def get_communication_slices(
        stencil, comm_stencil=None, streaming_pattern='pull', after_timestep='both', ghost_layers=1):
Frederik Hennig's avatar
Frederik Hennig committed
56
    """
57
    Return the source and destination slices for periodicity handling or communication between blocks.
Frederik Hennig's avatar
Frederik Hennig committed
58

59
60
    :param stencil: The stencil used by the LB method.
    :param comm_stencil: The stencil defining the communication directions. If None, it will be set to stencil.
61
    :param streaming_pattern: The streaming pattern.
Frederik Hennig's avatar
Frederik Hennig committed
62
63
64
65
    :param after_timestep: Timestep after which communication is run; either 'even', 'odd' or 'both'.
    :param ghost_layers: Number of ghost layers in each direction.

    """
66
67
68
    if comm_stencil is None:
        comm_stencil = stencil

Frederik Hennig's avatar
Frederik Hennig committed
69
70
    pdfs = Field.create_generic('pdfs', spatial_dimensions=len(stencil[0]), index_shape=(len(stencil),))
    write_accesses = get_accessor(streaming_pattern, after_timestep).write(pdfs, stencil)
71
    slices_per_comm_direction = dict()
Frederik Hennig's avatar
Frederik Hennig committed
72

73
    for comm_dir in comm_stencil:
Frederik Hennig's avatar
Frederik Hennig committed
74
75
76
        if all(d == 0 for d in comm_dir):
            continue

77
        slices_for_dir = []
Frederik Hennig's avatar
Frederik Hennig committed
78

Frederik Hennig's avatar
Frederik Hennig committed
79
80
        for streaming_dir in set(extend_dir(comm_dir)) & set(stencil):
            d = stencil.index(streaming_dir)
Frederik Hennig's avatar
Frederik Hennig committed
81
82
            write_offsets = numeric_offsets(write_accesses[d])
            write_index = numeric_index(write_accesses[d])[0]
Frederik Hennig's avatar
Frederik Hennig committed
83

Frederik Hennig's avatar
Frederik Hennig committed
84
85
            tangential_dir = tuple(s - c for s, c in zip(streaming_dir, comm_dir))
            origin_slice = get_slice_before_ghost_layer(comm_dir, ghost_layers=ghost_layers, thickness=1)
Frederik Hennig's avatar
Frederik Hennig committed
86
            origin_slice = _fix_length_one_slices(origin_slice)
87
            src_slice = shift_slice(trim_slice_in_direction(origin_slice, tangential_dir), write_offsets)
Frederik Hennig's avatar
Frederik Hennig committed
88

Frederik Hennig's avatar
Frederik Hennig committed
89
90
            neighbour_transform = _get_neighbour_transform(comm_dir, ghost_layers)
            dst_slice = shift_slice(src_slice, neighbour_transform)
Frederik Hennig's avatar
Frederik Hennig committed
91

92
93
94
95
            src_slice = src_slice + (write_index, )
            dst_slice = dst_slice + (write_index, )

            slices_for_dir.append((src_slice, dst_slice))
Frederik Hennig's avatar
Frederik Hennig committed
96

97
98
        slices_per_comm_direction[comm_dir] = slices_for_dir
    return slices_per_comm_direction
99
100
101
102


class PeriodicityHandling:

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    def __init__(self, stencil, data_handling, pdf_field_name,
                 streaming_pattern='pull', zeroth_timestep='both',
                 ghost_layers=1, gpu=False):
        if not isinstance(data_handling, SerialDataHandling):
            raise ValueError('Only single node data handling is supported!')

        self.stencil = stencil
        self.dh = data_handling
        self.pdf_field_name = pdf_field_name
        periodicity = data_handling.periodicity
        self.inplace_pattern = is_inplace(streaming_pattern)
        self.gpu = gpu

        def is_copy_direction(direction):
            for d, p in zip(direction, periodicity):
                if d != 0 and not p:
                    return False

            return True

        copy_directions = tuple(filter(is_copy_direction, stencil[1:]))
        if self.inplace_pattern:
            self.comm_slices = dict()
            for timestep in ['even', 'odd']:
                slices_per_comm_dir = get_communication_slices(stencil=stencil,
                                                               comm_stencil=copy_directions,
                                                               streaming_pattern=streaming_pattern,
                                                               after_timestep=timestep,
                                                               ghost_layers=ghost_layers)
                self.comm_slices[timestep] = list(chain.from_iterable(v for k, v in slices_per_comm_dir.items()))
        else:
            slices_per_comm_dir = get_communication_slices(stencil=stencil,
                                                           comm_stencil=copy_directions,
                                                           streaming_pattern=streaming_pattern,
                                                           after_timestep='both',
                                                           ghost_layers=ghost_layers)
            self.comm_slices = list(chain.from_iterable(v for k, v in slices_per_comm_dir.items()))

    def __call__(self, timestep_modulus='both'):
        if self.gpu:
            self._periodicity_handling_gpu(timestep_modulus)
        else:
            self._periodicity_handling_cpu(timestep_modulus)

    def _periodicity_handling_cpu(self, timestep):
        arr = self.dh.cpu_arrays[self.pdf_field_name]
        if timestep == 'both':
            comm_slices = self.comm_slices
        else:
            comm_slices = self.comm_slices[timestep]
        for src, dst in comm_slices:
            arr[dst] = arr[src]

    def _periodicity_handling_gpu(self, timestep):
Frederik Hennig's avatar
Frederik Hennig committed
157
        raise NotImplementedError()