Commit cd772ba1 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Improve implementation of Interpolator, TextureCachedField and add __eq__ for TextureDeclaration

parent 24f52f21
......@@ -35,6 +35,27 @@ class InterpolationMode(str, Enum):
CUBIC_SPLINE = "cubic_spline"
class _InterpolationSymbol(TypedSymbol):
def __new__(cls, name, field, interpolator):
obj = cls.__xnew_cached_(cls, name, field, interpolator)
return obj
def __new_stage2__(cls, name, field, interpolator):
obj = super().__xnew__(cls, name, 'dummy_symbol_carrying_field' + field.name)
obj.field = field
obj.interpolator = interpolator
return obj
def __getnewargs__(self):
return self.name, self.field, self.interpolator
# noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__)
# noinspection SpellCheckingInspection
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
class Interpolator(object):
"""
Implements non-integer accesses on fields using linear interpolation.
......@@ -81,14 +102,11 @@ class Interpolator(object):
self.field.field_type = pystencils.field.FieldType.CUSTOM
self.address_mode = address_mode
self.use_normalized_coordinates = use_normalized_coordinates
hash_str = hashlib.md5(
self.interpolation_mode = interpolation_mode
self.hash_str = hashlib.md5(
f'{self.field}_{address_mode}_{self.field.dtype}_{interpolation_mode}'.encode()).hexdigest()
self.symbol = TypedSymbol('dummy_symbol_carrying_field' + self.field.name + hash_str,
'dummy_symbol_carrying_field' + self.field.name + self.field.name + hash_str)
self.symbol.field = self.field
self.symbol.interpolator = self
self.symbol = _InterpolationSymbol(str(self), parent_field, self)
self.allow_textures = allow_textures
self.interpolation_mode = interpolation_mode
@property
def ndim(self):
......@@ -98,9 +116,8 @@ class Interpolator(object):
def _hashable_contents(self):
return (str(self.address_mode),
str(type(self)),
str(self.symbol.name),
hash(self.symbol.field),
self.address_mode,
self.hash_str,
self.use_normalized_coordinates)
def at(self, offset):
......@@ -116,7 +133,10 @@ class Interpolator(object):
return self.__str__()
def __hash__(self):
return hash(self._hashable_contents())
return hash(self._hashable_contents)
def __eq__(self, other):
return hash(self) == hash(other)
@property
def reproducible_hash(self):
......@@ -147,6 +167,14 @@ class NearestNeightborInterpolator(Interpolator):
use_normalized_coordinates)
# def forbid_double(expr):
# dtype = pystencils.data_types.get_type_of_expression(expr)
# if dtype == create_type('double')
# pystencils.data_types.get_type_of_expression(
# else:
# return expr
class InterpolatorAccess(TypedSymbol):
def __new__(cls, field, *offsets, **kwargs):
obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs)
......@@ -395,6 +423,7 @@ class TextureCachedField(Interpolator):
use_normalized_coordinates=False,
read_as_integer=False
):
super().__init__(parent_field, interpolation_mode, address_mode, use_normalized_coordinates)
if isinstance(address_mode, str):
address_mode = getattr(pycuda.driver.address_mode, address_mode.upper())
......@@ -403,21 +432,9 @@ class TextureCachedField(Interpolator):
if filter_mode is None:
filter_mode = pycuda.driver.filter_mode.LINEAR
# self, field_name, field_type, dtype, layout, shape, strides
self.field = parent_field
self.address_mode = address_mode
self.filter_mode = filter_mode
self.read_as_integer = read_as_integer
self.use_normalized_coordinates = use_normalized_coordinates
self.interpolation_mode = interpolation_mode
self.symbol = TypedSymbol(str(self), self.field.dtype.numpy_dtype)
self.symbol.interpolator = self
self.symbol.field = self.field
self.required_global_declarations = [TextureDeclaration(self)]
# assert str(self.field.dtype) != 'double', "CUDA does not support double textures!"
# assert dtype_supports_textures(self.field.dtype), "CUDA only supports texture types with 32 bits or less"
@property
def ndim(self):
return self.field.ndim
......@@ -436,19 +453,6 @@ class TextureCachedField(Interpolator):
def __repr__(self):
return self.__str__()
@property
def _hashable_contents(self):
return (str(self.address_mode),
str(type(self)),
str(self.symbol.name),
hash(self.symbol.field),
self.address_mode,
self.use_normalized_coordinates,
'T')
def __hash__(self):
return hash(self._hashable_contents)
@property
def reproducible_hash(self):
return _hash(str(self._hashable_contents).encode()).hexdigest()
......@@ -499,6 +503,9 @@ class TextureDeclaration(Node):
from pystencils.backends.cuda_backend import CudaBackend
return CudaBackend()(self)
def __repr__(self):
return str(self)
class TextureObject(TextureDeclaration):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment