diff --git a/pystencils/gpucuda/texture_utils.py b/pystencils/gpucuda/texture_utils.py index d29d862929530b27f840da445467a2bae44b1bfb..ff3db430fc841f4121b9602b57a1c1aa08ef86f9 100644 --- a/pystencils/gpucuda/texture_utils.py +++ b/pystencils/gpucuda/texture_utils.py @@ -15,6 +15,7 @@ import numpy as np try: import pycuda.driver as cuda from pycuda import gpuarray + import pycuda except Exception: pass @@ -35,6 +36,8 @@ def ndarray_to_tex(tex_ref, use_normalized_coordinates=False, read_as_integer=False): + if isinstance(address_mode, str): + address_mode = getattr(pycuda.driver.address_mode, address_mode.upper()) if address_mode is None: address_mode = cuda.address_mode.BORDER if filter_mode is None: diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py index 9db3cf77a08811b45be073ff4872c148a71f906d..76bd340b7e2c9bd66d9c82f77bc0399208017e52 100644 --- a/pystencils/interpolation_astnodes.py +++ b/pystencils/interpolation_astnodes.py @@ -116,7 +116,6 @@ class Interpolator(object): def _hashable_contents(self): return (str(self.address_mode), str(type(self)), - self.address_mode, self.hash_str, self.use_normalized_coordinates) @@ -416,11 +415,9 @@ class TextureCachedField(Interpolator): read_as_integer=False ): super().__init__(parent_field, interpolation_mode, address_mode, use_normalized_coordinates) - if isinstance(address_mode, str): - address_mode = getattr(pycuda.driver.address_mode, address_mode.upper()) if address_mode is None: - address_mode = pycuda.driver.address_mode.BORDER + address_mode = 'border' if filter_mode is None: filter_mode = pycuda.driver.filter_mode.LINEAR diff --git a/pystencils/transformations.py b/pystencils/transformations.py index e48d4f386d2614bc86272d0f63c1cb5fcb3a6f41..448991a842f2bd9cee6a47fc559fe16bbf8996ed 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -1334,19 +1334,19 @@ def implement_interpolations(ast_node: ast.Node, if implement_by_texture_accesses: for i in interpolation_accesses: - old_i = i + from pystencils.interpolation_astnodes import _InterpolationSymbol + try: import pycuda.driver as cuda texture = TextureCachedField.from_interpolator(i.interpolator) - i.symbol.interpolator = texture if can_use_hw_interpolation(i): - i.symbol.interpolator.filter_mode = cuda.filter_mode.LINEAR + texture.filter_mode = cuda.filter_mode.LINEAR else: - i.symbol.interpolator.filter_mode = cuda.filter_mode.POINT - i.symbol.interpolator.read_as_integer = True + texture.filter_mode = cuda.filter_mode.POINT + texture.read_as_integer = True except Exception as e: raise e - ast_node.subs({old_i: i}) + i.symbol = _InterpolationSymbol(str(texture), i.symbol.field, texture) # from pystencils.math_optimizations import ReplaceOptim, optimize_ast