Commit 08641367 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Make pystencils_reco.AssignmentCollection.kernel a property

parent ad5ebbe8
Pipeline #21671 failed with stage
in 5 minutes and 26 seconds
......@@ -105,7 +105,7 @@ class AssignmentCollection(pystencils.AssignmentCollection):
self.args = []
self.kwargs = {}
self._autodiff = None
self.kernel = None
self._kernel = None
# @property
# def reproducible_hash(self):
......@@ -118,6 +118,15 @@ class AssignmentCollection(pystencils.AssignmentCollection):
# def __getstate__(self):
# return self.reproducible_hash
@property
def kernel(self):
if not self._kernel:
self.compile()
return self._kernel
def __call__(self, *args, **kwargs):
return self.kernel(*args, **kwargs)
def compile(self, target=None, *args, **kwargs):
"""Convenience wrapper for pystencils.create_kernel(...).compile()
See :func: ~pystencils.create_kernel
......@@ -157,6 +166,7 @@ class AssignmentCollection(pystencils.AssignmentCollection):
else:
kernel.__call__ = partial(kernel, **self.kwargs)
self._kernel = kernel
return kernel
def backward(self):
......@@ -177,18 +187,18 @@ class AssignmentCollection(pystencils.AssignmentCollection):
def _create_ml_op(self, backend, target, **kwargs):
if not target:
target = 'gpu'
constant_field_names = [f for f, t in kwargs.items()
if hasattr(t, 'requires_grad') and not t.requires_grad]
constant_fields = {f for f in self.free_fields if f.name in constant_field_names}
# constant_field_names = [f for f, t in kwargs.items()
# if hasattr(t, 'requires_grad') and not t.requires_grad]
# constant_fields = {f for f in self.free_fields if f.name in constant_field_names}
for n in [f for f, t in kwargs.items() if hasattr(t, 'requires_grad')]:
kwargs.pop(n)
if not self._autodiff:
if hasattr(self, '_create_autodiff'):
self._create_autodiff(constant_fields, **kwargs)
self._create_autodiff(**kwargs)
else:
self._autodiff = _create_autodiff(self, constant_fields, **kwargs)
self._autodiff = _create_autodiff(self, **kwargs)
op = self._autodiff.create_tensorflow_op(backend=backend, use_cuda=(target == 'gpu'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment