Commit 58985c02 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add AssignmentCollection.atoms

parent 06afa33d
......@@ -3,7 +3,8 @@ from os.path import dirname, join
from pystencils.astnodes import Node
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.interpolation_astnodes import DiffInterpolatorAccess, InterpolationMode
from pystencils.interpolation_astnodes import (
DiffInterpolatorAccess, InterpolationMode, TextureCachedField)
with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
lines = f.readlines()
......@@ -70,13 +71,13 @@ class CudaSympyPrinter(CustomSympyPrinter):
super(CudaSympyPrinter, self).__init__()
self.known_functions.update(CUDA_KNOWN_FUNCTIONS)
def _print_TextureAccess(self, node):
dtype = node.texture.field.dtype.numpy_dtype
def _print_InterpolatorAccess(self, node):
dtype = node.interpolator.field.dtype.numpy_dtype
if type(node) == DiffInterpolatorAccess:
if isinstance(node, DiffInterpolatorAccess):
# cubicTex3D_1st_derivative_x(texture tex, float3 coord)
template = f"cubicTex%iD_1st_{'xyz'[node.diff_coordinate_idx]}(%s, %s)"
elif node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
template = f"cubicTex%iD_1st_derivative_{'zyx'[node.diff_coordinate_idx]}(%s, %s)"
elif node.interpolator.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
template = "cubicTex%iDSimple(%s, %s)"
else:
if dtype.itemsize > 4:
......@@ -87,8 +88,8 @@ class CudaSympyPrinter(CustomSympyPrinter):
template = "tex%iD(%s, %s)"
code = template % (
node.texture.field.spatial_dimensions,
str(node.texture),
node.interpolator.field.spatial_dimensions,
str(node.interpolator),
# + 0.5 comes from Nvidia's staggered indexing
', '.join(self._print(o + 0.5) for o in reversed(node.offsets))
)
......
......@@ -156,6 +156,9 @@ class AssignmentCollection:
"""See :func:`count_operations` """
return count_operations(self.all_assignments, only_type=None)
def atoms(self, *args):
return set().union(*[a.atoms(*args) for a in self.all_assignments])
def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
"""Returns all symbols that depend on one of the passed symbols.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment