Skip to content
Snippets Groups Projects
Commit ee9587df authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Allow differentation of InterpolatorAccess

parent 07c87bc7
Branches
Tags
No related merge requests found
......@@ -143,7 +143,7 @@ class NearestNeightborInterpolator(Interpolator):
class InterpolatorAccess(TypedSymbol):
def __new__(cls, field, *offsets, **kwargs):
obj = TextureAccess.__xnew_cached_(cls, field, *offsets, **kwargs)
obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets, **kwargs)
return obj
def __new_stage2__(self, symbol, *offsets):
......@@ -201,6 +201,17 @@ class InterpolatorAccess(TypedSymbol):
def interpolation_mode(self):
return self.interpolator.interpolation_mode
@property
def _diff_interpolation_vec(self):
return sp.Matrix([DiffInterpolatorAccess(self.symbol, i, *self.offsets)
for i in range(len(self.offsets))])
def diff(self, *symbols, **kwargs):
rtn = self._diff_interpolation_vec.T * sp.Matrix(self.offsets).diff(*symbols, **kwargs)
if rtn.shape == (1, 1):
rtn = rtn[0, 0]
return rtn
def implementation_with_stencils(self):
field = self.field
......@@ -255,7 +266,7 @@ class InterpolatorAccess(TypedSymbol):
for (dim, i) in enumerate(index)]
index = [cast_func(sp.Piecewise((i, i > 0),
(sp.Abs(cast_func(field.shape[dim] - 1 + i, default_int_type)),
True)), default_int_type)
True)), default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
......@@ -290,6 +301,46 @@ class InterpolatorAccess(TypedSymbol):
def __getnewargs__(self):
return tuple(self.symbol, *self.offsets)
class DiffInterpolatorAccess(InterpolatorAccess):
def __new__(cls, symbol, diff_coordinate_idx, *offsets, **kwargs):
obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets, **kwargs)
return obj
def __new_stage2__(self, symbol: sp.Symbol, diff_coordinate_idx, *offsets):
assert offsets is not None
obj = super().__xnew__(self, symbol, *offsets)
obj.diff_coordinate_idx = diff_coordinate_idx
return obj
def __hash__(self):
return hash((self.symbol, self.field, self.diff_coordinate_idx, tuple(self.offsets), self.interpolator))
def __str__(self):
return '%s_diff%i_interpolator(%s)' % (self.field.name, self.diff_coordinate_idx,
','.join(str(o) for o in self.offsets))
@property
def args(self):
return [self.symbol, self.diff_coordinate_idx, *self.offsets]
@property
def symbols_defined(self) -> Set[sp.Symbol]:
return {self}
@property
def interpolation_mode(self):
return self.interpolator.interpolation_mode
# noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__)
# noinspection SpellCheckingInspection
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return tuple(self.symbol, self.diff_coordinate_idx, *self.offsets)
##########################################################################################
# GPU-specific fast specializations (for precision GPUs can also use above nodes/symbols #
##########################################################################################
......
......@@ -234,5 +234,13 @@ def test_field_interpolated(address_mode, target):
out = np.zeros_like(lenna)
kernel(x=lenna, y=out)
pyconrad.imshow(out, "out " + address_mode)
kernel(x=lenna, y=out)
pyconrad.imshow(out, "out " + address_mode)
def test_spatial_derivative():
x, y = pystencils.fields('x, y: float32[2d]')
tx, ty = pystencils.fields('t_x, t_y: float32[2d]')
diff = sympy.diff(x.interpolated_access((tx.center, ty.center)), tx.center)
print("diff: " + str(diff))
diff = sympy.diff(x.interpolated_access((tx.center, 2 * ty.center)), sympy.Matrix((tx.center, ty.center)))
print("diff: " + str(diff))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment