Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
pycodegen
pystencils_autodiff
Commits
98eb75fe
Commit
98eb75fe
authored
Oct 26, 2020
by
Stephan Seitz
Browse files
Automatically don't differentiate for indexVector
parent
41167d51
Pipeline
#27620
failed with stage
in 1 minute and 24 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/pystencils_autodiff/_autodiff.py
View file @
98eb75fe
...
...
@@ -60,7 +60,8 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
# for every field create a corresponding diff field
diff_read_fields
=
{
f
:
pystencils_autodiff
.
AdjointField
(
f
,
diff_fields_prefix
)
for
f
in
read_fields
if
f
not
in
self
.
_constant_fields
}
for
f
in
read_fields
if
(
f
not
in
self
.
_constant_fields
and
f
.
name
not
in
self
.
_constant_fields
)}
diff_write_fields
=
{
f
:
pystencils_autodiff
.
AdjointField
(
f
,
diff_fields_prefix
)
for
f
in
write_fields
}
...
...
@@ -77,7 +78,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
# TODO: simplify implementation. use matrix notation like in 'transposed' mode
for
forward_read_field
in
self
.
_forward_input_fields
:
if
forward_read_field
in
self
.
_constant_fields
:
if
forward_read_field
in
self
.
_constant_fields
or
forward_read_field
.
name
in
self
.
_constant_fields
:
continue
diff_read_field
=
diff_read_fields
[
forward_read_field
]
...
...
@@ -247,6 +248,7 @@ Backward:
self
.
_forward_assignments
=
forward_assignments
self
.
_backward_assignments
=
None
self
.
_constant_fields
=
constant_fields
self
.
_constant_fields
+=
[
'indexVector'
]
self
.
_time_constant_fields
=
time_constant_fields
self
.
_kwargs
=
kwargs
self
.
op_name
=
op_name
...
...
@@ -393,7 +395,7 @@ Backward:
backward_assignments
=
[]
for
lhs
,
read_access
in
zip
(
diff_read_field_accesses
,
read_field_accesses
):
# don't differentiate for constant fields
if
read_access
.
field
in
self
.
_constant_fields
:
if
read_access
.
field
in
self
.
_constant_fields
or
read_access
.
field
.
name
in
self
.
_constant_fields
:
continue
rhs
=
sp
.
Matrix
(
sp
.
Matrix
([
e
.
rhs
for
e
in
forward_assignments
])).
diff
(
...
...
src/pystencils_autodiff/backends/python_bindings.py
View file @
98eb75fe
...
...
@@ -171,5 +171,5 @@ class PybindFunctionWrapping(JinjaCppFile):
super
().
__init__
({
'python_name'
:
function_node
.
function_name
,
'cpp_name'
:
function_node
.
function_name
,
'parameters'
:
[
p
.
symbol
.
name
for
p
in
function_node
.
get_parameters
()
if
hasattr
(
p
.
symbol
,
'dtype'
)
and
not
'meshFunctor'
in
p
.
symbol
.
name
]
if
hasattr
(
p
.
symbol
,
'dtype'
)
and
'meshFunctor'
not
in
p
.
symbol
.
name
]
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment