Commit eba2ea13 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

WIP: add interpolation_mode (NEAREST_NEIGHBOR vs LINEAR)

parent 5fa29433
Pipeline #17026 failed with stage
in 3 minutes and 48 seconds
......@@ -8,6 +8,7 @@ from pystencils.astnodes import Node
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
from pystencils.data_types import cast_func, create_type, get_type_of_expression
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.interpolation_astnodes import InterpolationMode
with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
lines = f.readlines()
......@@ -77,7 +78,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
def _print_TextureAccess(self, node):
dtype = node.texture.field.dtype.numpy_dtype
if node.texture.cubic_bspline_interpolation:
if node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
template = "cubicTex%iDSimple<%s>(%s, %s)"
else:
if dtype.itemsize > 4:
......
......@@ -7,7 +7,7 @@ from pystencils.data_types import StructType
from pystencils.field import FieldType
from pystencils.gpucuda.texture_utils import ndarray_to_tex
from pystencils.include import get_pystencils_include_path
from pystencils.interpolation_astnodes import TextureAccess
from pystencils.interpolation_astnodes import InterpolationMode, TextureAccess
from pystencils.kernelparameters import FieldPointerSymbol
USE_FAST_MATH = True
......@@ -46,7 +46,7 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
if USE_FAST_MATH:
nvcc_options.append("-use_fast_math")
if any(t.cubic_bspline_interpolation for t in textures):
if any(t.interpolation_mode == InterpolationMode.CUBIC_SPLINE for t in textures):
assert isdir(join(dirname(__file__), "CubicInterpolationCUDA", "code")), \
"Submodule CubicInterpolationCUDA does not exist"
nvcc_options += ["-I" + join(dirname(__file__), "CubicInterpolationCUDA", "code")]
......
......@@ -9,6 +9,7 @@
"""
import itertools
from enum import Enum
from typing import Set
import sympy as sp
......@@ -24,7 +25,14 @@ except Exception:
pass
class LinearInterpolator(object):
class InterpolationMode(str, Enum):
NEAREST_NEIGHBOR = "nearest_neighbour"
NN = NEAREST_NEIGHBOR
LINEAR = "linear"
CUBIC_SPLINE = "cubic_spline"
class Interpolator(object):
"""
Implements non-integer accesses on fields using linear interpolation.
......@@ -58,7 +66,9 @@ class LinearInterpolator(object):
required_global_declarations = []
def __init__(self, parent_field: pystencils.Field,
def __init__(self,
parent_field: pystencils.Field,
interpolation_mode: InterpolationMode,
address_mode='BORDER',
use_normalized_coordinates=False):
super().__init__()
......@@ -72,6 +82,7 @@ class LinearInterpolator(object):
'dummy_symbol_carrying_field' + self.field.name + hash_str)
self.symbol.field = self.field
self.symbol.interpolator = self
self.interpolation_mode = interpolation_mode
def at(self, offset):
return InterpolatorAccess(self.symbol, *offset)
......@@ -93,6 +104,30 @@ class LinearInterpolator(object):
self.use_normalized_coordinates))
class LinearInterpolator(Interpolator):
def __init__(self,
parent_field: pystencils.Field,
address_mode='BORDER',
use_normalized_coordinates=False):
super().__init__(parent_field,
InterpolationMode.LINEAR,
address_mode,
use_normalized_coordinates)
class NearestNeightborInterpolator(Interpolator):
def __init__(self,
parent_field: pystencils.Field,
address_mode='BORDER',
use_normalized_coordinates=False):
super().__init__(parent_field,
InterpolationMode.NN,
address_mode,
use_normalized_coordinates)
class InterpolatorAccess(TypedSymbol):
def __new__(cls, field, offsets, *args, **kwargs):
obj = TextureAccess.__xnew_cached_(cls, field, offsets, *args, **kwargs)
......@@ -149,6 +184,10 @@ class InterpolatorAccess(TypedSymbol):
def symbols_defined(self) -> Set[sp.Symbol]:
return {self}
@property
def interpolation_mode(self):
return self.interpolator.interpolation_mode
def implementation_with_stencils(self):
field = self.field
......@@ -165,56 +204,66 @@ class InterpolatorAccess(TypedSymbol):
offsets = self.offsets
rounding_functions = (sp.floor, lambda x: sp.floor(x) + 1)
# TODO optimization: implement via lerp: https://devblogs.nvidia.com/lerp-faster-cuda/
for c in itertools.product(rounding_functions, repeat=field.spatial_dimensions):
weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
index = [f(offset) for (f, offset) in zip(c, offsets)]
for channel_idx in range(field.shape[0] if field.index_dimensions else 1):
# Hardware boundary handling on GPU
for channel_idx in range(field.shape[0] if field.index_dimensions else 1):
if self.interpolation_mode == InterpolationMode.NN:
if use_textures:
weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
sum[channel_idx] += \
weight * absolute_access(index, channel_idx if field.index_dimensions else ())
# else boundary handling using software
elif str(self.interpolator.address_mode).lower() == 'border':
is_inside_field = sp.And(
*itertools.chain([i >= 0 for i in index],
[idx < field.shape[dim] for (dim, idx) in enumerate(index)]))
index = [cast_func(i, default_int_type) for i in index]
sum[channel_idx] += sp.Piecewise(
(weight * absolute_access(index, channel_idx if field.index_dimensions else ()),
is_inside_field),
(sp.simplify(0), True)
)
elif str(self.interpolator.address_mode).lower() == 'clamp':
index = [cast_func(sp.Min(sp.Max(0, i), field.shape[dim] - 1), default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
elif str(self.interpolator.address_mode).lower() == 'wrap':
index = [cast_func(sp.Piecewise((sp.Mod(i, field.shape[dim]), i >= 0),
(field.shape[dim] + sp.Mod(i, field.shape[dim]), True)),
default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
elif str(self.interpolator.address_mode).lower() == 'mirror':
def triangle_fun(x, half_period):
saw_tooth = sp.Abs(x) % (2 * half_period)
return sp.Piecewise((saw_tooth, saw_tooth < half_period),
(2 * half_period - 1 - saw_tooth, True))
index = [cast_func(triangle_fun(i, field.shape[dim]),
default_int_type) for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * absolute_access(index, channel_idx if field.index_dimensions else ())
sum[channel_idx] = self
else:
raise NotImplementedError()
sum = [sp.factor(s) for s in sum]
sum[channel_idx] = absolute_access([sp.floor(i + 0.5) for i in offsets], channel_idx)
if field.index_dimensions:
return sp.Matrix(sum)
else:
return sum[0]
elif self.interpolation_mode == InterpolationMode.LINEAR:
# TODO optimization: implement via lerp: https://devblogs.nvidia.com/lerp-faster-cuda/
for c in itertools.product(rounding_functions, repeat=field.spatial_dimensions):
weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
index = [f(offset) for (f, offset) in zip(c, offsets)]
# Hardware boundary handling on GPU
if use_textures:
weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
sum[channel_idx] += \
weight * absolute_access(index, channel_idx if field.index_dimensions else ())
# else boundary handling using software
elif str(self.interpolator.address_mode).lower() == 'border':
is_inside_field = sp.And(
*itertools.chain([i >= 0 for i in index],
[idx < field.shape[dim] for (dim, idx) in enumerate(index)]))
index = [cast_func(i, default_int_type) for i in index]
sum[channel_idx] += sp.Piecewise(
(weight * absolute_access(index, channel_idx if field.index_dimensions else ()),
is_inside_field),
(sp.simplify(0), True)
)
elif str(self.interpolator.address_mode).lower() == 'clamp':
index = [cast_func(sp.Min(sp.Max(0, i), field.shape[dim] - 1), default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
elif str(self.interpolator.address_mode).lower() == 'wrap':
index = [cast_func(sp.Piecewise((sp.Mod(i, field.shape[dim]), i >= 0),
(field.shape[dim] + sp.Mod(i, field.shape[dim]), True)),
default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
elif str(self.interpolator.address_mode).lower() == 'mirror':
def triangle_fun(x, half_period):
saw_tooth = sp.Abs(x) % (2 * half_period)
return sp.Piecewise((saw_tooth, saw_tooth < half_period),
(2 * half_period - 1 - saw_tooth, True))
index = [cast_func(triangle_fun(i, field.shape[dim]),
default_int_type) for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
else:
raise NotImplementedError()
elif self.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
raise NotImplementedError("only works with HW interpolation for float32")
sum = [sp.factor(s) for s in sum]
if field.index_dimensions:
return sp.Matrix(sum)
else:
return sum[0]
# noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__)
......@@ -231,9 +280,10 @@ class TextureCachedField:
def __init__(self, parent_field,
address_mode=None,
filter_mode=None,
interpolation_mode: InterpolationMode = InterpolationMode.LINEAR,
use_normalized_coordinates=False,
read_as_integer=False,
cubic_bspline_interpolation=False):
read_as_integer=False
):
if isinstance(address_mode, str):
address_mode = getattr(pycuda.driver.address_mode, address_mode.upper())
......@@ -252,14 +302,14 @@ class TextureCachedField:
self.symbol = TypedSymbol(str(self), self.field.dtype.numpy_dtype)
self.symbol.interpolator = self
self.symbol.field = self.field
self.cubic_bspline_interpolation = cubic_bspline_interpolation
self.interpolation_mode = interpolation_mode
# assert str(self.field.dtype) != 'double', "CUDA does not support double textures!"
# assert dtype_supports_textures(self.field.dtype), "CUDA only supports texture types with 32 bits or less"
@classmethod
def from_interpolator(cls, interpolator: LinearInterpolator):
obj = cls(interpolator.field, interpolator.address_mode)
obj = cls(interpolator.field, interpolator.address_mode, interpolation_mode=interpolator.interpolation_mode)
return obj
def at(self, offset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment