Skip to content

Commit

Permalink
Jannis Becktepe: Merge branch 'main' of https://github.com/automl/arl…
Browse files Browse the repository at this point in the history
  • Loading branch information
Github Actions committed May 30, 2024
1 parent 46d9bb4 commit 9e32a39
Show file tree
Hide file tree
Showing 21 changed files with 257 additions and 221 deletions.
Binary file modified main/.doctrees/api/arlbench.core.algorithms.doctree
Binary file not shown.
Binary file modified main/.doctrees/api/arlbench.core.algorithms.dqn.doctree
Binary file not shown.
Binary file modified main/.doctrees/api/arlbench.core.algorithms.dqn.dqn.doctree
Binary file not shown.
Binary file modified main/.doctrees/api/arlbench.core.algorithms.sac.doctree
Binary file not shown.
Binary file modified main/.doctrees/api/arlbench.core.algorithms.sac.sac.doctree
Binary file not shown.
Binary file modified main/.doctrees/arlbench.core.algorithms.doctree
Binary file not shown.
Binary file modified main/.doctrees/arlbench.core.algorithms.dqn.doctree
Binary file not shown.
Binary file modified main/.doctrees/arlbench.core.algorithms.sac.doctree
Binary file not shown.
Binary file modified main/.doctrees/environment.pickle
Binary file not shown.
84 changes: 48 additions & 36 deletions main/_modules/arlbench/core/algorithms/dqn/dqn.html

Large diffs are not rendered by default.

16 changes: 11 additions & 5 deletions main/_modules/arlbench/core/algorithms/ppo/ppo.html
Original file line number Diff line number Diff line change
Expand Up @@ -995,11 +995,14 @@ <h1>Source code for arlbench.core.algorithms.ppo.ppo</h1><div class="highlight">
<span class="w"> </span><span class="sd">&quot;&quot;&quot;One epoch of network updates using minibatches of the current transition batch.</span>

<span class="sd"> Args:</span>
<span class="sd"> update_state (tuple[PPOTrainState, Transition, jnp.ndarray, jnp.ndarray, chex.PRNGKey]): (train_state, transition_batch, advantages, targets, rng) Current update state.</span>
<span class="sd"> update_state (tuple[PPOTrainState, Transition, jnp.ndarray, jnp.ndarray, chex.PRNGKey]): </span>
<span class="sd"> (train_state, transition_batch, advantages, targets, rng) Current update state.</span>
<span class="sd"> _ (None): Unused parameter (required for jax.lax.scan).</span>

<span class="sd"> Returns:</span>
<span class="sd"> tuple[tuple[PPOTrainState, Transition, jnp.ndarray, jnp.ndarray, chex.PRNGKey], tuple[tuple | None, tuple | None]]: Tuple of (train_state, transition_batch, advantages, targets, rng) and (loss, grads) if tracked.</span>
<span class="sd"> tuple[tuple[PPOTrainState, Transition, jnp.ndarray, jnp.ndarray, chex.PRNGKey],</span>
<span class="sd"> tuple[tuple | None, tuple | None]]: Tuple of (train_state, transition_batch, </span>
<span class="sd"> advantages, targets, rng) and (loss, grads) if tracked.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">train_state</span><span class="p">,</span> <span class="n">traj_batch</span><span class="p">,</span> <span class="n">advantages</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">rng</span> <span class="o">=</span> <span class="n">update_state</span>
<span class="n">rng</span><span class="p">,</span> <span class="n">_rng</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng</span><span class="p">)</span>
Expand Down Expand Up @@ -1061,10 +1064,12 @@ <h1>Source code for arlbench.core.algorithms.ppo.ppo</h1><div class="highlight">

<span class="sd"> Args:</span>
<span class="sd"> train_state (PPOTrainState): PPO training state.</span>
<span class="sd"> batch_info (tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]): Minibatch of transitions, advantages and targets.</span>
<span class="sd"> batch_info (tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]): </span>
<span class="sd"> Minibatch of transitions, advantages and targets.</span>

<span class="sd"> Returns:</span>
<span class="sd"> tuple[PPOTrainState, tuple[tuple | None, tuple | None]]: Tuple of PPO train state and (loss, grads) if tracked.</span>
<span class="sd"> tuple[PPOTrainState, tuple[tuple | None, tuple | None]]: </span>
<span class="sd"> Tuple of PPO train state and (loss, grads) if tracked.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">traj_batch</span><span class="p">,</span> <span class="n">advantages</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">batch_info</span>

Expand Down Expand Up @@ -1093,7 +1098,8 @@ <h1>Source code for arlbench.core.algorithms.ppo.ppo</h1><div class="highlight">
<span class="sd"> targets (jnp.ndarray): Targets.</span>

<span class="sd"> Returns:</span>
<span class="sd"> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]: Tuple of (total_loss, (value_loss, actor_loss, entropy)).</span>
<span class="sd"> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]: </span>
<span class="sd"> Tuple of (total_loss, (value_loss, actor_loss, entropy)).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># Rerun network</span>
<span class="n">pi</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">traj_batch</span><span class="o">.</span><span class="n">obs</span><span class="p">)</span>
Expand Down
154 changes: 83 additions & 71 deletions main/_modules/arlbench/core/algorithms/sac/sac.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion main/api/arlbench.core.algorithms.dqn.dqn.html
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>runner_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.dqn.html#arlbench.core.algorithms.dqn.dqn.DQNRunnerState" title="arlbench.core.algorithms.dqn.dqn.DQNRunnerState"><em>DQNRunnerState</em></a>) – DQN runner state.</p></li>
<li><p><strong>_</strong> (<em>None</em>) – Unused parameter (buffer_state in other algorithms).</p></li>
<li><p><strong>buffer_state</strong> (<em>PrioritisedTrajectoryBufferState</em>) – Buffer state.</p></li>
<li><p><strong>n_total_timesteps</strong> (<em>int</em><em>, </em><em>optional</em>) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.</p></li>
<li><p><strong>n_eval_steps</strong> (<em>int</em><em>, </em><em>optional</em>) – Number of evaluation steps during training. Defaults to 100.</p></li>
<li><p><strong>n_eval_episodes</strong> (<em>int</em><em>, </em><em>optional</em>) – Number of evaluation episodes per evaluation during training. Defaults to 10.</p></li>
Expand Down
2 changes: 1 addition & 1 deletion main/api/arlbench.core.algorithms.dqn.html
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>runner_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.dqn.html#arlbench.core.algorithms.dqn.DQNRunnerState" title="arlbench.core.algorithms.dqn.DQNRunnerState"><em>DQNRunnerState</em></a>) – DQN runner state.</p></li>
<li><p><strong>_</strong> (<em>None</em>) – Unused parameter (buffer_state in other algorithms).</p></li>
<li><p><strong>buffer_state</strong> (<em>PrioritisedTrajectoryBufferState</em>) – Buffer state.</p></li>
<li><p><strong>n_total_timesteps</strong> (<em>int</em><em>, </em><em>optional</em>) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.</p></li>
<li><p><strong>n_eval_steps</strong> (<em>int</em><em>, </em><em>optional</em>) – Number of evaluation steps during training. Defaults to 100.</p></li>
<li><p><strong>n_eval_episodes</strong> (<em>int</em><em>, </em><em>optional</em>) – Number of evaluation episodes per evaluation during training. Defaults to 10.</p></li>
Expand Down
37 changes: 19 additions & 18 deletions main/api/arlbench.core.algorithms.html
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,7 @@
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>runner_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.dqn.html#arlbench.core.algorithms.dqn.DQNRunnerState" title="arlbench.core.algorithms.dqn.DQNRunnerState"><em>DQNRunnerState</em></a>) – DQN runner state.</p></li>
<li><p><strong>_</strong> (<em>None</em>) – Unused parameter (buffer_state in other algorithms).</p></li>
<li><p><strong>buffer_state</strong> (<em>PrioritisedTrajectoryBufferState</em>) – Buffer state.</p></li>
<li><p><strong>n_total_timesteps</strong> (<em>int</em><em>, </em><em>optional</em>) – Total number of training timesteps. Update steps = n_total_timesteps // n_envs. Defaults to 1000000.</p></li>
<li><p><strong>n_eval_steps</strong> (<em>int</em><em>, </em><em>optional</em>) – Number of evaluation steps during training. Defaults to 100.</p></li>
<li><p><strong>n_eval_episodes</strong> (<em>int</em><em>, </em><em>optional</em>) – Number of evaluation episodes per evaluation during training. Defaults to 10.</p></li>
Expand Down Expand Up @@ -1527,15 +1527,16 @@
<dl class="py method">
<dt class="sig sig-object py" id="arlbench.core.algorithms.SAC.update_actor">
<span class="sig-name descname"><span class="pre">update_actor</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">actor_train_state</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">critic_train_state</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">alpha_train_state</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">experience</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">is_weights</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">rng</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/arlbench/core/algorithms/sac/sac.html#SAC.update_actor"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#arlbench.core.algorithms.SAC.update_actor" title="Link to this definition"></a></dt>
<dd><p>_summary_.</p>
<dd><p>Updates the actor network parameters.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>actor_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – _description_</p></li>
<li><p><strong>critic_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – _description_</p></li>
<li><p><strong>alpha_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – _description_</p></li>
<li><p><strong>batch</strong> (<a class="reference internal" href="../arlbench.core.algorithms.dqn.html#arlbench.core.algorithms.dqn.dqn.Transition" title="arlbench.core.algorithms.dqn.dqn.Transition"><em>Transition</em></a>) – _description_</p></li>
<li><p><strong>rng</strong> (<em>chex.PRNGKey</em>) – _description_</p></li>
<li><p><strong>actor_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – Actor train state.</p></li>
<li><p><strong>critic_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – Critic train state.</p></li>
<li><p><strong>alpha_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – Alpha train state.</p></li>
<li><p><strong>experience</strong> (<a class="reference internal" href="../arlbench.core.algorithms.html#arlbench.core.algorithms.common.TimeStep" title="arlbench.core.algorithms.common.TimeStep"><em>TimeStep</em></a>) – Experience (batch of TimeSteps).</p></li>
<li><p><strong>is_weights</strong> (<em>jnp.ndarray</em>) – Whether to use weights for PER or not.</p></li>
<li><p><strong>rng</strong> (<em>chex.PRNGKey</em>) – Random number generator key.</p></li>
</ul>
</dd>
<dt class="field-even">Returns<span class="colon">:</span></dt>
Expand All @@ -1550,16 +1551,16 @@
<dl class="py method">
<dt class="sig sig-object py" id="arlbench.core.algorithms.SAC.update_alpha">
<span class="sig-name descname"><span class="pre">update_alpha</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">alpha_train_state</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">entropy</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/arlbench/core/algorithms/sac/sac.html#SAC.update_alpha"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#arlbench.core.algorithms.SAC.update_alpha" title="Link to this definition"></a></dt>
<dd><p>_summary_.</p>
<dd><p>Update alpha network parameters.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>alpha_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – _description_</p></li>
<li><p><strong>entropy</strong> (<em>jnp.ndarray</em>) – _description_</p></li>
<li><p><strong>alpha_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – Alpha training state.</p></li>
<li><p><strong>entropy</strong> (<em>jnp.ndarray</em>) – Entropy values.</p></li>
</ul>
</dd>
<dt class="field-even">Returns<span class="colon">:</span></dt>
<dd class="field-even"><p>_description_</p>
<dd class="field-even"><p>Updated trainingi state and metrics.</p>
</dd>
<dt class="field-odd">Return type<span class="colon">:</span></dt>
<dd class="field-odd"><p>tuple[<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState">SACTrainState</a>, jnp.ndarray]</p>
Expand All @@ -1570,19 +1571,19 @@
<dl class="py method">
<dt class="sig sig-object py" id="arlbench.core.algorithms.SAC.update_critic">
<span class="sig-name descname"><span class="pre">update_critic</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">actor_train_state</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">critic_train_state</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">alpha_train_state</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">experience</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">is_weights</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">rng</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/arlbench/core/algorithms/sac/sac.html#SAC.update_critic"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#arlbench.core.algorithms.SAC.update_critic" title="Link to this definition"></a></dt>
<dd><p>_summary_.</p>
<dd><p>Updates the critic network parameters.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>actor_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – _description_</p></li>
<li><p><strong>critic_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – _description_</p></li>
<li><p><strong>alpha_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – _description_</p></li>
<li><p><strong>experience</strong> (<a class="reference internal" href="../arlbench.core.algorithms.dqn.html#arlbench.core.algorithms.dqn.dqn.Transition" title="arlbench.core.algorithms.dqn.dqn.Transition"><em>Transition</em></a>) – _description_</p></li>
<li><p><strong>rng</strong> (<em>chex.PRNGKey</em>) – _description_</p></li>
<li><p><strong>actor_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – Actor train state.</p></li>
<li><p><strong>critic_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – Critic train state.</p></li>
<li><p><strong>alpha_train_state</strong> (<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState"><em>SACTrainState</em></a>) – Alpha train state.</p></li>
<li><p><strong>experience</strong> (<a class="reference internal" href="../arlbench.core.algorithms.dqn.html#arlbench.core.algorithms.dqn.dqn.Transition" title="arlbench.core.algorithms.dqn.dqn.Transition"><em>Transition</em></a>) – Experience (batch of transitions).</p></li>
<li><p><strong>rng</strong> (<em>chex.PRNGKey</em>) – Random number generator key.</p></li>
</ul>
</dd>
<dt class="field-even">Returns<span class="colon">:</span></dt>
<dd class="field-even"><p>_description_</p>
<dd class="field-even"><p>Updated training state and metrics.</p>
</dd>
<dt class="field-odd">Return type<span class="colon">:</span></dt>
<dd class="field-odd"><p>tuple[<a class="reference internal" href="../arlbench.core.algorithms.sac.html#arlbench.core.algorithms.sac.sac.SACTrainState" title="arlbench.core.algorithms.sac.sac.SACTrainState">SACTrainState</a>, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey]</p>
Expand Down
Loading

0 comments on commit 9e32a39

Please sign in to comment.