Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
pycodegen
pystencils_autodiff
Commits
2ac634ac
Commit
2ac634ac
authored
Nov 13, 2020
by
Stephan Seitz
Browse files
Give torch ops nice names
parent
1302edec
Pipeline
#27964
failed with stage
in 8 minutes and 1 second
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/pystencils_autodiff/_autodiff.py
View file @
2ac634ac
...
...
@@ -15,6 +15,7 @@ from pystencils_autodiff.backends import AVAILABLE_BACKENDS
from
pystencils_autodiff.transformations
import
add_fixed_constant_boundary_handling
REMOVE_CASTS
=
ReplaceOptim
(
lambda
x
:
isinstance
(
x
,
pystencils
.
data_types
.
cast_func
),
lambda
x
:
x
.
args
[
0
])
DEFAULT_OP_NAME
=
"autodiffop"
@
pystencils
.
cache
.
disk_cache_no_fallback
...
...
@@ -220,7 +221,7 @@ Backward:
def
__init__
(
self
,
forward_assignments
:
List
[
ps
.
Assignment
],
op_name
:
str
=
"autodiffop"
,
op_name
:
str
=
DEFAULT_OP_NAME
,
boundary_handling
:
AutoDiffBoundaryHandling
=
None
,
time_constant_fields
:
List
[
ps
.
Field
]
=
None
,
constant_fields
:
List
[
ps
.
Field
]
=
[],
...
...
@@ -604,8 +605,8 @@ Backward:
def
time_constant_fields
(
self
):
return
self
.
_time_constant_fields
def
create_torch_op
(
self
,
*
args
,
**
kwags
):
return
self
.
create_tensorflow_op
(
*
args
,
backend
=
'torch_native'
,
**
kwags
)
def
create_torch_op
(
self
,
*
args
,
**
kwa
r
gs
):
return
self
.
create_tensorflow_op
(
*
args
,
backend
=
'torch_native'
,
**
kwa
r
gs
)
def
create_tensorflow_op
(
self
,
inputfield_tensor_dict
=
{},
...
...
@@ -685,7 +686,8 @@ Backward:
self
,
inputfield_tensor_dict
,
forward_loop
,
backward_loop
)
elif
backend
==
'torch_native'
:
import
pystencils_autodiff.backends._torch_native
op
=
pystencils_autodiff
.
backends
.
_torch_native
.
create_autograd_function
(
self
,
use_cuda
)
op
=
pystencils_autodiff
.
backends
.
_torch_native
.
create_autograd_function
(
self
,
use_cuda
,
op_name
=
self
.
op_name
if
self
.
op_name
!=
DEFAULT_OP_NAME
else
None
)
elif
backend
==
'tensorflow'
:
import
pystencils_autodiff.backends._tensorflow
op
=
pystencils_autodiff
.
backends
.
_tensorflow
.
tensorflowop_from_autodiffop
(
...
...
src/pystencils_autodiff/backends/_torch_native.py
View file @
2ac634ac
...
...
@@ -7,7 +7,7 @@ from pystencils_autodiff.backends.astnodes import TorchModule
from
pystencils_autodiff.tensorflow_jit
import
_hash
def
create_autograd_function
(
autodiff_obj
,
use_cuda
):
def
create_autograd_function
(
autodiff_obj
,
use_cuda
,
op_name
=
None
):
import
torch
field_to_tensor_dict
=
dict
()
# Allocate output tensor for forward and backward pass
...
...
@@ -24,10 +24,11 @@ def create_autograd_function(autodiff_obj, use_cuda):
forward_ast
=
autodiff_obj
.
forward_ast_cpu
backward_ast
=
autodiff_obj
.
backward_ast_cpu
if
autodiff_obj
.
backward_output_fields
else
None
op_name
=
f
'
{
autodiff_obj
.
op_name
}
_
{
_hash
((
str
(
pystencils
.
show_code
(
forward_ast
))
+
str
(
autodiff_obj
)
+
str
(
autodiff_obj
.
constant_fields
)).
encode
()).
hexdigest
()
}
'
# noqa
forward_ast
.
function_name
=
f
'
{
op_name
}
_
{
forward_ast
.
function_name
}
'
if
backward_ast
:
backward_ast
.
function_name
=
f
'
{
op_name
}
_
{
backward_ast
.
function_name
}
'
if
not
op_name
:
op_name
=
f
'
{
autodiff_obj
.
op_name
}
_
{
_hash
((
str
(
pystencils
.
get_code_str
(
forward_ast
))
+
str
(
autodiff_obj
)
+
str
(
autodiff_obj
.
constant_fields
)).
encode
()).
hexdigest
()
}
'
# noqa
forward_ast
.
function_name
=
f
'
{
op_name
}
_
{
forward_ast
.
function_name
}
'
if
backward_ast
:
backward_ast
.
function_name
=
f
'
{
op_name
}
_
{
backward_ast
.
function_name
}
'
module
=
TorchModule
(
op_name
,
[
forward_ast
,
backward_ast
]
if
backward_ast
else
[
forward_ast
])
compiled_op
=
module
.
compile
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment