Commit 3b21e27a authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Fix gradient calculation for Tensorflow

parent 9ed3556c
...@@ -61,11 +61,13 @@ def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDi ...@@ -61,11 +61,13 @@ def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDi
backward_func = getattr(compiled_op, stringcase.snakecase( backward_func = getattr(compiled_op, stringcase.snakecase(
stringcase.pascalcase("call_" + backward_ast.function_name))) stringcase.pascalcase("call_" + backward_ast.function_name)))
grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields]
def gradient_calculation(op, grad): def gradient_calculation(op, *grad):
if isinstance(grad, Iterable): if not isinstance(grad, Iterable):
grad = [grad] grad = [grad]
return backward_func(**{autodiff_obj.backward_input_fields[i].name: g for i, g in enumerate(grad)},
return backward_func(**{grad_fields[i].name: g for i, g in enumerate(grad)},
**{autodiff_obj.forward_input_fields[i].name: inp for i, inp in enumerate(op.inputs) **{autodiff_obj.forward_input_fields[i].name: inp for i, inp in enumerate(op.inputs)
if autodiff_obj.forward_input_fields[i] in backward_ast.fields_accessed}) if autodiff_obj.forward_input_fields[i] in backward_ast.fields_accessed})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment