interpolation_astnodes.py 19.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.

"""

"""

11
import hashlib
12
13
14
15
import itertools
from enum import Enum
from typing import Set

16
import sympy as sp
17
18
19
from sympy.core.cache import cacheit

import pystencils
20
21
22
23
24
25
26
27
from pystencils.astnodes import Node
from pystencils.data_types import TypedSymbol, cast_func, create_type

try:
    import pycuda.driver
except Exception:
    pass

28
29
_hash = hashlib.md5

30
31
32
33
34
35
36
37

class InterpolationMode(str, Enum):
    NEAREST_NEIGHBOR = "nearest_neighbour"
    NN = NEAREST_NEIGHBOR
    LINEAR = "linear"
    CUBIC_SPLINE = "cubic_spline"


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class _InterpolationSymbol(TypedSymbol):

    def __new__(cls, name, field, interpolator):
        obj = cls.__xnew_cached_(cls, name, field, interpolator)
        return obj

    def __new_stage2__(cls, name, field, interpolator):
        obj = super().__xnew__(cls, name, 'dummy_symbol_carrying_field' + field.name)
        obj.field = field
        obj.interpolator = interpolator
        return obj

    def __getnewargs__(self):
        return self.name, self.field, self.interpolator

    # noinspection SpellCheckingInspection
    __xnew__ = staticmethod(__new_stage2__)
    # noinspection SpellCheckingInspection
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
class Interpolator(object):
    """
    Implements non-integer accesses on fields using linear interpolation.

    On GPU, this interpolator can be implemented by a :class:`.TextureCachedField` for hardware acceleration.

    Address modes are different boundary handlings possible choices are like for CUDA textures

        **CLAMP**

        The signal c[k] is continued outside k=0,...,M-1 so that c[k] = c[0] for k < 0, and c[k] = c[M-1] for k >= M.

        **BORDER**

        The signal c[k] is continued outside k=0,...,M-1 so that c[k] = 0 for k < 0and for k >= M.

        Now, to describe the last two address modes, we are forced to consider normalized coordinates,
        so that the 1D input signal samples are assumed to be c[k / M], with k=0,...,M-1.

        **WRAP**

        The signal c[k / M] is continued outside k=0,...,M-1 so that it is periodic with period equal to M.
        In other words, c[(k + p * M) / M] = c[k / M] for any (positive, negative or vanishing) integer p.

        **MIRROR**

        The signal c[k / M] is continued outside k=0,...,M-1 so that it is periodic with period equal to 2 * M - 2.
        In other words, c[l / M] = c[k / M] for any l and k such that (l + k)mod(2 * M - 2) = 0.

    Explanations from https://stackoverflow.com/questions/19020963/the-different-addressing-modes-of-cuda-textures
    """

    required_global_declarations = []

    def __init__(self,
                 parent_field,
                 interpolation_mode: InterpolationMode,
                 address_mode='BORDER',
                 use_normalized_coordinates=False,
                 allow_textures=True):
        super().__init__()

101
        self.field = parent_field
102
103
104
        self.field.field_type = pystencils.field.FieldType.CUSTOM
        self.address_mode = address_mode
        self.use_normalized_coordinates = use_normalized_coordinates
105
106
        self.interpolation_mode = interpolation_mode
        self.hash_str = hashlib.md5(
107
            f'{self.field}_{address_mode}_{self.field.dtype}_{interpolation_mode}'.encode()).hexdigest()
108
        self.symbol = _InterpolationSymbol(str(self), parent_field, self)
109
110
        self.allow_textures = allow_textures

Stephan Seitz's avatar
Stephan Seitz committed
111
112
113
114
    @property
    def ndim(self):
        return self.field.ndim

115
116
117
118
    @property
    def _hashable_contents(self):
        return (str(self.address_mode),
                str(type(self)),
119
                self.hash_str,
120
121
                self.use_normalized_coordinates)

122
123
124
125
126
127
128
    def at(self, offset):
        return InterpolatorAccess(self.symbol, *[sp.S(o) for o in offset])

    def __getitem__(self, offset):
        return InterpolatorAccess(self.symbol, *[sp.S(o) for o in offset])

    def __str__(self):
129
        return '%s_interpolator_%s' % (self.field.name, self.reproducible_hash)
130
131
132
133
134

    def __repr__(self):
        return self.__str__()

    def __hash__(self):
135
136
137
138
        return hash(self._hashable_contents)

    def __eq__(self, other):
        return hash(self) == hash(other)
139
140
141
142

    @property
    def reproducible_hash(self):
        return _hash(str(self._hashable_contents).encode()).hexdigest()
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169


class LinearInterpolator(Interpolator):

    def __init__(self,
                 parent_field: pystencils.Field,
                 address_mode='BORDER',
                 use_normalized_coordinates=False):
        super().__init__(parent_field,
                         InterpolationMode.LINEAR,
                         address_mode,
                         use_normalized_coordinates)


class NearestNeightborInterpolator(Interpolator):

    def __init__(self,
                 parent_field: pystencils.Field,
                 address_mode='BORDER',
                 use_normalized_coordinates=False):
        super().__init__(parent_field,
                         InterpolationMode.NN,
                         address_mode,
                         use_normalized_coordinates)


class InterpolatorAccess(TypedSymbol):
170
    def __new__(cls, field, *offsets, **kwargs):
171
        obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs)
172
173
        return obj

Stephan Seitz's avatar
Stephan Seitz committed
174
    def __new_stage2__(cls, symbol, *offsets):
175
        assert offsets is not None
Stephan Seitz's avatar
Stephan Seitz committed
176
177
178
        obj = super().__xnew__(cls, '%s_interpolator_%s' %
                               (symbol.field.name, _hash(str(tuple(offsets)).encode()).hexdigest()),
                               symbol.field.dtype)
179
180
181
182
183
184
        obj.offsets = offsets
        obj.symbol = symbol
        obj.field = symbol.field
        obj.interpolator = symbol.interpolator
        return obj

185
186
    def _hashable_contents(self):
        return super()._hashable_content() + ((self.symbol, self.field, tuple(self.offsets), self.symbol.interpolator))
187
188

    def __str__(self):
Stephan Seitz's avatar
Stephan Seitz committed
189
        return '%s_interpolator(%s)' % (self.field.name, ', '.join(str(o) for o in self.offsets))
190
191
192
193

    def __repr__(self):
        return self.__str__()

194
195
196
197
198
    def _latex(self, printer, *_):
        n = self.field.latex_name if self.field.latex_name else self.field.name
        foo = ", ".join(str(printer.doprint(o)) for o in self.offsets)
        return f'{n}_{{interpolator}}\\left({foo}\\right)'

199
200
201
202
203
204
205
206
    @property
    def ndim(self):
        return len(self.offsets)

    @property
    def is_texture(self):
        return isinstance(self.interpolator, TextureCachedField)

207
208
209
210
211
212
213
    def atoms(self, *types):
        if self.offsets:
            offsets = set(o for o in self.offsets if isinstance(o, types))
            if isinstance(self, *types):
                offsets.update([self])
            for o in self.offsets:
                if hasattr(o, 'atoms'):
214
                    offsets.update(set(o.atoms(*types)))
215
216
217
218
            return offsets
        else:
            return set()

219
220
221
222
223
    def neighbor(self, coord_id, offset):
        offset_list = list(self.offsets)
        offset_list[coord_id] += offset
        return self.interpolator.at(tuple(offset_list))

224
225
226
227
228
229
230
231
232
233
234
235
    @property
    def free_symbols(self):
        symbols = set()
        if self.offsets is not None:
            for o in self.offsets:
                if hasattr(o, 'free_symbols'):
                    symbols.update(set(o.free_symbols))
                # if hasattr(o, 'atoms'):
                    # symbols.update(set(o.atoms(sp.Symbol)))

        return symbols

Stephan Seitz's avatar
Stephan Seitz committed
236
237
238
239
240
241
242
    @property
    def required_global_declarations(self):
        required_global_declarations = self.symbol.interpolator.required_global_declarations
        if required_global_declarations:
            required_global_declarations[0]._symbols_defined.add(self)
        return required_global_declarations

243
244
245
246
247
248
249
250
251
252
253
254
    @property
    def args(self):
        return [self.symbol, *self.offsets]

    @property
    def symbols_defined(self) -> Set[sp.Symbol]:
        return {self}

    @property
    def interpolation_mode(self):
        return self.interpolator.interpolation_mode

255
256
257
258
259
260
    @property
    def _diff_interpolation_vec(self):
        return sp.Matrix([DiffInterpolatorAccess(self.symbol, i, *self.offsets)
                          for i in range(len(self.offsets))])

    def diff(self, *symbols, **kwargs):
Stephan Seitz's avatar
Stephan Seitz committed
261
262
        if symbols == (self,):
            return 1
263
264
265
266
267
        rtn = self._diff_interpolation_vec.T * sp.Matrix(self.offsets).diff(*symbols, **kwargs)
        if rtn.shape == (1, 1):
            rtn = rtn[0, 0]
        return rtn

268
269
270
271
    def implementation_with_stencils(self):
        field = self.field

        default_int_type = create_type('int64')
Stephan Seitz's avatar
Stephan Seitz committed
272
        use_textures = isinstance(self.interpolator, TextureCachedField)
273
274
        if use_textures:
            def absolute_access(x, _):
Stephan Seitz's avatar
Stephan Seitz committed
275
                return self.symbol.interpolator.at((o for o in x))
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        else:
            absolute_access = field.absolute_access

        sum = [0, ] * (field.shape[0] if field.index_dimensions else 1)

        offsets = self.offsets
        rounding_functions = (sp.floor, lambda x: sp.floor(x) + 1)

        for channel_idx in range(field.shape[0] if field.index_dimensions else 1):
            if self.interpolation_mode == InterpolationMode.NN:
                if use_textures:
                    sum[channel_idx] = self
                else:
                    sum[channel_idx] = absolute_access([sp.floor(i + 0.5) for i in offsets], channel_idx)

            elif self.interpolation_mode == InterpolationMode.LINEAR:
                # TODO optimization: implement via lerp: https://devblogs.nvidia.com/lerp-faster-cuda/
                for c in itertools.product(rounding_functions, repeat=field.spatial_dimensions):
                    weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
                    index = [f(offset) for (f, offset) in zip(c, offsets)]
                    # Hardware boundary handling on GPU
                    if use_textures:
                        weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
                        sum[channel_idx] += \
                            weight * absolute_access(index, channel_idx if field.index_dimensions else ())
                    # else boundary handling using software
                    elif str(self.interpolator.address_mode).lower() == 'border':
                        is_inside_field = sp.And(
                            *itertools.chain([i >= 0 for i in index],
                                             [idx < field.shape[dim] for (dim, idx) in enumerate(index)]))
                        index = [cast_func(i, default_int_type) for i in index]
                        sum[channel_idx] += sp.Piecewise(
                            (weight * absolute_access(index, channel_idx if field.index_dimensions else ()),
                                is_inside_field),
                            (sp.simplify(0), True)
                        )
                    elif str(self.interpolator.address_mode).lower() == 'clamp':
                        index = [sp.Min(sp.Max(0, cast_func(i, default_int_type)), field.spatial_shape[dim] - 1)
                                 for (dim, i) in enumerate(index)]
                        sum[channel_idx] += weight * \
                            absolute_access(index, channel_idx if field.index_dimensions else ())
                    elif str(self.interpolator.address_mode).lower() == 'wrap':
                        index = [sp.Mod(cast_func(i, default_int_type), field.shape[dim] - 1)
                                 for (dim, i) in enumerate(index)]
                        index = [cast_func(sp.Piecewise((i, i > 0),
                                                        (sp.Abs(cast_func(field.shape[dim] - 1 + i, default_int_type)),
322
                                                         True)), default_int_type)
323
324
325
326
327
328
                                 for (dim, i) in enumerate(index)]
                        sum[channel_idx] += weight * \
                            absolute_access(index, channel_idx if field.index_dimensions else ())
                        # sum[channel_idx] = 0
                    elif str(self.interpolator.address_mode).lower() == 'mirror':
                        def triangle_fun(x, half_period):
