diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py index e14222b1202349bc237ce086f842265e3012a691..ed9295d008dc26bd880a2bd9cf7d2d220a4ddcd1 100644 --- a/pystencils/interpolation_astnodes.py +++ b/pystencils/interpolation_astnodes.py @@ -143,7 +143,7 @@ class NearestNeightborInterpolator(Interpolator): class InterpolatorAccess(TypedSymbol): def __new__(cls, field, *offsets, **kwargs): - obj = TextureAccess.__xnew_cached_(cls, field, *offsets, **kwargs) + obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs) return obj def __new_stage2__(self, symbol, *offsets): @@ -201,6 +201,17 @@ class InterpolatorAccess(TypedSymbol): def interpolation_mode(self): return self.interpolator.interpolation_mode + @property + def _diff_interpolation_vec(self): + return sp.Matrix([DiffInterpolatorAccess(self.symbol, i, *self.offsets) + for i in range(len(self.offsets))]) + + def diff(self, *symbols, **kwargs): + rtn = self._diff_interpolation_vec.T * sp.Matrix(self.offsets).diff(*symbols, **kwargs) + if rtn.shape == (1, 1): + rtn = rtn[0, 0] + return rtn + def implementation_with_stencils(self): field = self.field @@ -255,7 +266,7 @@ class InterpolatorAccess(TypedSymbol): for (dim, i) in enumerate(index)] index = [cast_func(sp.Piecewise((i, i > 0), (sp.Abs(cast_func(field.shape[dim] - 1 + i, default_int_type)), - True)), default_int_type) + True)), default_int_type) for (dim, i) in enumerate(index)] sum[channel_idx] += weight * \ absolute_access(index, channel_idx if field.index_dimensions else ()) @@ -290,6 +301,46 @@ class InterpolatorAccess(TypedSymbol): def __getnewargs__(self): return tuple(self.symbol, *self.offsets) + +class DiffInterpolatorAccess(InterpolatorAccess): + def __new__(cls, symbol, diff_coordinate_idx, *offsets, **kwargs): + obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets, **kwargs) + return obj + + def __new_stage2__(self, symbol: sp.Symbol, diff_coordinate_idx, *offsets): + assert offsets is not None + obj = super().__xnew__(self, symbol, *offsets) + obj.diff_coordinate_idx = diff_coordinate_idx + return obj + + def __hash__(self): + return hash((self.symbol, self.field, self.diff_coordinate_idx, tuple(self.offsets), self.interpolator)) + + def __str__(self): + return '%s_diff%i_interpolator(%s)' % (self.field.name, self.diff_coordinate_idx, + ','.join(str(o) for o in self.offsets)) + + @property + def args(self): + return [self.symbol, self.diff_coordinate_idx, *self.offsets] + + @property + def symbols_defined(self) -> Set[sp.Symbol]: + return {self} + + @property + def interpolation_mode(self): + return self.interpolator.interpolation_mode + + # noinspection SpellCheckingInspection + __xnew__ = staticmethod(__new_stage2__) + # noinspection SpellCheckingInspection + __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) + + def __getnewargs__(self): + return tuple(self.symbol, self.diff_coordinate_idx, *self.offsets) + + ########################################################################################## # GPU-specific fast specializations (for precision GPUs can also use above nodes/symbols # ########################################################################################## diff --git a/pystencils_tests/test_interpolation.py b/pystencils_tests/test_interpolation.py index 4ead7a0fc2b489eb874abc7e9fde8cb043d5d770..d52cddc8d322e5cb13154a807418b221d1f2161d 100644 --- a/pystencils_tests/test_interpolation.py +++ b/pystencils_tests/test_interpolation.py @@ -234,5 +234,13 @@ def test_field_interpolated(address_mode, target): out = np.zeros_like(lenna) kernel(x=lenna, y=out) pyconrad.imshow(out, "out " + address_mode) - kernel(x=lenna, y=out) - pyconrad.imshow(out, "out " + address_mode) + + +def test_spatial_derivative(): + x, y = pystencils.fields('x, y: float32[2d]') + tx, ty = pystencils.fields('t_x, t_y: float32[2d]') + + diff = sympy.diff(x.interpolated_access((tx.center, ty.center)), tx.center) + print("diff: " + str(diff)) + diff = sympy.diff(x.interpolated_access((tx.center, 2 * ty.center)), sympy.Matrix((tx.center, ty.center))) + print("diff: " + str(diff))