Commit ee9587df authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Allow differentation of InterpolatorAccess

parent 07c87bc7
......@@ -143,7 +143,7 @@ class NearestNeightborInterpolator(Interpolator):
class InterpolatorAccess(TypedSymbol):
def __new__(cls, field, *offsets, **kwargs):
obj = TextureAccess.__xnew_cached_(cls, field, *offsets, **kwargs)
obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs)
return obj
def __new_stage2__(self, symbol, *offsets):
......@@ -201,6 +201,17 @@ class InterpolatorAccess(TypedSymbol):
def interpolation_mode(self):
return self.interpolator.interpolation_mode
@property
def _diff_interpolation_vec(self):
return sp.Matrix([DiffInterpolatorAccess(self.symbol, i, *self.offsets)
for i in range(len(self.offsets))])
def diff(self, *symbols, **kwargs):
rtn = self._diff_interpolation_vec.T * sp.Matrix(self.offsets).diff(*symbols, **kwargs)
if rtn.shape == (1, 1):
rtn = rtn[0, 0]
return rtn
def implementation_with_stencils(self):
field = self.field
......@@ -255,7 +266,7 @@ class InterpolatorAccess(TypedSymbol):
for (dim, i) in enumerate(index)]
index = [cast_func(sp.Piecewise((i, i > 0),
(sp.Abs(cast_func(field.shape[dim] - 1 + i, default_int_type)),
True)), default_int_type)
True)), default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
......@@ -290,6 +301,46 @@ class InterpolatorAccess(TypedSymbol):
def __getnewargs__(self):
return tuple(self.symbol, *self.offsets)
class DiffInterpolatorAccess(InterpolatorAccess):
def __new__(cls, symbol, diff_coordinate_idx, *offsets, **kwargs):
obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets, **kwargs)
return obj
def __new_stage2__(self, symbol: sp.Symbol, diff_coordinate_idx, *offsets):
assert offsets is not None
obj = super().__xnew__(self, symbol, *offsets)
obj.diff_coordinate_idx = diff_coordinate_idx
return obj
def __hash__(self):
return hash((self.symbol, self.field, self.diff_coordinate_idx, tuple(self.offsets), self.interpolator))
def __str__(self):
return '%s_diff%i_interpolator(%s)' % (self.field.name, self.diff_coordinate_idx,
','.join(str(o) for o in self.offsets))
@property
def args(self):
return [self.symbol, self.diff_coordinate_idx, *self.offsets]
@property
def symbols_defined(self) -> Set[sp.Symbol]:
return {self}
@property
def interpolation_mode(self):
return self.interpolator.interpolation_mode
# noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__)
# noinspection SpellCheckingInspection
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return tuple(self.symbol, self.diff_coordinate_idx, *self.offsets)
##########################################################################################
# GPU-specific fast specializations (for precision GPUs can also use above nodes/symbols #
##########################################################################################
......
......@@ -234,5 +234,13 @@ def test_field_interpolated(address_mode, target):
out = np.zeros_like(lenna)
kernel(x=lenna, y=out)
pyconrad.imshow(out, "out " + address_mode)
kernel(x=lenna, y=out)
pyconrad.imshow(out, "out " + address_mode)
def test_spatial_derivative():
x, y = pystencils.fields('x, y: float32[2d]')
tx, ty = pystencils.fields('t_x, t_y: float32[2d]')
diff = sympy.diff(x.interpolated_access((tx.center, ty.center)), tx.center)
print("diff: " + str(diff))
diff = sympy.diff(x.interpolated_access((tx.center, 2 * ty.center)), sympy.Matrix((tx.center, ty.center)))
print("diff: " + str(diff))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment