From dd6b920e734f6591a2568fb397591a6e9f3eba96 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 10 Jan 2020 17:56:44 +0100 Subject: [PATCH] Attempt to implement printing of DiffInterpolatorAccess --- pystencils/backends/cuda_backend.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py index 8bc584689..e3a6653a3 100644 --- a/pystencils/backends/cuda_backend.py +++ b/pystencils/backends/cuda_backend.py @@ -3,7 +3,7 @@ from os.path import dirname, join from pystencils.astnodes import Node from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt -from pystencils.interpolation_astnodes import InterpolationMode +from pystencils.interpolation_astnodes import DiffInterpolatorAccess, InterpolationMode with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f: lines = f.readlines() @@ -73,7 +73,10 @@ class CudaSympyPrinter(CustomSympyPrinter): def _print_TextureAccess(self, node): dtype = node.texture.field.dtype.numpy_dtype - if node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE: + if type(node) == DiffInterpolatorAccess: + # cubicTex3D_1st_derivative_x(texture tex, float3 coord) + template = f"cubicTex%iD_1st_{'xyz'[node.diff_coordinate_idx]}(%s, %s)" + elif node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE: template = "cubicTex%iDSimple(%s, %s)" else: if dtype.itemsize > 4: -- GitLab