-
Notifications
You must be signed in to change notification settings - Fork 132
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Force differentiation w.r.t. all non-const pointer/array type paramet…
…ers. We have an old problem that we need the user to differentiate w.r.t. pointer/array parameters. In those cases, the user does not provide an adjoint for the variables and we have to initialize them ourselves. And while it's clear how to deal with numerical types, e.g. ``` f_grad_0(double x, double y, double* _d_x) { // the user only wants the derivative w.r.t. x double _d_y = 0; // so we initialize _d_y ourselves ... ``` we cannot do the same with array/pointer types: ``` f_grad_0(double x, double* y, double* _d_x) { // the user only wants the derivative w.r.t. x double* _d_y = ???; // we don't know the size of y and, therefore, cannot initialize _d_y ... ``` It's worth noting that while the user is not directly interested in ``_d_y`` from the last example, ``_d_y`` can be used in the derivative internally. However, we can guarantee that ``_d_y`` will not be useful if ``y`` is const. Here's why this is the case: In the reverse mode, "``y`` depends on ``x``" results in "``_d_x`` depends on ``_d_y``", e.g. ``` y = x; --> _d_x += _d_y; ``` Therefore, if ``y`` cannot be modified (i.e. is const), ``_d_y`` will not be used in the reverse pass. This PR makes a compromise and forces differentiation w.r.t. non-const array/pointer parameters while allowing not differentiating w.r.t. const ones.
- Loading branch information
1 parent
51670da
commit 2e32138
Showing
13 changed files
with
79 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1 | ||
|
||
#include "clad/Differentiator/Differentiator.h" | ||
|
||
float func11(float* a, float b) { // expected-error {{Non-differentiable non-const pointer and array parameters are not supported. Please differentiate w.r.t. 'a' or mark it const.}} | ||
float sum = 0; | ||
sum += a[0] *= b; | ||
return sum; | ||
} | ||
|
||
float func12(float a, float b[]) { // expected-error {{Non-differentiable non-const pointer and array parameters are not supported. Please differentiate w.r.t. 'b' or mark it const.}} | ||
float sum = 0; | ||
sum += a *= b[1]; | ||
return sum; | ||
} | ||
|
||
int main() { | ||
clad::gradient(func11, "b"); | ||
clad::gradient(func12, "a"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters