diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py index e3a6653a3fe5adca9656205ec217feb650b2b1ed..0766b941591e32f29db9149d52080c171e8e7eac 100644 --- a/pystencils/backends/cuda_backend.py +++ b/pystencils/backends/cuda_backend.py @@ -3,7 +3,8 @@ from os.path import dirname, join from pystencils.astnodes import Node from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt -from pystencils.interpolation_astnodes import DiffInterpolatorAccess, InterpolationMode +from pystencils.interpolation_astnodes import ( + DiffInterpolatorAccess, InterpolationMode, TextureCachedField) with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f: lines = f.readlines() @@ -70,13 +71,13 @@ class CudaSympyPrinter(CustomSympyPrinter): super(CudaSympyPrinter, self).__init__() self.known_functions.update(CUDA_KNOWN_FUNCTIONS) - def _print_TextureAccess(self, node): - dtype = node.texture.field.dtype.numpy_dtype + def _print_InterpolatorAccess(self, node): + dtype = node.interpolator.field.dtype.numpy_dtype - if type(node) == DiffInterpolatorAccess: + if isinstance(node, DiffInterpolatorAccess): # cubicTex3D_1st_derivative_x(texture tex, float3 coord) - template = f"cubicTex%iD_1st_{'xyz'[node.diff_coordinate_idx]}(%s, %s)" - elif node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE: + template = f"cubicTex%iD_1st_derivative_{'zyx'[node.diff_coordinate_idx]}(%s, %s)" + elif node.interpolator.interpolation_mode == InterpolationMode.CUBIC_SPLINE: template = "cubicTex%iDSimple(%s, %s)" else: if dtype.itemsize > 4: @@ -87,8 +88,8 @@ class CudaSympyPrinter(CustomSympyPrinter): template = "tex%iD(%s, %s)" code = template % ( - node.texture.field.spatial_dimensions, - str(node.texture), + node.interpolator.field.spatial_dimensions, + str(node.interpolator), # + 0.5 comes from Nvidia's staggered indexing ', '.join(self._print(o + 0.5) for o in reversed(node.offsets)) ) diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index 22968eb72361b15dd24a0ebc72fd284ddeddd2b7..706008c859bd0cdc9f2f3b892557242a5aa01184 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -156,6 +156,9 @@ class AssignmentCollection: """See :func:`count_operations` """ return count_operations(self.all_assignments, only_type=None) + def atoms(self, *args): + return set().union(*[a.atoms(*args) for a in self.all_assignments]) + def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]: """Returns all symbols that depend on one of the passed symbols.