Commit 98eb75fe authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Automatically don't differentiate for indexVector

parent 41167d51
Pipeline #27620 failed with stage
in 1 minute and 24 seconds
......@@ -60,7 +60,8 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
# for every field create a corresponding diff field
diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in read_fields if f not in self._constant_fields}
for f in read_fields if (f not in self._constant_fields
and f.name not in self._constant_fields)}
diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in write_fields}
......@@ -77,7 +78,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
# TODO: simplify implementation. use matrix notation like in 'transposed' mode
for forward_read_field in self._forward_input_fields:
if forward_read_field in self._constant_fields:
if forward_read_field in self._constant_fields or forward_read_field.name in self._constant_fields:
continue
diff_read_field = diff_read_fields[forward_read_field]
......@@ -247,6 +248,7 @@ Backward:
self._forward_assignments = forward_assignments
self._backward_assignments = None
self._constant_fields = constant_fields
self._constant_fields += ['indexVector']
self._time_constant_fields = time_constant_fields
self._kwargs = kwargs
self.op_name = op_name
......@@ -393,7 +395,7 @@ Backward:
backward_assignments = []
for lhs, read_access in zip(diff_read_field_accesses, read_field_accesses):
# don't differentiate for constant fields
if read_access.field in self._constant_fields:
if read_access.field in self._constant_fields or read_access.field.name in self._constant_fields:
continue
rhs = sp.Matrix(sp.Matrix([e.rhs for e in forward_assignments])).diff(
......
......@@ -171,5 +171,5 @@ class PybindFunctionWrapping(JinjaCppFile):
super().__init__({'python_name': function_node.function_name,
'cpp_name': function_node.function_name,
'parameters': [p.symbol.name for p in function_node.get_parameters()
if hasattr(p.symbol, 'dtype') and not 'meshFunctor' in p.symbol.name]
if hasattr(p.symbol, 'dtype') and 'meshFunctor' not in p.symbol.name]
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment