Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Commit

Permalink
RL in Large Discrete Action Spaces - Wolpertinger Agent (#394)
Browse files Browse the repository at this point in the history
* Currently this is specific to the case of discretizing a continuous action space. Can easily be adapted to other case by feeding the kNN otherwise, and removing the usage of a discretizing output action filter
  • Loading branch information
Gal Leibovich authored Sep 8, 2019
1 parent fc50398 commit 138ced2
Show file tree
Hide file tree
Showing 46 changed files with 1,193 additions and 51 deletions.
Binary file modified docs/_images/algorithms.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_images/wolpertinger.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/_modules/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ <h1>All modules for which code is available</h1>
<li><a href="rl_coach/agents/soft_actor_critic_agent.html">rl_coach.agents.soft_actor_critic_agent</a></li>
<li><a href="rl_coach/agents/td3_agent.html">rl_coach.agents.td3_agent</a></li>
<li><a href="rl_coach/agents/value_optimization_agent.html">rl_coach.agents.value_optimization_agent</a></li>
<li><a href="rl_coach/agents/wolpertinger_agent.html">rl_coach.agents.wolpertinger_agent</a></li>
<li><a href="rl_coach/architectures/architecture.html">rl_coach.architectures.architecture</a></li>
<li><a href="rl_coach/architectures/network_wrapper.html">rl_coach.architectures.network_wrapper</a></li>
<li><a href="rl_coach/base_parameters.html">rl_coach.base_parameters</a></li>
Expand Down
15 changes: 12 additions & 3 deletions docs/_modules/rl_coach/agents/agent.html
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,9 @@ <h1>Source code for rl_coach.agents.agent</h1><div class="highlight"><pre>

<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">!=</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="p">,</span> <span class="n">EpisodicExperienceReplay</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">override_episode_rewards_with_the_last_transition_reward</span><span class="p">:</span>
<span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">transitions</span><span class="p">:</span>
<span class="n">t</span><span class="o">.</span><span class="n">reward</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">transitions</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">reward</span>
<span class="bp">self</span><span class="o">.</span><span class="n">call_memory</span><span class="p">(</span><span class="s1">&#39;store_episode&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">store_transitions_only_when_episodes_are_terminated</span><span class="p">:</span>
<span class="k">for</span> <span class="n">transition</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">current_episode_buffer</span><span class="o">.</span><span class="n">transitions</span><span class="p">:</span>
Expand Down Expand Up @@ -910,7 +913,8 @@ <h1>Source code for rl_coach.agents.agent</h1><div class="highlight"><pre>
<span class="c1"># update counters</span>
<span class="bp">self</span><span class="o">.</span><span class="n">training_iteration</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">update_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span>
<span class="n">batch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_internal_state</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>

<span class="c1"># if the batch returned empty then there are not enough samples in the replay buffer -&gt; skip</span>
<span class="c1"># training step</span>
Expand Down Expand Up @@ -1020,7 +1024,8 @@ <h1>Source code for rl_coach.agents.agent</h1><div class="highlight"><pre>
<span class="c1"># informed action</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># before choosing an action, first use the pre_network_filter to filter out the current state</span>
<span class="n">update_filter_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span>
<span class="n">update_filter_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span> <span class="ow">and</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">RunPhase</span><span class="o">.</span><span class="n">TEST</span>
<span class="n">curr_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">curr_state</span><span class="p">,</span> <span class="n">update_filter_internal_state</span><span class="p">)</span>

<span class="k">else</span><span class="p">:</span>
Expand Down Expand Up @@ -1048,6 +1053,10 @@ <h1>Source code for rl_coach.agents.agent</h1><div class="highlight"><pre>
<span class="sd"> :return: The filtered state</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">dummy_env_response</span> <span class="o">=</span> <span class="n">EnvResponse</span><span class="p">(</span><span class="n">next_state</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">reward</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>

<span class="c1"># TODO actually we only want to run the observation filters. No point in running the reward filters as the</span>
<span class="c1"># filtered reward is being ignored anyway (and it might unncecessarily affect the reward filters&#39; internal</span>
<span class="c1"># state).</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dummy_env_response</span><span class="p">,</span>
<span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_filter_internal_state</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span></div>

Expand Down Expand Up @@ -1177,7 +1186,7 @@ <h1>Source code for rl_coach.agents.agent</h1><div class="highlight"><pre>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Allows setting a directive for the agent to follow. This is useful in hierarchy structures, where the agent</span>
<span class="sd"> has another master agent that is controlling it. In such cases, the master agent can define the goals for the</span>
<span class="sd"> slave agent, define it&#39;s observation, possible actions, etc. The directive type is defined by the agent</span>
<span class="sd"> slave agent, define its observation, possible actions, etc. The directive type is defined by the agent</span>
<span class="sd"> in-action-space.</span>

<span class="sd"> :param action: The action that should be set as the directive</span>
Expand Down
12 changes: 9 additions & 3 deletions docs/_modules/rl_coach/agents/clipped_ppo_agent.html
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,9 @@ <h1>Source code for rl_coach.agents.clipped_ppo_agent</h1><div class="highlight"
<span class="bp">self</span><span class="o">.</span><span class="n">optimization_epochs</span> <span class="o">=</span> <span class="mi">10</span>
<span class="bp">self</span><span class="o">.</span><span class="n">normalization_stats</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">clipping_decay_schedule</span> <span class="o">=</span> <span class="n">ConstantSchedule</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">act_for_full_episodes</span> <span class="o">=</span> <span class="kc">True</span></div>
<span class="bp">self</span><span class="o">.</span><span class="n">act_for_full_episodes</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span> <span class="o">=</span> <span class="kc">True</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span> <span class="o">=</span> <span class="kc">False</span></div>


<span class="k">class</span> <span class="nc">ClippedPPOAgentParameters</span><span class="p">(</span><span class="n">AgentParameters</span><span class="p">):</span>
Expand Down Expand Up @@ -486,7 +488,9 @@ <h1>Source code for rl_coach.agents.clipped_ppo_agent</h1><div class="highlight"
<span class="n">network</span><span class="o">.</span><span class="n">set_is_training</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>

<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory</span><span class="o">.</span><span class="n">transitions</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">update_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_train</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">deep_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_internal_state</span><span class="p">)</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>

<span class="k">for</span> <span class="n">training_step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">num_consecutive_training_steps</span><span class="p">):</span>
Expand All @@ -512,7 +516,9 @@ <h1>Source code for rl_coach.agents.clipped_ppo_agent</h1><div class="highlight"

<span class="k">def</span> <span class="nf">run_pre_network_filter_for_inference</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">:</span> <span class="n">StateType</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="p">:</span> <span class="nb">bool</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="n">dummy_env_response</span> <span class="o">=</span> <span class="n">EnvResponse</span><span class="p">(</span><span class="n">next_state</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">reward</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">game_over</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">dummy_env_response</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="kc">False</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span>
<span class="n">update_internal_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">update_pre_network_filters_state_on_inference</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_network_filter</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span>
<span class="n">dummy_env_response</span><span class="p">,</span> <span class="n">update_internal_state</span><span class="o">=</span><span class="n">update_internal_state</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">next_state</span>

<span class="k">def</span> <span class="nf">choose_action</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">curr_state</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ap</span><span class="o">.</span><span class="n">algorithm</span><span class="o">.</span><span class="n">clipping_decay_schedule</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
Expand Down
Loading

0 comments on commit 138ced2

Please sign in to comment.