diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py index 71bb325b2dc6365c4d9c24b34dbe32873ad8dd4d..b67a06e71c92e642d2e7dc68970c2dcf04391285 100644 --- a/pystencils/interpolation_astnodes.py +++ b/pystencils/interpolation_astnodes.py @@ -207,6 +207,8 @@ class InterpolatorAccess(TypedSymbol): for i in range(len(self.offsets))]) def diff(self, *symbols, **kwargs): + if symbols == (self,): + return 1 rtn = self._diff_interpolation_vec.T * sp.Matrix(self.offsets).diff(*symbols, **kwargs) if rtn.shape == (1, 1): rtn = rtn[0, 0]