Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CEM #373

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open

CEM #373

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions genrl/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,7 @@
from genrl.agents.deep.sac.sac import SAC # noqa
from genrl.agents.deep.td3.td3 import TD3 # noqa
from genrl.agents.deep.vpg.vpg import VPG # noqa
from genrl.agents.modelbased.base import ModelBasedAgent # noqa
from genrl.agents.modelbased.cem.cem import CEM # noqa

from genrl.agents.bandits.multiarmed.base import MABAgent # noqa; noqa; noqa
Empty file.
47 changes: 47 additions & 0 deletions genrl/agents/modelbased/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from abc import ABC

import torch

from genrl.agents import BaseAgent


class Planner:
def __init__(self, initial_state, dynamics_model=None):
if dynamics_model is not None:
self.dynamics_model = dynamics_model
self.initial_state = initial_state

def _learn_dynamics_model(self, state):
raise NotImplementedError

def plan(self):
raise NotImplementedError

def execute_actions(self):
raise NotImplementedError


class ModelBasedAgent(BaseAgent):
def __init__(self, *args, planner=None, **kwargs):
super(ModelBasedAgent, self).__init__(*args, **kwargs)
self.planner = planner

def plan(self):
"""
To be used to plan out a sequence of actions
"""
if self.planner is not None:
raise ValueError("Provide a planner to plan for the environment")
self.planner.plan()

def generate_data(self):
"""
To be used to generate synthetic data via a model (may be learnt or specified beforehand)
"""
raise NotImplementedError

def value_equivalence(self, state_space):
"""
To be used for approximate value estimation methods e.g. Value Iteration Networks
"""
raise NotImplementedError
Empty file.
235 changes: 235 additions & 0 deletions genrl/agents/modelbased/cem/cem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from typing import Any, Dict

import numpy as np
import torch
import torch.nn.functional as F

from genrl.agents import ModelBasedAgent
from genrl.core import RolloutBuffer
from genrl.utils import get_env_properties, get_model, safe_mean


class CEM(ModelBasedAgent):
"""Cross Entropy method algorithm (CEM)

Attributes:
network (str): The type of network to be used
env (Environment): The environment the agent is supposed to act on
create_model (bool): Whether the model of the algo should be created when initialised
policy_layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network of the policy
lr_policy (float): learning rate of the policy
percentile (float): Top percentile of rewards to consider as elite
simulations_per_epoch (int): Number of simulations to perform before taking a gradient step
rollout_size (int): Capacity of the replay buffer
render (bool): Whether to render the environment or not
device (str): Hardware being used for training. Options:
["cuda" -> GPU, "cpu" -> CPU]
"""

def __init__(
self,
*args,
network: str = "mlp",
percentile: float = 70,
simulations_per_epoch: int = 1000,
rollout_size,
**kwargs
):
super(CEM, self).__init__(*args, **kwargs)
self.network = network
self.rollout_size = rollout_size
self.rollout = RolloutBuffer(self.rollout_size, self.env)
self.percentile = percentile
self.simulations_per_epoch = simulations_per_epoch

self._create_model()
self.empty_logs()

def _create_model(self):
"""Function to initialize the Policy

This will create the Policy net for the CEM agent
"""
self.state_dim, self.action_dim, discrete, action_lim = get_env_properties(
self.env, self.network
)
self.agent = get_model("p", self.network)(
self.state_dim,
self.action_dim,
self.policy_layers,
"V",
discrete,
action_lim,
).to(self.device)
self.optim = torch.optim.Adam(self.agent.parameters(), lr=self.lr_policy)

def plan(self):
"""Function to plan out one episode

Returns:
states (:obj:`list` of :obj:`torch.Tensor`): Batch of states the agent encountered in the episode
actions (:obj:`list` of :obj:`torch.Tensor`): Batch of actions the agent took in the episode
rewards (:obj:`torch.Tensor`): The episode reward obtained
"""
state = self.env.reset()
self.rollout.reset()
states, actions = self.collect_rollouts(state)
return (states, actions, self.rewards[-1])

def select_elites(self, states_batch, actions_batch, rewards_batch):
"""Function to select the elite states and elite actions based on the episode reward

Args:
states_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of states
actions_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of actions
rewards_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of rewards

Returns:
elite_states (:obj:`torch.Tensor`): Elite batch of states based on episode reward
elite_actions (:obj:`torch.Tensor`): Actions the agent took during the elite batch of states

"""
reward_threshold = np.percentile(rewards_batch, self.percentile)
elite_states = torch.cat(
[
s.unsqueeze(0).clone()
for i in range(len(states_batch))
if rewards_batch[i] >= reward_threshold
for s in states_batch[i]
],
dim=0,
)
elite_actions = torch.cat(
[
a.unsqueeze(0).clone()
for i in range(len(actions_batch))
if rewards_batch[i] >= reward_threshold
for a in actions_batch[i]
],
dim=0,
)

return elite_states, elite_actions

def select_action(self, state):
"""Select action given state

Action selection policy for the Cross Entropy agent

Args:
state (:obj:`torch.Tensor`): Current state of the agent

Returns:
action (:obj:`torch.Tensor`): Action taken by the agent
"""
state = torch.as_tensor(state).float()
action, dist = self.agent.get_action(state)
return action

def update_params(self):
"""Updates the the Policy network of the CEM agent

Function to update the policy network
"""
sess = [self.plan() for _ in range(self.simulations_per_epoch)]
batch_states, batch_actions, batch_rewards = zip(*sess)
elite_states, elite_actions = self.select_elites(
batch_states, batch_actions, batch_rewards
)
action_probs = self.agent.forward(elite_states.float().to(self.device))
loss = F.cross_entropy(
action_probs.view(-1, self.action_dim),
elite_actions.long().view(-1),
)
self.logs["crossentropy_loss"].append(loss.item())
loss.backward()
# torch.nn.utils.clip_grad_norm_(self.agent.parameters(), 0.5)
self.optim.step()

def get_traj_loss(self, values, dones):
# No need for this here
pass

def collect_rollouts(self, state: torch.Tensor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty similar to the OnPolicyAgent method. Shouldn't this return values and dones though? Not sure if this is a consequence of the algo.

"""Function to collect rollouts

Collects rollouts by playing the env like a human agent and inputs information into
the rollout buffer.

Args:
state (:obj:`torch.Tensor`): The starting state of the environment

Returns:
states (:obj:`list`): list of states the agent encountered during the episode
actions (:obj:`list`): list of actions the agent took in the corresponding states
"""
states = []
actions = []
for i in range(self.rollout_size):
action = self.select_action(state)
states.append(state)
actions.append(action)

next_state, reward, dones, _ = self.env.step(action)

if self.render:
self.env.render()

state = next_state

self.collect_rewards(dones, i)

if torch.any(dones.byte()):
break

return states, actions

def collect_rewards(self, dones: torch.Tensor, timestep: int):
"""Helper function to collect rewards

Runs through all the envs and collects rewards accumulated during rollouts

Args:
dones (:obj:`torch.Tensor`): Game over statuses of each environment
timestep (int): Timestep during rollout
"""
for i, done in enumerate(dones):
if done or timestep == self.rollout_size - 1:
self.rewards.append(self.env.episode_reward[i].detach().clone())

def get_hyperparams(self) -> Dict[str, Any]:
"""Get relevant hyperparameters to save

Returns:
hyperparams (:obj:`dict`): Hyperparameters to be saved
weights (:obj:`torch.Tensor`): Neural network weights
"""
hyperparams = {
"network": self.network,
"lr_policy": self.lr_policy,
"rollout_size": self.rollout_size,
}
return hyperparams, self.agent.state_dict()

def _load_weights(self, weights) -> None:
self.agent.load_state_dict(weights)

def get_logging_params(self) -> Dict[str, Any]:
"""Gets relevant parameters for logging

Returns:
logs (:obj:`dict`): Logging parameters for monitoring training
"""
logs = {
"crossentropy_loss": safe_mean(self.logs["crossentropy_loss"]),
"mean_reward": safe_mean(self.rewards),
}

self.empty_logs()
return logs

def empty_logs(self):
"""Empties logs"""
self.logs = {}
self.logs["crossentropy_loss"] = []
self.rewards = []
1 change: 1 addition & 0 deletions tests/test_agents/test_modelbased/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tests.test_agents.test_modelbased.test_cem import TestCEM
23 changes: 23 additions & 0 deletions tests/test_agents/test_modelbased/test_cem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import shutil

from genrl.agents import CEM
from genrl.environments import VectorEnv
from genrl.trainers import OnPolicyTrainer


class TestCEM:
def test_CEM(self):
env = VectorEnv("CartPole-v0", 1)
algo = CEM(
"mlp",
env,
percentile=70,
policy_layers=[100],
rollout_size=100,
simulations_per_epoch=100,
)
trainer = OnPolicyTrainer(
algo, env, log_mode=["csv"], logdir="./logs", epochs=1
)
trainer.train()
shutil.rmtree("./logs")