diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py index 9797bc7dad69533a0185d6efadcde309e27f4016..d590d65b4082e72745658f0b06eb152d64872944 100644 --- a/pystencils/backends/cuda_backend.py +++ b/pystencils/backends/cuda_backend.py @@ -75,7 +75,7 @@ class CudaSympyPrinter(CustomSympyPrinter): if type(node) == DiffInterpolatorAccess: # cubicTex3D_1st_derivative_x(texture tex, float3 coord) - template = f"cubicTex%iD_1st_derivative_{'xyz'[node.diff_coordinate_idx]}(%s, %s)" + template = f"cubicTex%iD_1st_derivative_{list(reversed('xyz'[:node.ndim]))[node.diff_coordinate_idx]}(%s, %s)" # noqa elif node.interpolator.interpolation_mode == InterpolationMode.CUBIC_SPLINE: template = "cubicTex%iDSimple(%s, %s)" else: diff --git a/pystencils/fd/finitedifferences.py b/pystencils/fd/finitedifferences.py index d5bce66e96d8aa5d296df68906c8871d40c08bba..5b6b15f95bebec40939d54a8ff2ad6c58c169f58 100644 --- a/pystencils/fd/finitedifferences.py +++ b/pystencils/fd/finitedifferences.py @@ -109,7 +109,9 @@ class Discretization2ndOrder: return self._discretize_advection(e) elif isinstance(e, Diff): arg, *indices = diff_args(e) - if not isinstance(arg, Field.Access): + from pystencils.interpolation_astnodes import InterpolatorAccess + + if not isinstance(arg, (Field.Access, InterpolatorAccess)): raise ValueError("Only derivatives with field or field accesses as arguments can be discretized") return self.spatial_stencil(indices, self.dx, arg) else: diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py index 28d45c2dc4e0def15d4a63842ebec73cc3e56359..3ecc2a70a09d46e859d1aa1934545d21caaa8664 100644 --- a/pystencils/interpolation_astnodes.py +++ b/pystencils/interpolation_astnodes.py @@ -170,6 +170,14 @@ class InterpolatorAccess(TypedSymbol): def __repr__(self): return self.__str__() + @property + def ndim(self): + return len(self.offsets) + + @property + def is_texture(self): + return isinstance(self.interpolator, TextureCachedField) + def atoms(self, *types): if self.offsets: offsets = set(o for o in self.offsets if isinstance(o, types)) @@ -182,6 +190,11 @@ class InterpolatorAccess(TypedSymbol): else: return set() + def neighbor(self, coord_id, offset): + offset_list = list(self.offsets) + offset_list[coord_id] += offset + return self.interpolator.at(tuple(offset_list)) + @property def free_symbols(self): symbols = set() @@ -318,6 +331,9 @@ class InterpolatorAccess(TypedSymbol): class DiffInterpolatorAccess(InterpolatorAccess): def __new__(cls, symbol, diff_coordinate_idx, *offsets, **kwargs): + if symbol.interpolator.interpolation_mode == InterpolationMode.LINEAR: + from pystencils.fd import Diff, Discretization2ndOrder + return Discretization2ndOrder(1)(Diff(symbol.interpolator.at(offsets), diff_coordinate_idx)) obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets, **kwargs) return obj @@ -363,7 +379,7 @@ class DiffInterpolatorAccess(InterpolatorAccess): ########################################################################################## -class TextureCachedField: +class TextureCachedField(Interpolator): def __init__(self, parent_field, address_mode=None,