From 479cc5ad73fd5094898dbec0e79ee073e942af2a Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Tue, 14 Jan 2020 18:46:35 +0100 Subject: [PATCH] Refactor interpolation --- pystencils/backends/cbackend.py | 7 ++-- pystencils/backends/cuda_backend.py | 7 ++-- pystencils/gpucuda/cudajit.py | 12 +++++-- pystencils/interpolation_astnodes.py | 34 ++++++++++++++---- pystencils/math_optimizations.py | 10 ++++++ pystencils/transformations.py | 48 +++++++++++++------------- pystencils_tests/test_interpolation.py | 40 ++++++++++----------- 7 files changed, 98 insertions(+), 60 deletions(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 6f203837d..3f016810a 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -10,11 +10,12 @@ from sympy.printing.ccode import C89CodePrinter from pystencils.astnodes import KernelFunction, Node from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.data_types import ( - PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func, - vector_memory_access) + PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, + reinterpret_cast_func, vector_memory_access) from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.integer_functions import ( - bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil) + bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, + int_div, int_power_of_2, modulo_ceil) try: from sympy.printing.ccode import C99CodePrinter as CCodePrinter diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py index 0766b9415..9797bc7da 100644 --- a/pystencils/backends/cuda_backend.py +++ b/pystencils/backends/cuda_backend.py @@ -3,8 +3,7 @@ from os.path import dirname, join from pystencils.astnodes import Node from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt -from pystencils.interpolation_astnodes import ( - DiffInterpolatorAccess, InterpolationMode, TextureCachedField) +from pystencils.interpolation_astnodes import DiffInterpolatorAccess, InterpolationMode with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f: lines = f.readlines() @@ -74,9 +73,9 @@ class CudaSympyPrinter(CustomSympyPrinter): def _print_InterpolatorAccess(self, node): dtype = node.interpolator.field.dtype.numpy_dtype - if isinstance(node, DiffInterpolatorAccess): + if type(node) == DiffInterpolatorAccess: # cubicTex3D_1st_derivative_x(texture tex, float3 coord) - template = f"cubicTex%iD_1st_derivative_{'zyx'[node.diff_coordinate_idx]}(%s, %s)" + template = f"cubicTex%iD_1st_derivative_{'xyz'[node.diff_coordinate_idx]}(%s, %s)" elif node.interpolator.interpolation_mode == InterpolationMode.CUBIC_SPLINE: template = "cubicTex%iDSimple(%s, %s)" else: diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py index eb2492f61..a77767dce 100644 --- a/pystencils/gpucuda/cudajit.py +++ b/pystencils/gpucuda/cudajit.py @@ -5,13 +5,20 @@ from pystencils.data_types import StructType from pystencils.field import FieldType from pystencils.gpucuda.texture_utils import ndarray_to_tex from pystencils.include import get_pycuda_include_path, get_pystencils_include_path -from pystencils.interpolation_astnodes import TextureAccess +from pystencils.interpolation_astnodes import InterpolatorAccess, TextureCachedField from pystencils.kernel_wrapper import KernelWrapper from pystencils.kernelparameters import FieldPointerSymbol USE_FAST_MATH = True +def get_cubic_interpolation_include_paths(): + from os.path import join, dirname + + return [join(dirname(__file__), "CubicInterpolationCUDA", "code"), + join(dirname(__file__), "CubicInterpolationCUDA", "code", "internal")] + + def make_python_function(kernel_function_node, argument_dict=None, custom_backend=None): """ Creates a kernel function from an abstract syntax tree which @@ -39,7 +46,8 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen code += "#define FUNC_PREFIX __global__\n" code += "#define RESTRICT __restrict__\n\n" code += str(generate_c(kernel_function_node, dialect='cuda', custom_backend=custom_backend)) - textures = set(d.texture for d in kernel_function_node.atoms(TextureAccess)) + textures = set(d.interpolator for d in kernel_function_node.atoms( + InterpolatorAccess) if isinstance(d.interpolator, TextureCachedField)) nvcc_options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"] if USE_FAST_MATH: diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py index b67a06e71..b14e07be1 100644 --- a/pystencils/interpolation_astnodes.py +++ b/pystencils/interpolation_astnodes.py @@ -89,6 +89,10 @@ class Interpolator(object): self.allow_textures = allow_textures self.interpolation_mode = interpolation_mode + @property + def ndim(self): + return self.field.ndim + @property def _hashable_contents(self): return (str(self.address_mode), @@ -146,10 +150,11 @@ class InterpolatorAccess(TypedSymbol): obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs) return obj - def __new_stage2__(self, symbol, *offsets): + def __new_stage2__(cls, symbol, *offsets): assert offsets is not None - obj = super().__xnew__(self, '%s_interpolator_%x' % - (symbol.field.name, abs(hash(tuple(offsets)))), symbol.field.dtype) + obj = super().__xnew__(cls, '%s_interpolator_%s' % + (symbol.field.name, _hash(str(tuple(offsets)).encode()).hexdigest()), + symbol.field.dtype) obj.offsets = offsets obj.symbol = symbol obj.field = symbol.field @@ -160,7 +165,7 @@ class InterpolatorAccess(TypedSymbol): return hash((self.symbol, self.field, tuple(self.offsets), self.interpolator)) def __str__(self): - return '%s_interpolator(%s)' % (self.field.name, ','.join(str(o) for o in self.offsets)) + return '%s_interpolator(%s)' % (self.field.name, ', '.join(str(o) for o in self.offsets)) def __repr__(self): return self.__str__() @@ -189,6 +194,13 @@ class InterpolatorAccess(TypedSymbol): return symbols + @property + def required_global_declarations(self): + required_global_declarations = self.symbol.interpolator.required_global_declarations + if required_global_declarations: + required_global_declarations[0]._symbols_defined.add(self) + return required_global_declarations + @property def args(self): return [self.symbol, *self.offsets] @@ -320,7 +332,7 @@ class DiffInterpolatorAccess(InterpolatorAccess): def __str__(self): return '%s_diff%i_interpolator(%s)' % (self.field.name, self.diff_coordinate_idx, - ','.join(str(o) for o in self.offsets)) + ', '.join(str(o) for o in self.offsets)) def __repr__(self): return str(self) @@ -383,6 +395,10 @@ class TextureCachedField: # assert str(self.field.dtype) != 'double', "CUDA does not support double textures!" # assert dtype_supports_textures(self.field.dtype), "CUDA only supports texture types with 32 bits or less" + @property + def ndim(self): + return self.field.ndim + @classmethod def from_interpolator(cls, interpolator: LinearInterpolator): if (isinstance(interpolator, cls) @@ -432,7 +448,7 @@ class TextureAccess(InterpolatorAccess): return obj def __str__(self): - return '%s_texture(%s)' % (self.interpolator.field.name, ','.join(str(o) for o in self.offsets)) + return '%s_texture(%s)' % (self.interpolator.field.name, ', '.join(str(o) for o in self.offsets)) @property def texture(self): @@ -480,7 +496,10 @@ class TextureDeclaration(Node): @property def headers(self): - return ['"pycuda-helpers.hpp"'] + headers = ['"pycuda-helpers.hpp"'] + if self.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE: + headers.append('"cubicTex%iD.cu"' % self.texture.ndim) + return headers def __str__(self): from pystencils.backends.cuda_backend import CudaBackend @@ -515,3 +534,4 @@ def dtype_supports_textures(dtype): return dtype().itemsize <= 4 return dtype.itemsize <= 4 + diff --git a/pystencils/math_optimizations.py b/pystencils/math_optimizations.py index ad0114782..b9420318f 100644 --- a/pystencils/math_optimizations.py +++ b/pystencils/math_optimizations.py @@ -44,3 +44,13 @@ def optimize_assignments(assignments, optimizations): a.optimize(optimizations) return assignments + + +def optimize_ast(ast, optimizations): + + if HAS_REWRITING: + assignments_nodes = ast.atoms(SympyAssignment) + for a in assignments_nodes: + a.optimize(optimizations) + + return ast diff --git a/pystencils/transformations.py b/pystencils/transformations.py index bb089ccb9..eb68d53ef 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -14,8 +14,8 @@ import pystencils.astnodes as ast import pystencils.integer_functions from pystencils.assignment import Assignment from pystencils.data_types import ( - PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type, get_base_type, - get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func) + PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type, + get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func) from pystencils.field import AbstractField, Field, FieldType from pystencils.kernelparameters import FieldPointerSymbol from pystencils.simp.assignment_collection import AssignmentCollection @@ -1314,7 +1314,7 @@ def implement_interpolations(ast_node: ast.Node, implement_by_texture_accesses: bool = False, vectorize: bool = False, use_hardware_interpolation_for_f32=True): - from pystencils.interpolation_astnodes import InterpolatorAccess, TextureAccess, TextureCachedField + from pystencils.interpolation_astnodes import (InterpolatorAccess, TextureCachedField) # TODO: perform this function on assignments, when unify_shape_symbols allows differently sized fields assert not(implement_by_texture_accesses and vectorize), \ @@ -1324,34 +1324,27 @@ def implement_interpolations(ast_node: ast.Node, interpolation_accesses = ast_node.atoms(InterpolatorAccess) def can_use_hw_interpolation(i): - return use_hardware_interpolation_for_f32 and i.dtype == FLOAT32_T and isinstance(i, TextureAccess) + return (use_hardware_interpolation_for_f32 + and i.dtype == FLOAT32_T + and isinstance(i.interpolator, TextureCachedField)) if implement_by_texture_accesses: - interpolators = {a.symbol.interpolator for a in interpolation_accesses} - to_texture_map = {i: TextureCachedField.from_interpolator(i) for i in interpolators} - - substitutions = {i: to_texture_map[i.symbol.interpolator].at( - [o for o in i.offsets]) for i in interpolation_accesses} - - try: - import pycuda.driver as cuda - for texture in substitutions.values(): - if can_use_hw_interpolation(texture): + for i in interpolation_accesses: + old_i = i + try: + import pycuda.driver as cuda + texture = TextureCachedField.from_interpolator(i.interpolator) + i.interpolator = texture + i.symbol.interpolator = texture + if can_use_hw_interpolation(i): texture.filter_mode = cuda.filter_mode.LINEAR else: texture.filter_mode = cuda.filter_mode.POINT texture.read_as_integer = True - except Exception: - pass - - if isinstance(ast_node, AssignmentCollection): - ast_node = ast_node.subs(substitutions) - else: - ast_node.subs(substitutions) - - # Update after replacements - interpolation_accesses = ast_node.atoms(InterpolatorAccess) + except Exception: + pass + ast_node.subs({old_i: i}) if vectorize: # TODO can be done in _interpolator_access_to_stencils field.absolute_access == simd_gather @@ -1364,4 +1357,11 @@ def implement_interpolations(ast_node: ast.Node, else: ast_node.subs(substitutions) + # from pystencils.math_optimizations import ReplaceOptim, optimize_ast + + # RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate), + # lambda e: e.args[0] + # ) + # optimize_ast(ast_node, [RemoveConjugate]) + return ast_node diff --git a/pystencils_tests/test_interpolation.py b/pystencils_tests/test_interpolation.py index 04cf41071..63cda967b 100644 --- a/pystencils_tests/test_interpolation.py +++ b/pystencils_tests/test_interpolation.py @@ -110,8 +110,9 @@ def test_rotate_interpolation(address_mode): pyconrad.imshow(out, "out " + address_mode) -@pytest.mark.parametrize('address_mode', ['border', 'wrap', 'clamp', 'mirror']) -def test_rotate_interpolation_gpu(address_mode): +@pytest.mark.parametrize('dtype', (np.int32, np.float32, np.float64)) +@pytest.mark.parametrize('address_mode', ('border', 'wrap', 'clamp', 'mirror')) +def test_rotate_interpolation_gpu(dtype, address_mode): rotation_angle = sympy.pi / 5 scale = 1 @@ -138,24 +139,23 @@ def test_rotate_interpolation_gpu(address_mode): pystencils.show_code(ast) kernel = ast.compile() - out = gpuarray.zeros_like(lenna_gpu) - kernel(x=lenna_gpu, y=out) - pyconrad.imshow(out, - f"out {address_mode} texture:{use_textures} {type_map[dtype]}") - skimage.io.imsave(f"/tmp/out {address_mode} texture:{use_textures} {type_map[dtype]}.tif", - np.ascontiguousarray(out.get(), np.float32)) - if previous_result is not None: - try: - assert np.allclose(previous_result[4:-4, 4:-4], out.get()[4:-4, 4:-4], rtol=100, atol=1e-3) - except AssertionError: # NOQA - print("Max error: %f" % np.max(previous_result - out.get())) - # pyconrad.imshow(previous_result - out.get(), "Difference image") - # raise e - previous_result = out.get() - - -@pytest.mark.parametrize('address_mode', ['border', 'wrap', 'clamp', 'mirror']) -def test_shift_interpolation_gpu(address_mode): + out = gpuarray.zeros_like(lenna_gpu) + kernel(x=lenna_gpu, y=out) + pyconrad.imshow(out, + f"out {address_mode} texture:{use_textures} {type_map[dtype]}") + skimage.io.imsave(f"/tmp/out {address_mode} texture:{use_textures} {type_map[dtype]}.tif", + np.ascontiguousarray(out.get(), np.float32)) + if previous_result is not None: + try: + assert np.allclose(previous_result[4:-4, 4:-4], out.get()[4:-4, 4:-4], rtol=100, atol=1e-3) + except AssertionError as e: # NOQA + print("Max error: %f" % np.max(previous_result - out.get())) + # pyconrad.imshow(previous_result - out.get(), "Difference image") + # raise e + previous_result = out.get() + + +def test_shift_interpolation_gpu(): rotation_angle = 0 # sympy.pi / 5 scale = 1 -- GitLab