329
                            saw_tooth = cast_func(sp.Abs(cast_func(x, 'int32')), 'int32') % (
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
                                cast_func(2 * half_period, create_type('int32')))
                            return sp.Piecewise((saw_tooth, saw_tooth < half_period),
                                                (2 * half_period - 1 - saw_tooth, True))
                        index = [cast_func(triangle_fun(i, field.shape[dim]),
                                           default_int_type) for (dim, i) in enumerate(index)]
                        sum[channel_idx] += weight * \
                            absolute_access(index, channel_idx if field.index_dimensions else ())
                    else:
                        raise NotImplementedError()
            elif self.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
                raise NotImplementedError("only works with HW interpolation for float32")

            sum = [sp.factor(s) for s in sum]

            if field.index_dimensions:
                return sp.Matrix(sum)
            else:
                return sum[0]

    # noinspection SpellCheckingInspection
    __xnew__ = staticmethod(__new_stage2__)
    # noinspection SpellCheckingInspection
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

354
355
356
    def __getnewargs__(self):
        return tuple(self.symbol, *self.offsets)

357
358
359

class DiffInterpolatorAccess(InterpolatorAccess):
    def __new__(cls, symbol, diff_coordinate_idx, *offsets, **kwargs):
360
361
362
        if symbol.interpolator.interpolation_mode == InterpolationMode.LINEAR:
            from pystencils.fd import Diff, Discretization2ndOrder
            return Discretization2ndOrder(1)(Diff(symbol.interpolator.at(offsets), diff_coordinate_idx))
363
364
365
366
367
368
369
370
371
372
373
374
375
376
        obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets, **kwargs)
        return obj

    def __new_stage2__(self, symbol: sp.Symbol, diff_coordinate_idx, *offsets):
        assert offsets is not None
        obj = super().__xnew__(self, symbol, *offsets)
        obj.diff_coordinate_idx = diff_coordinate_idx
        return obj

    def __hash__(self):
        return hash((self.symbol, self.field, self.diff_coordinate_idx, tuple(self.offsets), self.interpolator))

    def __str__(self):
        return '%s_diff%i_interpolator(%s)' % (self.field.name, self.diff_coordinate_idx,
Stephan Seitz's avatar
Stephan Seitz committed
377
                                               ', '.join(str(o) for o in self.offsets))
378

379
380
381
    def __repr__(self):
        return str(self)

382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    @property
    def args(self):
        return [self.symbol, self.diff_coordinate_idx, *self.offsets]

    @property
    def symbols_defined(self) -> Set[sp.Symbol]:
        return {self}

    @property
    def interpolation_mode(self):
        return self.interpolator.interpolation_mode

    # noinspection SpellCheckingInspection
    __xnew__ = staticmethod(__new_stage2__)
    # noinspection SpellCheckingInspection
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    def __getnewargs__(self):
        return tuple(self.symbol, self.diff_coordinate_idx, *self.offsets)


403
404
405
406
407
##########################################################################################
# GPU-specific fast specializations (for precision GPUs can also use above nodes/symbols #
##########################################################################################


408
class TextureCachedField(Interpolator):
409
410
411
412
413
414
415
416

    def __init__(self, parent_field,
                 address_mode=None,
                 filter_mode=None,
                 interpolation_mode: InterpolationMode = InterpolationMode.LINEAR,
                 use_normalized_coordinates=False,
                 read_as_integer=False
                 ):
417
        super().__init__(parent_field, interpolation_mode, address_mode, use_normalized_coordinates)
418
419

        if address_mode is None:
420
            address_mode = 'border'
421
422
423
424
425
426
        if filter_mode is None:
            filter_mode = pycuda.driver.filter_mode.LINEAR

        self.read_as_integer = read_as_integer
        self.required_global_declarations = [TextureDeclaration(self)]

Stephan Seitz's avatar
Stephan Seitz committed
427
428
429
430
    @property
    def ndim(self):
        return self.field.ndim

431
432
    @classmethod
    def from_interpolator(cls, interpolator: LinearInterpolator):
433
434
        if (isinstance(interpolator, cls)
                or (hasattr(interpolator, 'allow_textures') and not interpolator.allow_textures)):
435
436
437
438
439
            return interpolator
        obj = cls(interpolator.field, interpolator.address_mode, interpolation_mode=interpolator.interpolation_mode)
        return obj

    def __str__(self):
440
        return '%s_texture_%s' % (self.field.name, self.reproducible_hash)
441
442
443
444

    def __repr__(self):
        return self.__str__()

445
446
    @property
    def reproducible_hash(self):
447
        return _hash(str(self._hashable_contents).encode()).hexdigest()
448

449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485

class TextureDeclaration(Node):
    """
    A global declaration of a texture. Visible both for device and host code.

    .. code:: cpp

        // This Node represents the following global declaration
        texture<float, cudaTextureType2D, cudaReadModeElementType> x_texture_5acc9fced7b0dc3e;

        __device__ kernel(...) {
            // kernel acceses x_texture_5acc9fced7b0dc3e with tex2d(...)
        }

        __host__ launch_kernel(...) {
            // Host needs to bind the texture
            cudaBindTexture(0, x_texture_5acc9fced7b0dc3e, buffer, N*sizeof(float));
        }

    This has been deprecated by CUDA in favor of :class:`.TextureObject`.
    But texture objects are not yet supported by PyCUDA (https://github.com/inducer/pycuda/pull/174)
    """

    def __init__(self, parent_texture):
        self.texture = parent_texture
        self._symbols_defined = {self.texture.symbol}

    @property
    def symbols_defined(self) -> Set[sp.Symbol]:
        return self._symbols_defined

    @property
    def args(self) -> Set[sp.Symbol]:
        return set()

    @property
    def headers(self):
Stephan Seitz's avatar
Stephan Seitz committed
486
487
488
489
        headers = ['"pycuda-helpers.hpp"']
        if self.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
            headers.append('"cubicTex%iD.cu"' % self.texture.ndim)
        return headers
490

491
492
493
494
    def __str__(self):
        from pystencils.backends.cuda_backend import CudaBackend
        return CudaBackend()(self)

495
496
497
    def __repr__(self):
        return str(self)

498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526

class TextureObject(TextureDeclaration):
    """
    A CUDA texture object. Opposed to :class:`.TextureDeclaration` it is not declared globally but
    used as a function argument for the kernel call.

    Like :class:`.TextureDeclaration` it defines :class:`.TextureAccess` symbols.
    Just the printing representation is a bit different.
    """
    pass


def dtype_supports_textures(dtype):
    """
    Returns whether CUDA natively supports texture fetches with this numpy dtype.

    The maximum word size for a texture fetch is four bytes.

    With this trick also larger dtypes can be fetched:
    https://github.com/inducer/pycuda/blob/master/pycuda/cuda/pycuda-helpers.hpp

    """
    if hasattr(dtype, 'numpy_dtype'):
        dtype = dtype.numpy_dtype

    if isinstance(dtype, type):
        return dtype().itemsize <= 4

    return dtype.itemsize <= 4