Commit 32e36a69 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Allow injecting class definitions into Python bindings

parent 32bb1d31
Pipeline #27122 failed with stage
in 8 minutes and 20 seconds
......@@ -102,7 +102,12 @@ class TorchModule(JinjaCppFile):
def backend(self):
return 'gpucuda' if self.is_cuda else 'c'
def __init__(self, module_name, kernel_asts, with_python_bindings=True, wrap_wrapper_functions=False):
def __init__(self,
module_name,
kernel_asts,
with_python_bindings=True,
wrap_wrapper_functions=False,
class_definitions=[]):
"""Create a C++ module with forward and optional backward_kernels
:param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
......@@ -125,7 +130,7 @@ class TorchModule(JinjaCppFile):
'module_name': module_name,
'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name,
[self.PYTHON_FUNCTION_WRAPPING_CLASS(a)
for a in wrapper_functions])
for a in wrapper_functions] + class_definitions)
if with_python_bindings else ''
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment