From dd6b920e734f6591a2568fb397591a6e9f3eba96 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 10 Jan 2020 17:56:44 +0100
Subject: [PATCH] Attempt to implement printing of DiffInterpolatorAccess

---
 pystencils/backends/cuda_backend.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py
index 8bc584689..e3a6653a3 100644
--- a/pystencils/backends/cuda_backend.py
+++ b/pystencils/backends/cuda_backend.py
@@ -3,7 +3,7 @@ from os.path import dirname, join
 from pystencils.astnodes import Node
 from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
 from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
-from pystencils.interpolation_astnodes import InterpolationMode
+from pystencils.interpolation_astnodes import DiffInterpolatorAccess, InterpolationMode
 
 with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
     lines = f.readlines()
@@ -73,7 +73,10 @@ class CudaSympyPrinter(CustomSympyPrinter):
     def _print_TextureAccess(self, node):
         dtype = node.texture.field.dtype.numpy_dtype
 
-        if node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
+        if type(node) == DiffInterpolatorAccess:
+            # cubicTex3D_1st_derivative_x(texture tex, float3 coord)
+            template = f"cubicTex%iD_1st_{'xyz'[node.diff_coordinate_idx]}(%s, %s)"
+        elif node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
             template = "cubicTex%iDSimple(%s, %s)"
         else:
             if dtype.itemsize > 4:
-- 
GitLab