diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py index 51fa20f40f1a357960ed72b09526c7103c72c841..e7d4d2333eff562e2ab83ed4a8d0567985ed2488 100644 --- a/pystencils/interpolation_astnodes.py +++ b/pystencils/interpolation_astnodes.py @@ -12,12 +12,11 @@ import itertools from enum import Enum from typing import Set -import sympy as sp -from sympy.core.cache import cacheit - import pystencils +import sympy as sp from pystencils.astnodes import Node from pystencils.data_types import TypedSymbol, cast_func, create_type +from sympy.core.cache import cacheit try: import pycuda.driver @@ -304,18 +303,19 @@ class TextureCachedField: self.filter_mode = filter_mode self.read_as_integer = read_as_integer self.use_normalized_coordinates = use_normalized_coordinates + self.interpolation_mode = interpolation_mode self.symbol = TypedSymbol(str(self), self.field.dtype.numpy_dtype) self.symbol.interpolator = self self.symbol.field = self.field self.required_global_declarations = [TextureDeclaration(self)] - self.interpolation_mode = interpolation_mode # assert str(self.field.dtype) != 'double', "CUDA does not support double textures!" # assert dtype_supports_textures(self.field.dtype), "CUDA only supports texture types with 32 bits or less" @classmethod def from_interpolator(cls, interpolator: LinearInterpolator): - if hasattr(interpolator, 'allow_textures') and not interpolator.allow_textures: + if (isinstance(interpolator, cls) + or (hasattr(interpolator, 'allow_textures') and not interpolator.allow_textures)): return interpolator obj = cls(interpolator.field, interpolator.address_mode, interpolation_mode=interpolator.interpolation_mode) return obj @@ -327,11 +327,19 @@ class TextureCachedField: return TextureAccess(self.symbol, *offset) def __str__(self): - return '%s_texture_%x' % (self.field.name, abs(hash(self.field) + hash(str(self.address_mode)))) + return '%s_texture_%x' % (self.field.name, abs(hash(self))) def __repr__(self): return self.__str__() + def __hash__(self): + return hash((str(type(self)), + self.address_mode, + self.filter_mode, + self.read_as_integer, + self.interpolation_mode, + self.use_normalized_coordinates)) + class TextureAccess(InterpolatorAccess): def __new__(cls, texture_symbol, offsets, *args, **kwargs):