Commit 8b7bdb96 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Make all fields == not time_constant_fields default

parent 1e141dfb
Pipeline #20179 failed with stage
in 4 minutes and 6 seconds
......@@ -85,7 +85,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
w in zip(ra.offsets, forward_assignment.lhs.offsets))
diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \
diff_write_field[inverted_offset](*diff_write_index)
if forward_read_field in self._time_constant_fields:
if forward_read_field in self._time_constant_fields and self.time_constant_fields is not None:
# Accumulate in case of time_constant_fields
assignment = ps.Assignment(
diff_read_field.center(), diff_read_field.center() + diff_read_field_sum)
......@@ -202,7 +202,7 @@ Backward:
forward_assignments: List[ps.Assignment],
op_name: str = "autodiffop",
boundary_handling: AutoDiffBoundaryHandling = None,
time_constant_fields: List[ps.Field] = [],
time_constant_fields: List[ps.Field] = None,
constant_fields: List[ps.Field] = [],
diff_fields_prefix='diff', # TODO: remove!
do_common_subexpression_elimination=True,
......@@ -241,6 +241,7 @@ Backward:
self._backward_kernel_gpu = None
self._do_common_subexpression_elimination = do_common_subexpression_elimination
self._boundary_handling = boundary_handling
if backward_assignments:
self._forward_assignments = forward_assignments
self._forward_read_accesses = None
......@@ -367,7 +368,7 @@ Backward:
rhs = rhs[0, 0]
# if field is constant over we time we can accumulate in assignment
if read_access.field in self._time_constant_fields:
if read_access.field in self._time_constant_fields and self.time_constant_fields is not None:
backward_assignments.append(ps.Assignment(lhs, lhs + rhs))
else:
backward_assignments.append(ps.Assignment(lhs, rhs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment