From 479cc5ad73fd5094898dbec0e79ee073e942af2a Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Tue, 14 Jan 2020 18:46:35 +0100
Subject: [PATCH] Refactor interpolation

---
 pystencils/backends/cbackend.py        |  7 ++--
 pystencils/backends/cuda_backend.py    |  7 ++--
 pystencils/gpucuda/cudajit.py          | 12 +++++--
 pystencils/interpolation_astnodes.py   | 34 ++++++++++++++----
 pystencils/math_optimizations.py       | 10 ++++++
 pystencils/transformations.py          | 48 +++++++++++++-------------
 pystencils_tests/test_interpolation.py | 40 ++++++++++-----------
 7 files changed, 98 insertions(+), 60 deletions(-)

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 6f203837d..3f016810a 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -10,11 +10,12 @@ from sympy.printing.ccode import C89CodePrinter
 from pystencils.astnodes import KernelFunction, Node
 from pystencils.cpu.vectorization import vec_all, vec_any
 from pystencils.data_types import (
-    PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func,
-    vector_memory_access)
+    PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
+    reinterpret_cast_func, vector_memory_access)
 from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
 from pystencils.integer_functions import (
-    bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil)
+    bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
+    int_div, int_power_of_2, modulo_ceil)
 
 try:
     from sympy.printing.ccode import C99CodePrinter as CCodePrinter
diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py
index 0766b9415..9797bc7da 100644
--- a/pystencils/backends/cuda_backend.py
+++ b/pystencils/backends/cuda_backend.py
@@ -3,8 +3,7 @@ from os.path import dirname, join
 from pystencils.astnodes import Node
 from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
 from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
-from pystencils.interpolation_astnodes import (
-    DiffInterpolatorAccess, InterpolationMode, TextureCachedField)
+from pystencils.interpolation_astnodes import DiffInterpolatorAccess, InterpolationMode
 
 with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
     lines = f.readlines()
@@ -74,9 +73,9 @@ class CudaSympyPrinter(CustomSympyPrinter):
     def _print_InterpolatorAccess(self, node):
         dtype = node.interpolator.field.dtype.numpy_dtype
 
-        if isinstance(node, DiffInterpolatorAccess):
+        if type(node) == DiffInterpolatorAccess:
             # cubicTex3D_1st_derivative_x(texture tex, float3 coord)
-            template = f"cubicTex%iD_1st_derivative_{'zyx'[node.diff_coordinate_idx]}(%s, %s)"
+            template = f"cubicTex%iD_1st_derivative_{'xyz'[node.diff_coordinate_idx]}(%s, %s)"
         elif node.interpolator.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
             template = "cubicTex%iDSimple(%s, %s)"
         else:
diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py
index eb2492f61..a77767dce 100644
--- a/pystencils/gpucuda/cudajit.py
+++ b/pystencils/gpucuda/cudajit.py
@@ -5,13 +5,20 @@ from pystencils.data_types import StructType
 from pystencils.field import FieldType
 from pystencils.gpucuda.texture_utils import ndarray_to_tex
 from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
-from pystencils.interpolation_astnodes import TextureAccess
+from pystencils.interpolation_astnodes import InterpolatorAccess, TextureCachedField
 from pystencils.kernel_wrapper import KernelWrapper
 from pystencils.kernelparameters import FieldPointerSymbol
 
 USE_FAST_MATH = True
 
 
+def get_cubic_interpolation_include_paths():
+    from os.path import join, dirname
+
+    return [join(dirname(__file__), "CubicInterpolationCUDA", "code"),
+            join(dirname(__file__), "CubicInterpolationCUDA", "code", "internal")]
+
+
 def make_python_function(kernel_function_node, argument_dict=None, custom_backend=None):
     """
     Creates a kernel function from an abstract syntax tree which
@@ -39,7 +46,8 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
     code += "#define FUNC_PREFIX __global__\n"
     code += "#define RESTRICT __restrict__\n\n"
     code += str(generate_c(kernel_function_node, dialect='cuda', custom_backend=custom_backend))
-    textures = set(d.texture for d in kernel_function_node.atoms(TextureAccess))
+    textures = set(d.interpolator for d in kernel_function_node.atoms(
+        InterpolatorAccess) if isinstance(d.interpolator, TextureCachedField))
 
     nvcc_options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"]
     if USE_FAST_MATH:
diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py
index b67a06e71..b14e07be1 100644
--- a/pystencils/interpolation_astnodes.py
+++ b/pystencils/interpolation_astnodes.py
@@ -89,6 +89,10 @@ class Interpolator(object):
         self.allow_textures = allow_textures
         self.interpolation_mode = interpolation_mode
 
+    @property
+    def ndim(self):
+        return self.field.ndim
+
     @property
     def _hashable_contents(self):
         return (str(self.address_mode),
@@ -146,10 +150,11 @@ class InterpolatorAccess(TypedSymbol):
         obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs)
         return obj
 
-    def __new_stage2__(self, symbol, *offsets):
+    def __new_stage2__(cls, symbol, *offsets):
         assert offsets is not None
-        obj = super().__xnew__(self, '%s_interpolator_%x' %
-                               (symbol.field.name, abs(hash(tuple(offsets)))), symbol.field.dtype)
+        obj = super().__xnew__(cls, '%s_interpolator_%s' %
+                               (symbol.field.name, _hash(str(tuple(offsets)).encode()).hexdigest()),
+                               symbol.field.dtype)
         obj.offsets = offsets
         obj.symbol = symbol
         obj.field = symbol.field
@@ -160,7 +165,7 @@ class InterpolatorAccess(TypedSymbol):
         return hash((self.symbol, self.field, tuple(self.offsets), self.interpolator))
 
     def __str__(self):
-        return '%s_interpolator(%s)' % (self.field.name, ','.join(str(o) for o in self.offsets))
+        return '%s_interpolator(%s)' % (self.field.name, ', '.join(str(o) for o in self.offsets))
 
     def __repr__(self):
         return self.__str__()
@@ -189,6 +194,13 @@ class InterpolatorAccess(TypedSymbol):
 
         return symbols
 
+    @property
+    def required_global_declarations(self):
+        required_global_declarations = self.symbol.interpolator.required_global_declarations
+        if required_global_declarations:
+            required_global_declarations[0]._symbols_defined.add(self)
+        return required_global_declarations
+
     @property
     def args(self):
         return [self.symbol, *self.offsets]
@@ -320,7 +332,7 @@ class DiffInterpolatorAccess(InterpolatorAccess):
 
     def __str__(self):
         return '%s_diff%i_interpolator(%s)' % (self.field.name, self.diff_coordinate_idx,
-                                               ','.join(str(o) for o in self.offsets))
+                                               ', '.join(str(o) for o in self.offsets))
 
     def __repr__(self):
         return str(self)
@@ -383,6 +395,10 @@ class TextureCachedField:
         # assert str(self.field.dtype) != 'double', "CUDA does not support double textures!"
         # assert dtype_supports_textures(self.field.dtype), "CUDA only supports texture types with 32 bits or less"
 
+    @property
+    def ndim(self):
+        return self.field.ndim
+
     @classmethod
     def from_interpolator(cls, interpolator: LinearInterpolator):
         if (isinstance(interpolator, cls)
@@ -432,7 +448,7 @@ class TextureAccess(InterpolatorAccess):
         return obj
 
     def __str__(self):
-        return '%s_texture(%s)' % (self.interpolator.field.name, ','.join(str(o) for o in self.offsets))
+        return '%s_texture(%s)' % (self.interpolator.field.name, ', '.join(str(o) for o in self.offsets))
 
     @property
     def texture(self):
@@ -480,7 +496,10 @@ class TextureDeclaration(Node):
 
     @property
     def headers(self):
-        return ['"pycuda-helpers.hpp"']
+        headers = ['"pycuda-helpers.hpp"']
+        if self.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
+            headers.append('"cubicTex%iD.cu"' % self.texture.ndim)
+        return headers
 
     def __str__(self):
         from pystencils.backends.cuda_backend import CudaBackend
@@ -515,3 +534,4 @@ def dtype_supports_textures(dtype):
         return dtype().itemsize <= 4
 
     return dtype.itemsize <= 4
+
diff --git a/pystencils/math_optimizations.py b/pystencils/math_optimizations.py
index ad0114782..b9420318f 100644
--- a/pystencils/math_optimizations.py
+++ b/pystencils/math_optimizations.py
@@ -44,3 +44,13 @@ def optimize_assignments(assignments, optimizations):
             a.optimize(optimizations)
 
     return assignments
+
+
+def optimize_ast(ast, optimizations):
+
+    if HAS_REWRITING:
+        assignments_nodes = ast.atoms(SympyAssignment)
+        for a in assignments_nodes:
+            a.optimize(optimizations)
+
+    return ast
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index bb089ccb9..eb68d53ef 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -14,8 +14,8 @@ import pystencils.astnodes as ast
 import pystencils.integer_functions
 from pystencils.assignment import Assignment
 from pystencils.data_types import (
-    PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type, get_base_type,
-    get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
+    PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type,
+    get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func)
 from pystencils.field import AbstractField, Field, FieldType
 from pystencils.kernelparameters import FieldPointerSymbol
 from pystencils.simp.assignment_collection import AssignmentCollection
@@ -1314,7 +1314,7 @@ def implement_interpolations(ast_node: ast.Node,
                              implement_by_texture_accesses: bool = False,
                              vectorize: bool = False,
                              use_hardware_interpolation_for_f32=True):
-    from pystencils.interpolation_astnodes import InterpolatorAccess, TextureAccess, TextureCachedField
+    from pystencils.interpolation_astnodes import (InterpolatorAccess, TextureCachedField)
     # TODO: perform this function on assignments, when unify_shape_symbols allows differently sized fields
 
     assert not(implement_by_texture_accesses and vectorize), \
@@ -1324,34 +1324,27 @@ def implement_interpolations(ast_node: ast.Node,
     interpolation_accesses = ast_node.atoms(InterpolatorAccess)
 
     def can_use_hw_interpolation(i):
-        return use_hardware_interpolation_for_f32 and i.dtype == FLOAT32_T and isinstance(i, TextureAccess)
+        return (use_hardware_interpolation_for_f32
+                and i.dtype == FLOAT32_T
+                and isinstance(i.interpolator, TextureCachedField))
 
     if implement_by_texture_accesses:
 
-        interpolators = {a.symbol.interpolator for a in interpolation_accesses}
-        to_texture_map = {i: TextureCachedField.from_interpolator(i) for i in interpolators}
-
-        substitutions = {i: to_texture_map[i.symbol.interpolator].at(
-            [o for o in i.offsets]) for i in interpolation_accesses}
-
-        try:
-            import pycuda.driver as cuda
-            for texture in substitutions.values():
-                if can_use_hw_interpolation(texture):
+        for i in interpolation_accesses:
+            old_i = i
+            try:
+                import pycuda.driver as cuda
+                texture = TextureCachedField.from_interpolator(i.interpolator)
+                i.interpolator = texture
+                i.symbol.interpolator = texture
+                if can_use_hw_interpolation(i):
                     texture.filter_mode = cuda.filter_mode.LINEAR
                 else:
                     texture.filter_mode = cuda.filter_mode.POINT
                     texture.read_as_integer = True
-        except Exception:
-            pass
-
-        if isinstance(ast_node, AssignmentCollection):
-            ast_node = ast_node.subs(substitutions)
-        else:
-            ast_node.subs(substitutions)
-
-        # Update after replacements
-        interpolation_accesses = ast_node.atoms(InterpolatorAccess)
+            except Exception:
+                pass
+            ast_node.subs({old_i: i})
 
     if vectorize:
         # TODO can be done in _interpolator_access_to_stencils field.absolute_access == simd_gather
@@ -1364,4 +1357,11 @@ def implement_interpolations(ast_node: ast.Node,
         else:
             ast_node.subs(substitutions)
 
+    # from pystencils.math_optimizations import ReplaceOptim, optimize_ast
+
+    # RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate),
+            # lambda e: e.args[0]
+            # )
+    # optimize_ast(ast_node, [RemoveConjugate])
+
     return ast_node
diff --git a/pystencils_tests/test_interpolation.py b/pystencils_tests/test_interpolation.py
index 04cf41071..63cda967b 100644
--- a/pystencils_tests/test_interpolation.py
+++ b/pystencils_tests/test_interpolation.py
@@ -110,8 +110,9 @@ def test_rotate_interpolation(address_mode):
     pyconrad.imshow(out, "out " + address_mode)
 
 
-@pytest.mark.parametrize('address_mode', ['border', 'wrap', 'clamp', 'mirror'])
-def test_rotate_interpolation_gpu(address_mode):
+@pytest.mark.parametrize('dtype', (np.int32, np.float32, np.float64))
+@pytest.mark.parametrize('address_mode', ('border', 'wrap', 'clamp', 'mirror'))
+def test_rotate_interpolation_gpu(dtype, address_mode):
 
     rotation_angle = sympy.pi / 5
     scale = 1
@@ -138,24 +139,23 @@ def test_rotate_interpolation_gpu(address_mode):
             pystencils.show_code(ast)
             kernel = ast.compile()
 
-            out = gpuarray.zeros_like(lenna_gpu)
-            kernel(x=lenna_gpu, y=out)
-            pyconrad.imshow(out,
-                            f"out {address_mode} texture:{use_textures} {type_map[dtype]}")
-            skimage.io.imsave(f"/tmp/out {address_mode} texture:{use_textures} {type_map[dtype]}.tif",
-                              np.ascontiguousarray(out.get(), np.float32))
-            if previous_result is not None:
-                try:
-                    assert np.allclose(previous_result[4:-4, 4:-4], out.get()[4:-4, 4:-4], rtol=100, atol=1e-3)
-                except AssertionError:  # NOQA
-                    print("Max error: %f" % np.max(previous_result - out.get()))
-                    # pyconrad.imshow(previous_result - out.get(), "Difference image")
-                    # raise e
-            previous_result = out.get()
-
-
-@pytest.mark.parametrize('address_mode', ['border', 'wrap', 'clamp', 'mirror'])
-def test_shift_interpolation_gpu(address_mode):
+        out = gpuarray.zeros_like(lenna_gpu)
+        kernel(x=lenna_gpu, y=out)
+        pyconrad.imshow(out,
+                        f"out {address_mode} texture:{use_textures} {type_map[dtype]}")
+        skimage.io.imsave(f"/tmp/out {address_mode} texture:{use_textures} {type_map[dtype]}.tif",
+                          np.ascontiguousarray(out.get(), np.float32))
+        if previous_result is not None:
+            try:
+                assert np.allclose(previous_result[4:-4, 4:-4], out.get()[4:-4, 4:-4], rtol=100, atol=1e-3)
+            except AssertionError as e:  # NOQA
+                print("Max error: %f" % np.max(previous_result - out.get()))
+                # pyconrad.imshow(previous_result - out.get(), "Difference image")
+                # raise e
+        previous_result = out.get()
+
+
+def test_shift_interpolation_gpu():
 
     rotation_angle = 0  # sympy.pi / 5
     scale = 1
-- 
GitLab