From ffb0d25e15a1e288cc75a4418e4c0907fa6e6101 Mon Sep 17 00:00:00 2001 From: Cody Balos Date: Tue, 8 Oct 2024 22:33:18 -0700 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Steven Roberts --- doc/arkode/guide/source/Mathematics.rst | 10 +++++----- doc/arkode/guide/source/Usage/ARKStep/ASA.rst | 3 +-- .../sunadjoint/SUNAdjointCheckpointScheme.rst | 18 +++++++++--------- doc/shared/sunadjoint/SUNAdjointStepper.rst | 2 +- doc/shared/sundials.bib | 2 +- 5 files changed, 17 insertions(+), 18 deletions(-) diff --git a/doc/arkode/guide/source/Mathematics.rst b/doc/arkode/guide/source/Mathematics.rst index 898b3bd034..a7987b07a6 100644 --- a/doc/arkode/guide/source/Mathematics.rst +++ b/doc/arkode/guide/source/Mathematics.rst @@ -2143,7 +2143,7 @@ for which we would like to compute the gradients :math:`\partial g(y(t_f),p)/\pa and/or :math:`\partial g(y(t_f),p)/\partial p`. The adjoint method is one approach to obtaining the gradients that is particularly efficient when there are relatively few functionals and a large number of parameters. With the adjoint method we solve the adjoint ODEs for :math:`\lambda(t) -\in \mathbb{R}^N` and :math:`\mu(t) \in \mathbb{R}^{N_s}`: +\in \mathbb{R}^N` and :math:`\mu(t) \in \mathbb{R}^{N_p}`: .. math:: \lambda'(t) &= -f_y^T(t, y, p) \lambda,\quad \lambda(t_F) = g_y^T(y(t_f), p), \\ @@ -2151,8 +2151,8 @@ large number of parameters. With the adjoint method we solve the adjoint ODEs fo :label: ARKODE_ADJOINT_ODE (For a detailed derivation see :cite:p:`hager2000runge,sanduDiscrete2006`). Here :math:`f_y \equiv -\partial f/\partial y` is the Jacobian with respect to the dependent variable and :math:`f_p \equiv -\partial f/\partial p` is the Jacobian with respect to the parameters. The ARKStep module in ARKODE +\partial f/\partial y \in \mathbb{R}^{N \times N}` is the Jacobian with respect to the dependent variable and :math:`f_p \equiv +\partial f/\partial p \in \mathbb{R}^{N \times N_p}` is the Jacobian with respect to the parameters. The ARKStep module in ARKODE provides adjoint sensitivity analysis based on the *discrete* formulation, i.e., given an s-stage explicit Runge--Kutta method (as in :eq:`ARKODE_ERK`, but without the embedding), the discrete adjoint to compute :math:`\lambda_n` and :math:`\mu_n` starting from :math:`\lambda_{n+1}` and @@ -2170,7 +2170,7 @@ After completing integration from :math:`t_n` all the way to :math:`t_0` using t formulation, the gradients are .. math:: - \frac{\partial g}{\partial y_{(t_0)}} = \lambda_0, \quad + \frac{\partial g}{\partial y(t_0)} = \lambda_0, \quad \frac{\partial g}{\partial p} = \mu_0 + \lambda_0 \left(\frac{\partial y(t_0)}{\partial p} \right). For more information on performing discrete adjoint sensitivity analysis see, :numref:`ARKODE.Usage.ARKStep.ASA`. @@ -2184,7 +2184,7 @@ sometimes used -- the *continuous* adjoint method. In the continuous approach, w sensitivity equations directly from the model and then we integrate them with a time integration method. This is the approach implemented in the SUNDIALS :ref:`CVODES ` and :ref:`IDAS ` packages. In the *discrete* approach, the model equations are -discretized with the time integration method first, and then we derive the adjoints of the +discretized with the time integration method first, and then we derive the sensitivities of the discretized equations. It is understood that the continuous adjoint method can be problematic in the context of optimization problems because the continuous adjoint method provides an approximation to the gradient of a continuous cost function while the optimizer is expecting the gradient of the diff --git a/doc/arkode/guide/source/Usage/ARKStep/ASA.rst b/doc/arkode/guide/source/Usage/ARKStep/ASA.rst index 19d8ba90e2..f964c8ac9d 100644 --- a/doc/arkode/guide/source/Usage/ARKStep/ASA.rst +++ b/doc/arkode/guide/source/Usage/ARKStep/ASA.rst @@ -7,8 +7,7 @@ The previous sections discuss using ARKStep for the integration of forward ODE m This section discusses how to use ARKStep for adjoint sensitivity analysis as introduced in :numref:`ARKODE.Mathematics.ASA`. To use ARKStep for ASA, users simply setup the forward integration as usual (following :numref:`ARKODE.Usage.Skeleton`) with one exception: -a :c:type:`SUNAdjointCheckpointScheme` object must be provided to the forward ARKStep stepper -creating a :c:type:`SUNAdjointCheckpointScheme` object and then calling +a :c:type:`SUNAdjointCheckpointScheme` object must be created and passed to :c:func:`ARKodeSetAdjointCheckpointScheme` before the call to the :c:func:`ARKodeEvolve` function. After the forward model integration code, a :c:type:`SUNAdjointStepper` object can be created for the adjoint model integration by calling :c:func:`ARKStepCreateAdjointStepper`. diff --git a/doc/shared/sunadjoint/SUNAdjointCheckpointScheme.rst b/doc/shared/sunadjoint/SUNAdjointCheckpointScheme.rst index 37f2ddac5e..50fa0ff8bf 100644 --- a/doc/shared/sunadjoint/SUNAdjointCheckpointScheme.rst +++ b/doc/shared/sunadjoint/SUNAdjointCheckpointScheme.rst @@ -44,7 +44,7 @@ A :c:type:`SUNAdjointCheckpointScheme` is a pointer to the .. c:struct:: SUNAdjointCheckpointScheme_Ops_ -. c:type:: struct SUNAdjointCheckpointScheme_Ops_ +.. c:type:: struct SUNAdjointCheckpointScheme_Ops_ .. c:member:: SUNErrCode (*shouldWeSave)(SUNAdjointCheckpointScheme, sunindextype step_num, sunindextype stage_num, sunrealtype t, sunbooleantype* yes_or_no) @@ -60,15 +60,15 @@ A :c:type:`SUNAdjointCheckpointScheme` is a pointer to the .. c:member:: SUNErrCode (*loadVector)(SUNAdjointCheckpointScheme, sunindextype step_num, sunindextype stage_num, sunbooleantype peek, N_Vector* out, sunrealtype* tout) - Function pointer to load a checkpoint state represented as a `N_Vector`. + Function pointer to load a checkpoint state represented as a :c:type:`N_Vector`. .. c:member:: SUNErrCode (*removeVector)(SUNAdjointCheckpointScheme, sunindextype step_num, sunindextype stage_num, N_Vector* out) - Function pointer to remove a checkpoint state represented as a `N_Vector`. + Function pointer to remove a checkpoint state represented as a :c:type:`N_Vector`. .. c:member:: SUNErrCode (*destroy)(SUNAdjointCheckpointScheme*) - Function pointer to destroy and free the memory for the `SUNAdjointCheckpointScheme` object. + Function pointer to destroy and free the memory for the :c:type:`SUNAdjointCheckpointScheme` object. .. c:member:: SUNErrCode (*enableDense)(SUNAdjointCheckpointScheme, sunbooleantype on_or_off) @@ -158,7 +158,7 @@ A :c:type:`SUNAdjointCheckpointScheme` is a pointer to the Enables or disables dense checkpointing (checkpointing every step/stage). :param cs: The :c:type:`SUNAdjointCheckpointScheme` object - :param on_or_off: if true, dense checkpointing will be turned on, ifalse it will be turned off. + :param on_or_off: if true, dense checkpointing will be turned on, if false it will be turned off. :return: A :c:type:`SUNErrCode` indicating failure or success. @@ -190,7 +190,7 @@ The SUNAdjointCheckpointScheme_Basic module has the following user-callable func .. c:function:: SUNErrCode SUNAdjointCheckpointScheme_Create_Basic(SUNDataIOMode io_mode, SUNMemoryHelper mem_helper, int64_t interval, int64_t estimate, sunbooleantype save_stages, sunbooleantype keep, SUNContext sunctx, SUNAdjointCheckpointScheme* check_scheme_ptr) - Creates a new `SUNAdjointCheckpointScheme` object that checkpoints at a fixed interval. + Creates a new :c:type:`SUNAdjointCheckpointScheme` object that checkpoints at a fixed interval. :param io_mode: The IO mode used for storing the checkpoints. :param mem_helper: Memory helper for managing memory. @@ -198,9 +198,9 @@ The SUNAdjointCheckpointScheme_Basic module has the following user-callable func :param estimate: An estimate of the total number of checkpoints needed. :param save_stages: If using a multistage method, should stages be saved with the step. :param keep: Keep data stored even after it is not needed anymore. - :param sunctx: The SUNContext for the simulation. + :param sunctx: The :c:type:`SUNContext` for the simulation. :param check_scheme_ptr: Pointer to the newly constructed object. - :return: A `SUNErrCode` indicating success or failure. + :return: A :c:type:`SUNErrCode` indicating success or failure. .. c:function:: SUNErrCode SUNAdjointCheckpointScheme_ShouldWeSave_Basic(SUNAdjointCheckpointScheme check_scheme, sunindextype step_num, sunindextype stage_num, sunrealtype t, sunbooleantype* yes_or_no) @@ -211,7 +211,7 @@ The SUNAdjointCheckpointScheme_Basic module has the following user-callable func :param stage_num: The current stage number (only nonzero for multistage methods). :param t: The current time. :param yes_or_no: On output, will be 1 if you should save, 0 otherwise. - :return: A `SUNErrCode` indicating success or failure. + :return: A :c:type:`SUNErrCode` indicating success or failure. .. c:function:: SUNErrCode SUNAdjointCheckpointScheme_InsertVector_Basic(SUNAdjointCheckpointScheme check_scheme, sunindextype step_num, sunindextype stage_num, sunrealtype t, N_Vector state) diff --git a/doc/shared/sunadjoint/SUNAdjointStepper.rst b/doc/shared/sunadjoint/SUNAdjointStepper.rst index f6358fc18a..1ea23f0efa 100644 --- a/doc/shared/sunadjoint/SUNAdjointStepper.rst +++ b/doc/shared/sunadjoint/SUNAdjointStepper.rst @@ -200,7 +200,7 @@ The :c:type:`SUNAdjointStepper` class has the following functions: :return: A :c:type:`SUNErrCode` indicating failure or success. -.. c:function:: SUNErrCode SUNAdjointStepper_SetJacTimesVecFn(SUNAdjointStepper adj_stepper, SUNJacTimesFn Jvp, SUNJacTimesFn JPvp) +.. c:function:: SUNErrCode SUNAdjointStepper_SetVecTimesJacFn(SUNAdjointStepper adj_stepper, SUNJacTimesFn Jvp, SUNJacTimesFn JPvp) Sets the function pointers to evaluate :math:`(df/dy)^T v` and :math:`(df/dp)^T v` diff --git a/doc/shared/sundials.bib b/doc/shared/sundials.bib index 39f9a2f17f..5cd1591b77 100644 --- a/doc/shared/sundials.bib +++ b/doc/shared/sundials.bib @@ -2407,7 +2407,7 @@ @article{hager2000runge @article{sanduDiscrete2006, year = {2006}, - title = {{On the Properties of Runge-Kutta Discrete Adjoints}}, + title = {On the Properties of {Runge-Kutta} Discrete Adjoints}, author = {Sandu, Adrian}, journal = {Lecture Notes in Computer Science}, issn = {0302-9743},