Commit 59b6c5c4 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Form backward_assignments with spatial derivatives

parent db815251
Pipeline #20944 failed with stage
in 60 minutes and 1 second
......@@ -121,7 +121,8 @@ def resample(input_field, output_field, interpolation_mode='linear'):
def translate(input_field: pystencils.Field,
output_field: pystencils.Field,
translation,
interpolation_mode='linear'):
interpolation_mode='linear',
allow_spatial_derivatives=True):
def create_autodiff(self, constant_fields=None, **kwargs):
backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation)
......@@ -136,7 +137,9 @@ def translate(input_field: pystencils.Field,
output_field.center: input_field.interpolated_access(
input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode)
})
assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
if not allow_spatial_derivatives:
assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
return assignments
......
......@@ -93,7 +93,7 @@ def test_polar_transform2():
class PolarTransform(sympy.Function):
def eval(args):
return sympy.Matrix(
(args.norm(), sympy.atan2(args[1]-x.shape[1]/2, args[0]-x.shape[0]/2) / sympy.pi * x.shape[1]/2))
(args.norm(), sympy.atan2(args[1] - x.shape[1] / 2, args[0] - x.shape[0] / 2) / sympy.pi * x.shape[1] / 2))
x.set_coordinate_origin_to_field_center()
y.coordinate_transform = PolarTransform
......@@ -117,11 +117,11 @@ def test_polar_inverted_transform():
class PolarTransform(sympy.Function):
def eval(args):
return sympy.Matrix(
(args.norm(), sympy.atan2(args[1]-x.shape[1]/2, args[0]-x.shape[0]/2) / sympy.pi * x.shape[1]/2))
(args.norm(), sympy.atan2(args[1] - x.shape[1] / 2, args[0] - x.shape[0] / 2) / sympy.pi * x.shape[1] / 2))
def inv():
return lambda l: (sympy.Matrix((sympy.cos(l[1] * sympy.pi / x.shape[1]*2) * l[0],
sympy.sin(l[1] * sympy.pi / x.shape[1]*2) * l[0]))
return lambda l: (sympy.Matrix((sympy.cos(l[1] * sympy.pi / x.shape[1] * 2) * l[0],
sympy.sin(l[1] * sympy.pi / x.shape[1] * 2) * l[0]))
+ sympy.Matrix(x.shape) * 0.5)
lenna_file = join(dirname(__file__), "test_data", "lenna.png")
......@@ -251,3 +251,16 @@ def test_motion_model2():
# while True:
# sleep(100)
def test_spatial_derivative():
x, y = pystencils.fields('x, y: float32[2d]')
tx, ty = pystencils.fields('t_x, t_y: float32[2d]')
assignments = pystencils.AssignmentCollection({
y.center: x.interpolated_access((tx.center, 2 * ty.center))
})
backward_assignments = pystencils.autodiff.create_backward_assignments(assignments)
print("backward_assignments: " + str(backward_assignments))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment