Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
pycodegen
pystencils_autodiff
Commits
3b21e27a
Commit
3b21e27a
authored
Nov 29, 2019
by
Stephan Seitz
Browse files
Fix gradient calculation for Tensorflow
parent
9ed3556c
Changes
1
Hide whitespace changes
Inline
Side-by-side
src/pystencils_autodiff/backends/_tensorflow.py
View file @
3b21e27a
...
...
@@ -61,11 +61,13 @@ def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDi
backward_func
=
getattr
(
compiled_op
,
stringcase
.
snakecase
(
stringcase
.
pascalcase
(
"call_"
+
backward_ast
.
function_name
)))
grad_fields
=
[
f
for
f
in
autodiff_obj
.
backward_input_fields
if
f
not
in
autodiff_obj
.
forward_input_fields
]
def
gradient_calculation
(
op
,
grad
):
if
isinstance
(
grad
,
Iterable
):
def
gradient_calculation
(
op
,
*
grad
):
if
not
isinstance
(
grad
,
Iterable
):
grad
=
[
grad
]
return
backward_func
(
**
{
autodiff_obj
.
backward_input_fields
[
i
].
name
:
g
for
i
,
g
in
enumerate
(
grad
)},
return
backward_func
(
**
{
grad_fields
[
i
].
name
:
g
for
i
,
g
in
enumerate
(
grad
)},
**
{
autodiff_obj
.
forward_input_fields
[
i
].
name
:
inp
for
i
,
inp
in
enumerate
(
op
.
inputs
)
if
autodiff_obj
.
forward_input_fields
[
i
]
in
backward_ast
.
fields_accessed
})
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment