From cd772ba1baa75836dfc0688ad4e4cccc35d88e3e Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 23 Jan 2020 12:57:01 +0100
Subject: [PATCH] Improve implementation of Interpolator, TextureCachedField
 and add __eq__ for TextureDeclaration

---
 pystencils/interpolation_astnodes.py | 75 +++++++++++++++-------------
 1 file changed, 41 insertions(+), 34 deletions(-)

diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py
index 0d27475c9..95845ac46 100644
--- a/pystencils/interpolation_astnodes.py
+++ b/pystencils/interpolation_astnodes.py
@@ -35,6 +35,27 @@ class InterpolationMode(str, Enum):
     CUBIC_SPLINE = "cubic_spline"
 
 
+class _InterpolationSymbol(TypedSymbol):
+
+    def __new__(cls, name, field, interpolator):
+        obj = cls.__xnew_cached_(cls, name, field, interpolator)
+        return obj
+
+    def __new_stage2__(cls, name, field, interpolator):
+        obj = super().__xnew__(cls, name, 'dummy_symbol_carrying_field' + field.name)
+        obj.field = field
+        obj.interpolator = interpolator
+        return obj
+
+    def __getnewargs__(self):
+        return self.name, self.field, self.interpolator
+
+    # noinspection SpellCheckingInspection
+    __xnew__ = staticmethod(__new_stage2__)
+    # noinspection SpellCheckingInspection
+    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
+
+
 class Interpolator(object):
     """
     Implements non-integer accesses on fields using linear interpolation.
@@ -81,14 +102,11 @@ class Interpolator(object):
         self.field.field_type = pystencils.field.FieldType.CUSTOM
         self.address_mode = address_mode
         self.use_normalized_coordinates = use_normalized_coordinates
-        hash_str = hashlib.md5(
+        self.interpolation_mode = interpolation_mode
+        self.hash_str = hashlib.md5(
             f'{self.field}_{address_mode}_{self.field.dtype}_{interpolation_mode}'.encode()).hexdigest()
-        self.symbol = TypedSymbol('dummy_symbol_carrying_field' + self.field.name + hash_str,
-                                  'dummy_symbol_carrying_field' + self.field.name + self.field.name + hash_str)
-        self.symbol.field = self.field
-        self.symbol.interpolator = self
+        self.symbol = _InterpolationSymbol(str(self), parent_field, self)
         self.allow_textures = allow_textures
-        self.interpolation_mode = interpolation_mode
 
     @property
     def ndim(self):
@@ -98,9 +116,8 @@ class Interpolator(object):
     def _hashable_contents(self):
         return (str(self.address_mode),
                 str(type(self)),
-                str(self.symbol.name),
-                hash(self.symbol.field),
                 self.address_mode,
+                self.hash_str,
                 self.use_normalized_coordinates)
 
     def at(self, offset):
@@ -116,7 +133,10 @@ class Interpolator(object):
         return self.__str__()
 
     def __hash__(self):
-        return hash(self._hashable_contents())
+        return hash(self._hashable_contents)
+
+    def __eq__(self, other):
+        return hash(self) == hash(other)
 
     @property
     def reproducible_hash(self):
@@ -147,6 +167,14 @@ class NearestNeightborInterpolator(Interpolator):
                          use_normalized_coordinates)
 
 
+# def forbid_double(expr):
+    # dtype = pystencils.data_types.get_type_of_expression(expr)
+    # if dtype == create_type('double')
+        # pystencils.data_types.get_type_of_expression(
+    # else:
+        # return expr
+
+
 class InterpolatorAccess(TypedSymbol):
     def __new__(cls, field, *offsets, **kwargs):
         obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs)
@@ -395,6 +423,7 @@ class TextureCachedField(Interpolator):
                  use_normalized_coordinates=False,
                  read_as_integer=False
                  ):
+        super().__init__(parent_field, interpolation_mode, address_mode, use_normalized_coordinates)
         if isinstance(address_mode, str):
             address_mode = getattr(pycuda.driver.address_mode, address_mode.upper())
 
@@ -403,21 +432,9 @@ class TextureCachedField(Interpolator):
         if filter_mode is None:
             filter_mode = pycuda.driver.filter_mode.LINEAR
 
-        # self, field_name, field_type, dtype, layout, shape, strides
-        self.field = parent_field
-        self.address_mode = address_mode
-        self.filter_mode = filter_mode
         self.read_as_integer = read_as_integer
-        self.use_normalized_coordinates = use_normalized_coordinates
-        self.interpolation_mode = interpolation_mode
-        self.symbol = TypedSymbol(str(self), self.field.dtype.numpy_dtype)
-        self.symbol.interpolator = self
-        self.symbol.field = self.field
         self.required_global_declarations = [TextureDeclaration(self)]
 
-        # assert str(self.field.dtype) != 'double', "CUDA does not support double textures!"
-        # assert dtype_supports_textures(self.field.dtype), "CUDA only supports texture types with 32 bits or less"
-
     @property
     def ndim(self):
         return self.field.ndim
@@ -436,19 +453,6 @@ class TextureCachedField(Interpolator):
     def __repr__(self):
         return self.__str__()
 
-    @property
-    def _hashable_contents(self):
-        return (str(self.address_mode),
-                str(type(self)),
-                str(self.symbol.name),
-                hash(self.symbol.field),
-                self.address_mode,
-                self.use_normalized_coordinates,
-                'T')
-
-    def __hash__(self):
-        return hash(self._hashable_contents)
-
     @property
     def reproducible_hash(self):
         return _hash(str(self._hashable_contents).encode()).hexdigest()
@@ -499,6 +503,9 @@ class TextureDeclaration(Node):
         from pystencils.backends.cuda_backend import CudaBackend
         return CudaBackend()(self)
 
+    def __repr__(self):
+        return str(self)
+
 
 class TextureObject(TextureDeclaration):
     """
-- 
GitLab