-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CEM #373
Open
hades-rp2010
wants to merge
18
commits into
SforAiDl:master
Choose a base branch
from
hades-rp2010:CEM
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
CEM #373
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
1d49049
Single actor critic shared params
hades-rp2010 ef4a179
Shared layers for multi ACs
hades-rp2010 2ecd086
Merge branch 'master' of https://github.com/SforAiDl/genrl
hades-rp2010 53450a8
Fix lint errors (1)
hades-rp2010 274aff9
Fixed tests
hades-rp2010 38f95f0
Changes to dicstrings and classes
hades-rp2010 835819e
Renaming Multi -> Two and comments
hades-rp2010 c94a9a1
Merge branch 'master' of https://github.com/SforAiDl/genrl
hades-rp2010 bf71710
Adding tutorial
hades-rp2010 fc356b9
Small change
hades-rp2010 844c53d
Index
hades-rp2010 d3830e0
Up to date
hades-rp2010 a90e8d0
CEM agent
hades-rp2010 6cb6d5c
Merge branch 'master' of https://github.com/SforAiDl/genrl into CEM
hades-rp2010 3b2067d
Training CEM without rollouts
hades-rp2010 f86b046
Fix Codacy (1)
hades-rp2010 f5a189d
Docstrings
hades-rp2010 4b11c16
Adding device
hades-rp2010 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from abc import ABC | ||
|
||
import torch | ||
|
||
from genrl.agents import BaseAgent | ||
|
||
|
||
class Planner: | ||
def __init__(self, initial_state, dynamics_model=None): | ||
if dynamics_model is not None: | ||
self.dynamics_model = dynamics_model | ||
self.initial_state = initial_state | ||
|
||
def _learn_dynamics_model(self, state): | ||
raise NotImplementedError | ||
|
||
def plan(self): | ||
raise NotImplementedError | ||
|
||
def execute_actions(self): | ||
raise NotImplementedError | ||
|
||
|
||
class ModelBasedAgent(BaseAgent): | ||
def __init__(self, *args, planner=None, **kwargs): | ||
super(ModelBasedAgent, self).__init__(*args, **kwargs) | ||
self.planner = planner | ||
|
||
def plan(self): | ||
""" | ||
To be used to plan out a sequence of actions | ||
""" | ||
if self.planner is not None: | ||
raise ValueError("Provide a planner to plan for the environment") | ||
self.planner.plan() | ||
|
||
def generate_data(self): | ||
""" | ||
To be used to generate synthetic data via a model (may be learnt or specified beforehand) | ||
""" | ||
raise NotImplementedError | ||
|
||
def value_equivalence(self, state_space): | ||
""" | ||
To be used for approximate value estimation methods e.g. Value Iteration Networks | ||
""" | ||
raise NotImplementedError |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
from typing import Any, Dict | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
from genrl.agents import ModelBasedAgent | ||
from genrl.core import RolloutBuffer | ||
from genrl.utils import get_env_properties, get_model, safe_mean | ||
|
||
|
||
class CEM(ModelBasedAgent): | ||
"""Cross Entropy method algorithm (CEM) | ||
|
||
Attributes: | ||
network (str): The type of network to be used | ||
env (Environment): The environment the agent is supposed to act on | ||
create_model (bool): Whether the model of the algo should be created when initialised | ||
policy_layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network of the policy | ||
lr_policy (float): learning rate of the policy | ||
percentile (float): Top percentile of rewards to consider as elite | ||
simulations_per_epoch (int): Number of simulations to perform before taking a gradient step | ||
rollout_size (int): Capacity of the replay buffer | ||
render (bool): Whether to render the environment or not | ||
device (str): Hardware being used for training. Options: | ||
["cuda" -> GPU, "cpu" -> CPU] | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*args, | ||
network: str = "mlp", | ||
percentile: float = 70, | ||
simulations_per_epoch: int = 1000, | ||
rollout_size, | ||
**kwargs | ||
): | ||
super(CEM, self).__init__(*args, **kwargs) | ||
self.network = network | ||
self.rollout_size = rollout_size | ||
self.rollout = RolloutBuffer(self.rollout_size, self.env) | ||
self.percentile = percentile | ||
self.simulations_per_epoch = simulations_per_epoch | ||
|
||
self._create_model() | ||
self.empty_logs() | ||
|
||
def _create_model(self): | ||
"""Function to initialize the Policy | ||
|
||
This will create the Policy net for the CEM agent | ||
""" | ||
self.state_dim, self.action_dim, discrete, action_lim = get_env_properties( | ||
self.env, self.network | ||
) | ||
self.agent = get_model("p", self.network)( | ||
self.state_dim, | ||
self.action_dim, | ||
self.policy_layers, | ||
"V", | ||
discrete, | ||
action_lim, | ||
).to(self.device) | ||
self.optim = torch.optim.Adam(self.agent.parameters(), lr=self.lr_policy) | ||
|
||
def plan(self): | ||
"""Function to plan out one episode | ||
|
||
Returns: | ||
states (:obj:`list` of :obj:`torch.Tensor`): Batch of states the agent encountered in the episode | ||
actions (:obj:`list` of :obj:`torch.Tensor`): Batch of actions the agent took in the episode | ||
rewards (:obj:`torch.Tensor`): The episode reward obtained | ||
""" | ||
state = self.env.reset() | ||
self.rollout.reset() | ||
states, actions = self.collect_rollouts(state) | ||
return (states, actions, self.rewards[-1]) | ||
|
||
def select_elites(self, states_batch, actions_batch, rewards_batch): | ||
"""Function to select the elite states and elite actions based on the episode reward | ||
|
||
Args: | ||
states_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of states | ||
actions_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of actions | ||
rewards_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of rewards | ||
|
||
Returns: | ||
elite_states (:obj:`torch.Tensor`): Elite batch of states based on episode reward | ||
elite_actions (:obj:`torch.Tensor`): Actions the agent took during the elite batch of states | ||
|
||
""" | ||
reward_threshold = np.percentile(rewards_batch, self.percentile) | ||
elite_states = torch.cat( | ||
[ | ||
s.unsqueeze(0).clone() | ||
for i in range(len(states_batch)) | ||
if rewards_batch[i] >= reward_threshold | ||
for s in states_batch[i] | ||
], | ||
dim=0, | ||
) | ||
elite_actions = torch.cat( | ||
[ | ||
a.unsqueeze(0).clone() | ||
for i in range(len(actions_batch)) | ||
if rewards_batch[i] >= reward_threshold | ||
for a in actions_batch[i] | ||
], | ||
dim=0, | ||
) | ||
|
||
return elite_states, elite_actions | ||
|
||
def select_action(self, state): | ||
"""Select action given state | ||
|
||
Action selection policy for the Cross Entropy agent | ||
|
||
Args: | ||
state (:obj:`torch.Tensor`): Current state of the agent | ||
|
||
Returns: | ||
action (:obj:`torch.Tensor`): Action taken by the agent | ||
""" | ||
state = torch.as_tensor(state).float() | ||
action, dist = self.agent.get_action(state) | ||
return action | ||
|
||
def update_params(self): | ||
"""Updates the the Policy network of the CEM agent | ||
|
||
Function to update the policy network | ||
""" | ||
sess = [self.plan() for _ in range(self.simulations_per_epoch)] | ||
batch_states, batch_actions, batch_rewards = zip(*sess) | ||
elite_states, elite_actions = self.select_elites( | ||
batch_states, batch_actions, batch_rewards | ||
) | ||
action_probs = self.agent.forward(elite_states.float().to(self.device)) | ||
loss = F.cross_entropy( | ||
action_probs.view(-1, self.action_dim), | ||
elite_actions.long().view(-1), | ||
) | ||
self.logs["crossentropy_loss"].append(loss.item()) | ||
loss.backward() | ||
# torch.nn.utils.clip_grad_norm_(self.agent.parameters(), 0.5) | ||
self.optim.step() | ||
|
||
def get_traj_loss(self, values, dones): | ||
# No need for this here | ||
pass | ||
|
||
def collect_rollouts(self, state: torch.Tensor): | ||
"""Function to collect rollouts | ||
|
||
Collects rollouts by playing the env like a human agent and inputs information into | ||
the rollout buffer. | ||
|
||
Args: | ||
state (:obj:`torch.Tensor`): The starting state of the environment | ||
|
||
Returns: | ||
states (:obj:`list`): list of states the agent encountered during the episode | ||
actions (:obj:`list`): list of actions the agent took in the corresponding states | ||
""" | ||
states = [] | ||
actions = [] | ||
for i in range(self.rollout_size): | ||
action = self.select_action(state) | ||
states.append(state) | ||
actions.append(action) | ||
|
||
next_state, reward, dones, _ = self.env.step(action) | ||
|
||
if self.render: | ||
self.env.render() | ||
|
||
state = next_state | ||
|
||
self.collect_rewards(dones, i) | ||
|
||
if torch.any(dones.byte()): | ||
break | ||
|
||
return states, actions | ||
|
||
def collect_rewards(self, dones: torch.Tensor, timestep: int): | ||
"""Helper function to collect rewards | ||
|
||
Runs through all the envs and collects rewards accumulated during rollouts | ||
|
||
Args: | ||
dones (:obj:`torch.Tensor`): Game over statuses of each environment | ||
timestep (int): Timestep during rollout | ||
""" | ||
for i, done in enumerate(dones): | ||
if done or timestep == self.rollout_size - 1: | ||
self.rewards.append(self.env.episode_reward[i].detach().clone()) | ||
|
||
def get_hyperparams(self) -> Dict[str, Any]: | ||
"""Get relevant hyperparameters to save | ||
|
||
Returns: | ||
hyperparams (:obj:`dict`): Hyperparameters to be saved | ||
weights (:obj:`torch.Tensor`): Neural network weights | ||
""" | ||
hyperparams = { | ||
"network": self.network, | ||
"lr_policy": self.lr_policy, | ||
"rollout_size": self.rollout_size, | ||
} | ||
return hyperparams, self.agent.state_dict() | ||
|
||
def _load_weights(self, weights) -> None: | ||
self.agent.load_state_dict(weights) | ||
|
||
def get_logging_params(self) -> Dict[str, Any]: | ||
"""Gets relevant parameters for logging | ||
|
||
Returns: | ||
logs (:obj:`dict`): Logging parameters for monitoring training | ||
""" | ||
logs = { | ||
"crossentropy_loss": safe_mean(self.logs["crossentropy_loss"]), | ||
"mean_reward": safe_mean(self.rewards), | ||
} | ||
|
||
self.empty_logs() | ||
return logs | ||
|
||
def empty_logs(self): | ||
"""Empties logs""" | ||
self.logs = {} | ||
self.logs["crossentropy_loss"] = [] | ||
self.rewards = [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from tests.test_agents.test_modelbased.test_cem import TestCEM |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import shutil | ||
|
||
from genrl.agents import CEM | ||
from genrl.environments import VectorEnv | ||
from genrl.trainers import OnPolicyTrainer | ||
|
||
|
||
class TestCEM: | ||
def test_CEM(self): | ||
env = VectorEnv("CartPole-v0", 1) | ||
algo = CEM( | ||
"mlp", | ||
env, | ||
percentile=70, | ||
policy_layers=[100], | ||
rollout_size=100, | ||
simulations_per_epoch=100, | ||
) | ||
trainer = OnPolicyTrainer( | ||
algo, env, log_mode=["csv"], logdir="./logs", epochs=1 | ||
) | ||
trainer.train() | ||
shutil.rmtree("./logs") |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty similar to the
OnPolicyAgent
method. Shouldn't this returnvalues
anddones
though? Not sure if this is a consequence of the algo.