Commit 995cc9ef authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Big fat interpolation mode refactoring

parent f645c192
......@@ -77,13 +77,14 @@ class Interpolator(object):
allow_textures=True):
super().__init__()
self.field = parent_field.new_field_with_different_name(parent_field.name)
self.field = parent_field
self.field.field_type = pystencils.field.FieldType.CUSTOM
self.address_mode = address_mode
self.use_normalized_coordinates = use_normalized_coordinates
hash_str = hashlib.md5(f'{self.field}_{address_mode}_{interpolation_mode}'.encode()).hexdigest()
hash_str = hashlib.md5(
f'{self.field}_{address_mode}_{self.field.dtype}_{interpolation_mode}'.encode()).hexdigest()
self.symbol = TypedSymbol('dummy_symbol_carrying_field' + self.field.name + hash_str,
'dummy_symbol_carrying_field' + self.field.name + hash_str)
'dummy_symbol_carrying_field' + self.field.name + self.field.name + hash_str)
self.symbol.field = self.field
self.symbol.interpolator = self
self.allow_textures = allow_textures
......@@ -97,7 +98,8 @@ class Interpolator(object):
def _hashable_contents(self):
return (str(self.address_mode),
str(type(self)),
self.symbol,
str(self.symbol.name),
hash(self.symbol.field),
self.address_mode,
self.use_normalized_coordinates)
......@@ -114,7 +116,7 @@ class Interpolator(object):
return self.__str__()
def __hash__(self):
return hash(self._hashable_contents)
return hash(self._hashable_contents())
@property
def reproducible_hash(self):
......@@ -161,8 +163,8 @@ class InterpolatorAccess(TypedSymbol):
obj.interpolator = symbol.interpolator
return obj
def __hash__(self):
return hash((self.symbol, self.field, tuple(self.offsets), self.interpolator))
def _hashable_contents(self):
return super()._hashable_content() + ((self.symbol, self.field, tuple(self.offsets), self.symbol.interpolator))
def __str__(self):
return '%s_interpolator(%s)' % (self.field.name, ', '.join(str(o) for o in self.offsets))
......@@ -431,12 +433,13 @@ class TextureCachedField(Interpolator):
@property
def _hashable_contents(self):
return (type(self),
return (str(self.address_mode),
str(type(self)),
str(self.symbol.name),
hash(self.symbol.field),
self.address_mode,
self.filter_mode,
self.read_as_integer,
self.interpolation_mode,
self.use_normalized_coordinates)
self.use_normalized_coordinates,
'T')
def __hash__(self):
return hash(self._hashable_contents)
......
......@@ -1327,8 +1327,9 @@ def implement_interpolations(ast_node: ast.Node,
def can_use_hw_interpolation(i):
return (use_hardware_interpolation_for_f32
and implement_by_texture_accesses
and i.dtype == FLOAT32_T
and isinstance(i.interpolator, TextureCachedField))
and isinstance(i.symbol.interpolator, TextureCachedField))
if implement_by_texture_accesses:
......@@ -1337,27 +1338,35 @@ def implement_interpolations(ast_node: ast.Node,
try:
import pycuda.driver as cuda
texture = TextureCachedField.from_interpolator(i.interpolator)
i.interpolator = texture
i.symbol.interpolator = texture
if can_use_hw_interpolation(i):
texture.filter_mode = cuda.filter_mode.LINEAR
i.symbol.interpolator.filter_mode = cuda.filter_mode.LINEAR
else:
texture.filter_mode = cuda.filter_mode.POINT
texture.read_as_integer = True
i.symbol.interpolator.filter_mode = cuda.filter_mode.POINT
i.symbol.interpolator.read_as_integer = True
except Exception:
pass
ast_node.subs({old_i: i})
from pystencils.math_optimizations import ReplaceOptim, optimize_ast
# from pystencils.math_optimizations import ReplaceOptim, optimize_ast
ImplementInterpolationByStencils = ReplaceOptim(lambda e: isinstance(e, InterpolatorAccess)
and not can_use_hw_interpolation(i),
lambda e: e.implementation_with_stencils()
)
# ImplementInterpolationByStencils = ReplaceOptim(lambda e: isinstance(e, InterpolatorAccess)
# and not can_use_hw_interpolation(i),
# lambda e: e.implementation_with_stencils()
# )
RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate),
lambda e: e.args[0]
)
optimize_ast(ast_node, [RemoveConjugate, ImplementInterpolationByStencils])
# RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate),
# lambda e: e.args[0]
# )
if vectorize:
# TODO can be done in _interpolator_access_to_stencils field.absolute_access == simd_gather
raise NotImplementedError()
else:
substitutions = {i: i.implementation_with_stencils()
for i in interpolation_accesses if not can_use_hw_interpolation(i)}
if isinstance(ast_node, AssignmentCollection):
ast_node = ast_node.subs(substitutions)
else:
ast_node.subs(substitutions)
return ast_node
......@@ -112,51 +112,43 @@ def test_rotate_interpolation(address_mode):
@pytest.mark.parametrize('dtype', (np.int32, np.float32, np.float64))
@pytest.mark.parametrize('address_mode', ('border', 'wrap', 'clamp', 'mirror'))
def test_rotate_interpolation_gpu(dtype, address_mode):
@pytest.mark.parametrize('use_textures', ('use_textures', False))
def test_rotate_interpolation_gpu(dtype, address_mode, use_textures):
rotation_angle = sympy.pi / 5
scale = 1
previous_result = None
for dtype in [np.int32, np.float32, np.float64]:
if dtype == np.int32:
lenna_gpu = gpuarray.to_gpu(
np.ascontiguousarray(lenna * 255, dtype))
else:
lenna_gpu = gpuarray.to_gpu(
np.ascontiguousarray(lenna, dtype))
for use_textures in [True, False]:
x_f, y_f = pystencils.fields('x,y: %s [2d]' % type_map[dtype], ghost_layers=0)
if dtype == np.int32:
lenna_gpu = gpuarray.to_gpu(
np.ascontiguousarray(lenna * 255, dtype))
else:
lenna_gpu = gpuarray.to_gpu(
np.ascontiguousarray(lenna, dtype))
x_f, y_f = pystencils.fields('x,y: %s [2d]' % type_map[dtype], ghost_layers=0)
transformed = scale * sympy.rot_axis3(rotation_angle)[:2, :2] * \
sympy.Matrix((x_, y_)) - sympy.Matrix([2, 2])
assignments = pystencils.AssignmentCollection({
y_f.center(): LinearInterpolator(x_f, address_mode=address_mode).at(transformed)
})
print(assignments)
ast = pystencils.create_kernel(assignments, target='gpu', use_textures_for_interpolation=use_textures)
print(ast)
pystencils.show_code(ast)
kernel = ast.compile()
transformed = scale * \
sympy.rot_axis3(rotation_angle)[:2, :2] * sympy.Matrix((x_, y_)) - sympy.Matrix([2, 2])
assignments = pystencils.AssignmentCollection({
y_f.center(): LinearInterpolator(x_f, address_mode=address_mode).at(transformed)
})
print(assignments)
ast = pystencils.create_kernel(assignments, target='gpu', use_textures_for_interpolation=use_textures)
print(ast)
print(pystencils.show_code(ast))
kernel = ast.compile()
out = gpuarray.zeros_like(lenna_gpu)
kernel(x=lenna_gpu, y=out)
pyconrad.imshow(out,
f"out {address_mode} texture:{use_textures} {type_map[dtype]}")
skimage.io.imsave(f"/tmp/out {address_mode} texture:{use_textures} {type_map[dtype]}.tif",
np.ascontiguousarray(out.get(), np.float32))
if previous_result is not None:
try:
assert np.allclose(previous_result[4:-4, 4:-4], out.get()[4:-4, 4:-4], rtol=100, atol=1e-3)
except AssertionError as e: # NOQA
print("Max error: %f" % np.max(previous_result - out.get()))
# pyconrad.imshow(previous_result - out.get(), "Difference image")
# raise e
previous_result = out.get()
out = gpuarray.zeros_like(lenna_gpu)
kernel(x=lenna_gpu, y=out)
pyconrad.imshow(out,
f"out {address_mode} texture:{use_textures} {type_map[dtype]}")
skimage.io.imsave(f"/tmp/out {address_mode} texture:{use_textures} {type_map[dtype]}.tif",
np.ascontiguousarray(out.get(), np.float32))
@pytest.mark.parametrize('address_mode', ['border', 'wrap', 'clamp', 'mirror'])
def test_shift_interpolation_gpu(address_mode):
@pytest.mark.parametrize('dtype', [np.float64, np.float32, np.int32])
@pytest.mark.parametrize('use_textures', ('use_textures', False,))
def test_shift_interpolation_gpu(address_mode, dtype, use_textures):
rotation_angle = 0 # sympy.pi / 5
scale = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment