diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py index 8bc58468901a006bc7ef8f278e3e7a544234de52..e3a6653a3fe5adca9656205ec217feb650b2b1ed 100644 --- a/pystencils/backends/cuda_backend.py +++ b/pystencils/backends/cuda_backend.py @@ -3,7 +3,7 @@ from os.path import dirname, join from pystencils.astnodes import Node from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt -from pystencils.interpolation_astnodes import InterpolationMode +from pystencils.interpolation_astnodes import DiffInterpolatorAccess, InterpolationMode with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f: lines = f.readlines() @@ -73,7 +73,10 @@ class CudaSympyPrinter(CustomSympyPrinter): def _print_TextureAccess(self, node): dtype = node.texture.field.dtype.numpy_dtype - if node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE: + if type(node) == DiffInterpolatorAccess: + # cubicTex3D_1st_derivative_x(texture tex, float3 coord) + template = f"cubicTex%iD_1st_{'xyz'[node.diff_coordinate_idx]}(%s, %s)" + elif node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE: template = "cubicTex%iDSimple(%s, %s)" else: if dtype.itemsize > 4: