From b20974fe1395c295b3b13774cebe33139a655c6d Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 23 Jan 2020 14:10:12 +0100
Subject: [PATCH] interpolation: Stuff is working let's commit quickly

---
 pystencils/gpucuda/texture_utils.py  |  3 +++
 pystencils/interpolation_astnodes.py |  5 +----
 pystencils/transformations.py        | 12 ++++++------
 3 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/pystencils/gpucuda/texture_utils.py b/pystencils/gpucuda/texture_utils.py
index d29d86292..ff3db430f 100644
--- a/pystencils/gpucuda/texture_utils.py
+++ b/pystencils/gpucuda/texture_utils.py
@@ -15,6 +15,7 @@ import numpy as np
 try:
     import pycuda.driver as cuda
     from pycuda import gpuarray
+    import pycuda
 except Exception:
     pass
 
@@ -35,6 +36,8 @@ def ndarray_to_tex(tex_ref,
                    use_normalized_coordinates=False,
                    read_as_integer=False):
 
+    if isinstance(address_mode, str):
+        address_mode = getattr(pycuda.driver.address_mode, address_mode.upper())
     if address_mode is None:
         address_mode = cuda.address_mode.BORDER
     if filter_mode is None:
diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py
index 9db3cf77a..76bd340b7 100644
--- a/pystencils/interpolation_astnodes.py
+++ b/pystencils/interpolation_astnodes.py
@@ -116,7 +116,6 @@ class Interpolator(object):
     def _hashable_contents(self):
         return (str(self.address_mode),
                 str(type(self)),
-                self.address_mode,
                 self.hash_str,
                 self.use_normalized_coordinates)
 
@@ -416,11 +415,9 @@ class TextureCachedField(Interpolator):
                  read_as_integer=False
                  ):
         super().__init__(parent_field, interpolation_mode, address_mode, use_normalized_coordinates)
-        if isinstance(address_mode, str):
-            address_mode = getattr(pycuda.driver.address_mode, address_mode.upper())
 
         if address_mode is None:
-            address_mode = pycuda.driver.address_mode.BORDER
+            address_mode = 'border'
         if filter_mode is None:
             filter_mode = pycuda.driver.filter_mode.LINEAR
 
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index e48d4f386..448991a84 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -1334,19 +1334,19 @@ def implement_interpolations(ast_node: ast.Node,
     if implement_by_texture_accesses:
 
         for i in interpolation_accesses:
-            old_i = i
+            from pystencils.interpolation_astnodes import _InterpolationSymbol
+
             try:
                 import pycuda.driver as cuda
                 texture = TextureCachedField.from_interpolator(i.interpolator)
-                i.symbol.interpolator = texture
                 if can_use_hw_interpolation(i):
-                    i.symbol.interpolator.filter_mode = cuda.filter_mode.LINEAR
+                    texture.filter_mode = cuda.filter_mode.LINEAR
                 else:
-                    i.symbol.interpolator.filter_mode = cuda.filter_mode.POINT
-                    i.symbol.interpolator.read_as_integer = True
+                    texture.filter_mode = cuda.filter_mode.POINT
+                    texture.read_as_integer = True
             except Exception as e:
                 raise e
-            ast_node.subs({old_i: i})
+            i.symbol = _InterpolationSymbol(str(texture), i.symbol.field, texture)
 
     # from pystencils.math_optimizations import ReplaceOptim, optimize_ast
 
-- 
GitLab