From cd772ba1baa75836dfc0688ad4e4cccc35d88e3e Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Thu, 23 Jan 2020 12:57:01 +0100 Subject: [PATCH] Improve implementation of Interpolator, TextureCachedField and add __eq__ for TextureDeclaration --- pystencils/interpolation_astnodes.py | 75 +++++++++++++++------------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py index 0d27475c9..95845ac46 100644 --- a/pystencils/interpolation_astnodes.py +++ b/pystencils/interpolation_astnodes.py @@ -35,6 +35,27 @@ class InterpolationMode(str, Enum): CUBIC_SPLINE = "cubic_spline" +class _InterpolationSymbol(TypedSymbol): + + def __new__(cls, name, field, interpolator): + obj = cls.__xnew_cached_(cls, name, field, interpolator) + return obj + + def __new_stage2__(cls, name, field, interpolator): + obj = super().__xnew__(cls, name, 'dummy_symbol_carrying_field' + field.name) + obj.field = field + obj.interpolator = interpolator + return obj + + def __getnewargs__(self): + return self.name, self.field, self.interpolator + + # noinspection SpellCheckingInspection + __xnew__ = staticmethod(__new_stage2__) + # noinspection SpellCheckingInspection + __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) + + class Interpolator(object): """ Implements non-integer accesses on fields using linear interpolation. @@ -81,14 +102,11 @@ class Interpolator(object): self.field.field_type = pystencils.field.FieldType.CUSTOM self.address_mode = address_mode self.use_normalized_coordinates = use_normalized_coordinates - hash_str = hashlib.md5( + self.interpolation_mode = interpolation_mode + self.hash_str = hashlib.md5( f'{self.field}_{address_mode}_{self.field.dtype}_{interpolation_mode}'.encode()).hexdigest() - self.symbol = TypedSymbol('dummy_symbol_carrying_field' + self.field.name + hash_str, - 'dummy_symbol_carrying_field' + self.field.name + self.field.name + hash_str) - self.symbol.field = self.field - self.symbol.interpolator = self + self.symbol = _InterpolationSymbol(str(self), parent_field, self) self.allow_textures = allow_textures - self.interpolation_mode = interpolation_mode @property def ndim(self): @@ -98,9 +116,8 @@ class Interpolator(object): def _hashable_contents(self): return (str(self.address_mode), str(type(self)), - str(self.symbol.name), - hash(self.symbol.field), self.address_mode, + self.hash_str, self.use_normalized_coordinates) def at(self, offset): @@ -116,7 +133,10 @@ class Interpolator(object): return self.__str__() def __hash__(self): - return hash(self._hashable_contents()) + return hash(self._hashable_contents) + + def __eq__(self, other): + return hash(self) == hash(other) @property def reproducible_hash(self): @@ -147,6 +167,14 @@ class NearestNeightborInterpolator(Interpolator): use_normalized_coordinates) +# def forbid_double(expr): + # dtype = pystencils.data_types.get_type_of_expression(expr) + # if dtype == create_type('double') + # pystencils.data_types.get_type_of_expression( + # else: + # return expr + + class InterpolatorAccess(TypedSymbol): def __new__(cls, field, *offsets, **kwargs): obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs) @@ -395,6 +423,7 @@ class TextureCachedField(Interpolator): use_normalized_coordinates=False, read_as_integer=False ): + super().__init__(parent_field, interpolation_mode, address_mode, use_normalized_coordinates) if isinstance(address_mode, str): address_mode = getattr(pycuda.driver.address_mode, address_mode.upper()) @@ -403,21 +432,9 @@ class TextureCachedField(Interpolator): if filter_mode is None: filter_mode = pycuda.driver.filter_mode.LINEAR - # self, field_name, field_type, dtype, layout, shape, strides - self.field = parent_field - self.address_mode = address_mode - self.filter_mode = filter_mode self.read_as_integer = read_as_integer - self.use_normalized_coordinates = use_normalized_coordinates - self.interpolation_mode = interpolation_mode - self.symbol = TypedSymbol(str(self), self.field.dtype.numpy_dtype) - self.symbol.interpolator = self - self.symbol.field = self.field self.required_global_declarations = [TextureDeclaration(self)] - # assert str(self.field.dtype) != 'double', "CUDA does not support double textures!" - # assert dtype_supports_textures(self.field.dtype), "CUDA only supports texture types with 32 bits or less" - @property def ndim(self): return self.field.ndim @@ -436,19 +453,6 @@ class TextureCachedField(Interpolator): def __repr__(self): return self.__str__() - @property - def _hashable_contents(self): - return (str(self.address_mode), - str(type(self)), - str(self.symbol.name), - hash(self.symbol.field), - self.address_mode, - self.use_normalized_coordinates, - 'T') - - def __hash__(self): - return hash(self._hashable_contents) - @property def reproducible_hash(self): return _hash(str(self._hashable_contents).encode()).hexdigest() @@ -499,6 +503,9 @@ class TextureDeclaration(Node): from pystencils.backends.cuda_backend import CudaBackend return CudaBackend()(self) + def __repr__(self): + return str(self) + class TextureObject(TextureDeclaration): """ -- GitLab