Commit 2ac634ac authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Give torch ops nice names

parent 1302edec
Pipeline #27964 failed with stage
in 8 minutes and 1 second
......@@ -15,6 +15,7 @@ from pystencils_autodiff.backends import AVAILABLE_BACKENDS
from pystencils_autodiff.transformations import add_fixed_constant_boundary_handling
REMOVE_CASTS = ReplaceOptim(lambda x: isinstance(x, pystencils.data_types.cast_func), lambda x: x.args[0])
DEFAULT_OP_NAME = "autodiffop"
@pystencils.cache.disk_cache_no_fallback
......@@ -220,7 +221,7 @@ Backward:
def __init__(self,
forward_assignments: List[ps.Assignment],
op_name: str = "autodiffop",
op_name: str = DEFAULT_OP_NAME,
boundary_handling: AutoDiffBoundaryHandling = None,
time_constant_fields: List[ps.Field] = None,
constant_fields: List[ps.Field] = [],
......@@ -604,8 +605,8 @@ Backward:
def time_constant_fields(self):
return self._time_constant_fields
def create_torch_op(self, *args, **kwags):
return self.create_tensorflow_op(*args, backend='torch_native', **kwags)
def create_torch_op(self, *args, **kwargs):
return self.create_tensorflow_op(*args, backend='torch_native', **kwargs)
def create_tensorflow_op(self,
inputfield_tensor_dict={},
......@@ -685,7 +686,8 @@ Backward:
self, inputfield_tensor_dict, forward_loop, backward_loop)
elif backend == 'torch_native':
import pystencils_autodiff.backends._torch_native
op = pystencils_autodiff.backends._torch_native.create_autograd_function(self, use_cuda)
op = pystencils_autodiff.backends._torch_native.create_autograd_function(
self, use_cuda, op_name=self.op_name if self.op_name != DEFAULT_OP_NAME else None)
elif backend == 'tensorflow':
import pystencils_autodiff.backends._tensorflow
op = pystencils_autodiff.backends._tensorflow.tensorflowop_from_autodiffop(
......
......@@ -7,7 +7,7 @@ from pystencils_autodiff.backends.astnodes import TorchModule
from pystencils_autodiff.tensorflow_jit import _hash
def create_autograd_function(autodiff_obj, use_cuda):
def create_autograd_function(autodiff_obj, use_cuda, op_name=None):
import torch
field_to_tensor_dict = dict()
# Allocate output tensor for forward and backward pass
......@@ -24,10 +24,11 @@ def create_autograd_function(autodiff_obj, use_cuda):
forward_ast = autodiff_obj.forward_ast_cpu
backward_ast = autodiff_obj.backward_ast_cpu if autodiff_obj.backward_output_fields else None
op_name = f'{autodiff_obj.op_name}_{_hash((str(pystencils.show_code(forward_ast)) + str(autodiff_obj)+str(autodiff_obj.constant_fields)).encode()).hexdigest()}' # noqa
forward_ast.function_name = f'{op_name}_{forward_ast.function_name}'
if backward_ast:
backward_ast.function_name = f'{op_name}_{backward_ast.function_name}'
if not op_name:
op_name = f'{autodiff_obj.op_name}_{_hash((str(pystencils.get_code_str(forward_ast)) + str(autodiff_obj)+str(autodiff_obj.constant_fields)).encode()).hexdigest()}' # noqa
forward_ast.function_name = f'{op_name}_{forward_ast.function_name}'
if backward_ast:
backward_ast.function_name = f'{op_name}_{backward_ast.function_name}'
module = TorchModule(op_name, [forward_ast, backward_ast] if backward_ast else [forward_ast])
compiled_op = module.compile()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment