Commit 479cc5ad authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Refactor interpolation

parent be54b20e
......@@ -10,11 +10,12 @@ from sympy.printing.ccode import C89CodePrinter
from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func,
vector_memory_access)
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
reinterpret_cast_func, vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil)
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil)
try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
......
......@@ -3,8 +3,7 @@ from os.path import dirname, join
from pystencils.astnodes import Node
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.interpolation_astnodes import (
DiffInterpolatorAccess, InterpolationMode, TextureCachedField)
from pystencils.interpolation_astnodes import DiffInterpolatorAccess, InterpolationMode
with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
lines = f.readlines()
......@@ -74,9 +73,9 @@ class CudaSympyPrinter(CustomSympyPrinter):
def _print_InterpolatorAccess(self, node):
dtype = node.interpolator.field.dtype.numpy_dtype
if isinstance(node, DiffInterpolatorAccess):
if type(node) == DiffInterpolatorAccess:
# cubicTex3D_1st_derivative_x(texture tex, float3 coord)
template = f"cubicTex%iD_1st_derivative_{'zyx'[node.diff_coordinate_idx]}(%s, %s)"
template = f"cubicTex%iD_1st_derivative_{'xyz'[node.diff_coordinate_idx]}(%s, %s)"
elif node.interpolator.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
template = "cubicTex%iDSimple(%s, %s)"
else:
......
......@@ -5,13 +5,20 @@ from pystencils.data_types import StructType
from pystencils.field import FieldType
from pystencils.gpucuda.texture_utils import ndarray_to_tex
from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
from pystencils.interpolation_astnodes import TextureAccess
from pystencils.interpolation_astnodes import InterpolatorAccess, TextureCachedField
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.kernelparameters import FieldPointerSymbol
USE_FAST_MATH = True
def get_cubic_interpolation_include_paths():
from os.path import join, dirname
return [join(dirname(__file__), "CubicInterpolationCUDA", "code"),
join(dirname(__file__), "CubicInterpolationCUDA", "code", "internal")]
def make_python_function(kernel_function_node, argument_dict=None, custom_backend=None):
"""
Creates a kernel function from an abstract syntax tree which
......@@ -39,7 +46,8 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
code += "#define FUNC_PREFIX __global__\n"
code += "#define RESTRICT __restrict__\n\n"
code += str(generate_c(kernel_function_node, dialect='cuda', custom_backend=custom_backend))
textures = set(d.texture for d in kernel_function_node.atoms(TextureAccess))
textures = set(d.interpolator for d in kernel_function_node.atoms(
InterpolatorAccess) if isinstance(d.interpolator, TextureCachedField))
nvcc_options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"]
if USE_FAST_MATH:
......
......@@ -89,6 +89,10 @@ class Interpolator(object):
self.allow_textures = allow_textures
self.interpolation_mode = interpolation_mode
@property
def ndim(self):
return self.field.ndim
@property
def _hashable_contents(self):
return (str(self.address_mode),
......@@ -146,10 +150,11 @@ class InterpolatorAccess(TypedSymbol):
obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs)
return obj
def __new_stage2__(self, symbol, *offsets):
def __new_stage2__(cls, symbol, *offsets):
assert offsets is not None
obj = super().__xnew__(self, '%s_interpolator_%x' %
(symbol.field.name, abs(hash(tuple(offsets)))), symbol.field.dtype)
obj = super().__xnew__(cls, '%s_interpolator_%s' %
(symbol.field.name, _hash(str(tuple(offsets)).encode()).hexdigest()),
symbol.field.dtype)
obj.offsets = offsets
obj.symbol = symbol
obj.field = symbol.field
......@@ -160,7 +165,7 @@ class InterpolatorAccess(TypedSymbol):
return hash((self.symbol, self.field, tuple(self.offsets), self.interpolator))
def __str__(self):
return '%s_interpolator(%s)' % (self.field.name, ','.join(str(o) for o in self.offsets))
return '%s_interpolator(%s)' % (self.field.name, ', '.join(str(o) for o in self.offsets))
def __repr__(self):
return self.__str__()
......@@ -189,6 +194,13 @@ class InterpolatorAccess(TypedSymbol):
return symbols
@property
def required_global_declarations(self):
required_global_declarations = self.symbol.interpolator.required_global_declarations
if required_global_declarations:
required_global_declarations[0]._symbols_defined.add(self)
return required_global_declarations
@property
def args(self):
return [self.symbol, *self.offsets]
......@@ -320,7 +332,7 @@ class DiffInterpolatorAccess(InterpolatorAccess):
def __str__(self):
return '%s_diff%i_interpolator(%s)' % (self.field.name, self.diff_coordinate_idx,
','.join(str(o) for o in self.offsets))
', '.join(str(o) for o in self.offsets))
def __repr__(self):
return str(self)
......@@ -383,6 +395,10 @@ class TextureCachedField:
# assert str(self.field.dtype) != 'double', "CUDA does not support double textures!"
# assert dtype_supports_textures(self.field.dtype), "CUDA only supports texture types with 32 bits or less"
@property
def ndim(self):
return self.field.ndim
@classmethod
def from_interpolator(cls, interpolator: LinearInterpolator):
if (isinstance(interpolator, cls)
......@@ -432,7 +448,7 @@ class TextureAccess(InterpolatorAccess):
return obj
def __str__(self):
return '%s_texture(%s)' % (self.interpolator.field.name, ','.join(str(o) for o in self.offsets))
return '%s_texture(%s)' % (self.interpolator.field.name, ', '.join(str(o) for o in self.offsets))
@property
def texture(self):
......@@ -480,7 +496,10 @@ class TextureDeclaration(Node):
@property
def headers(self):
return ['"pycuda-helpers.hpp"']
headers = ['"pycuda-helpers.hpp"']
if self.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
headers.append('"cubicTex%iD.cu"' % self.texture.ndim)
return headers
def __str__(self):
from pystencils.backends.cuda_backend import CudaBackend
......@@ -515,3 +534,4 @@ def dtype_supports_textures(dtype):
return dtype().itemsize <= 4
return dtype.itemsize <= 4
......@@ -44,3 +44,13 @@ def optimize_assignments(assignments, optimizations):
a.optimize(optimizations)
return assignments
def optimize_ast(ast, optimizations):
if HAS_REWRITING:
assignments_nodes = ast.atoms(SympyAssignment)
for a in assignments_nodes:
a.optimize(optimizations)
return ast
......@@ -14,8 +14,8 @@ import pystencils.astnodes as ast
import pystencils.integer_functions
from pystencils.assignment import Assignment
from pystencils.data_types import (
PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type, get_base_type,
get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type,
get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
from pystencils.field import AbstractField, Field, FieldType
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection
......@@ -1314,7 +1314,7 @@ def implement_interpolations(ast_node: ast.Node,
implement_by_texture_accesses: bool = False,
vectorize: bool = False,
use_hardware_interpolation_for_f32=True):
from pystencils.interpolation_astnodes import InterpolatorAccess, TextureAccess, TextureCachedField
from pystencils.interpolation_astnodes import (InterpolatorAccess, TextureCachedField)
# TODO: perform this function on assignments, when unify_shape_symbols allows differently sized fields
assert not(implement_by_texture_accesses and vectorize), \
......@@ -1324,34 +1324,27 @@ def implement_interpolations(ast_node: ast.Node,
interpolation_accesses = ast_node.atoms(InterpolatorAccess)
def can_use_hw_interpolation(i):
return use_hardware_interpolation_for_f32 and i.dtype == FLOAT32_T and isinstance(i, TextureAccess)
return (use_hardware_interpolation_for_f32
and i.dtype == FLOAT32_T
and isinstance(i.interpolator, TextureCachedField))
if implement_by_texture_accesses:
interpolators = {a.symbol.interpolator for a in interpolation_accesses}
to_texture_map = {i: TextureCachedField.from_interpolator(i) for i in interpolators}
substitutions = {i: to_texture_map[i.symbol.interpolator].at(
[o for o in i.offsets]) for i in interpolation_accesses}
try:
import pycuda.driver as cuda
for texture in substitutions.values():
if can_use_hw_interpolation(texture):
for i in interpolation_accesses:
old_i = i
try:
import pycuda.driver as cuda
texture = TextureCachedField.from_interpolator(i.interpolator)
i.interpolator = texture
i.symbol.interpolator = texture
if can_use_hw_interpolation(i):
texture.filter_mode = cuda.filter_mode.LINEAR
else:
texture.filter_mode = cuda.filter_mode.POINT
texture.read_as_integer = True
except Exception:
pass
if isinstance(ast_node, AssignmentCollection):
ast_node = ast_node.subs(substitutions)
else:
ast_node.subs(substitutions)
# Update after replacements
interpolation_accesses = ast_node.atoms(InterpolatorAccess)
except Exception:
pass
ast_node.subs({old_i: i})
if vectorize:
# TODO can be done in _interpolator_access_to_stencils field.absolute_access == simd_gather
......@@ -1364,4 +1357,11 @@ def implement_interpolations(ast_node: ast.Node,
else:
ast_node.subs(substitutions)
# from pystencils.math_optimizations import ReplaceOptim, optimize_ast
# RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate),
# lambda e: e.args[0]
# )
# optimize_ast(ast_node, [RemoveConjugate])
return ast_node
......@@ -110,8 +110,9 @@ def test_rotate_interpolation(address_mode):
pyconrad.imshow(out, "out " + address_mode)
@pytest.mark.parametrize('address_mode', ['border', 'wrap', 'clamp', 'mirror'])
def test_rotate_interpolation_gpu(address_mode):
@pytest.mark.parametrize('dtype', (np.int32, np.float32, np.float64))
@pytest.mark.parametrize('address_mode', ('border', 'wrap', 'clamp', 'mirror'))
def test_rotate_interpolation_gpu(dtype, address_mode):
rotation_angle = sympy.pi / 5
scale = 1
......@@ -138,24 +139,23 @@ def test_rotate_interpolation_gpu(address_mode):
pystencils.show_code(ast)
kernel = ast.compile()
out = gpuarray.zeros_like(lenna_gpu)
kernel(x=lenna_gpu, y=out)
pyconrad.imshow(out,
f"out {address_mode} texture:{use_textures} {type_map[dtype]}")
skimage.io.imsave(f"/tmp/out {address_mode} texture:{use_textures} {type_map[dtype]}.tif",
np.ascontiguousarray(out.get(), np.float32))
if previous_result is not None:
try:
assert np.allclose(previous_result[4:-4, 4:-4], out.get()[4:-4, 4:-4], rtol=100, atol=1e-3)
except AssertionError: # NOQA
print("Max error: %f" % np.max(previous_result - out.get()))
# pyconrad.imshow(previous_result - out.get(), "Difference image")
# raise e
previous_result = out.get()
@pytest.mark.parametrize('address_mode', ['border', 'wrap', 'clamp', 'mirror'])
def test_shift_interpolation_gpu(address_mode):
out = gpuarray.zeros_like(lenna_gpu)
kernel(x=lenna_gpu, y=out)
pyconrad.imshow(out,
f"out {address_mode} texture:{use_textures} {type_map[dtype]}")
skimage.io.imsave(f"/tmp/out {address_mode} texture:{use_textures} {type_map[dtype]}.tif",
np.ascontiguousarray(out.get(), np.float32))
if previous_result is not None:
try:
assert np.allclose(previous_result[4:-4, 4:-4], out.get()[4:-4, 4:-4], rtol=100, atol=1e-3)
except AssertionError as e: # NOQA
print("Max error: %f" % np.max(previous_result - out.get()))
# pyconrad.imshow(previous_result - out.get(), "Difference image")
# raise e
previous_result = out.get()
def test_shift_interpolation_gpu():
rotation_angle = 0 # sympy.pi / 5
scale = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment