Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New A2C example with entropy #26

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion cherry/td.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def discount(gamma, rewards, dones, bootstrap=0.0):

msg = 'dones and rewards must have equal length.'
assert rewards.size(0) == dones.size(0), msg
R = th.zeros_like(rewards[0]) + bootstrap
R = th.zeros_like(rewards) + bootstrap
discounted = th.zeros_like(rewards)
length = discounted.size(0)
for t in reversed(range(length)):
Expand Down
223 changes: 121 additions & 102 deletions examples/actor_critic_cartpole.py
Original file line number Diff line number Diff line change
@@ -1,112 +1,131 @@
#!/usr/bin/env python3

"""
Simple example of using cherry to solve cartpole with an actor-critic.

The code is an adaptation of the PyTorch reinforcement learning example.
"""

import random
import torch
import cherry
import gym
import numpy as np

from itertools import count
import statistics

NUM_ENVS = 6
STEPS = 5
TRAIN_STEPS = int(1e4)

class A2C(torch.nn.Module):
def __init__(self, num_envs):
super(A2C, self).__init__()

self.num_envs = num_envs
self.gamma = 0.99
self.vf_coef = 0.25
self.ent_coef = 0.01
self.max_clip_norm = 0.5

def select_action(self, state):
probs, value = self(state)
mass = torch.distributions.Categorical(probs)
action = mass.sample()
# Return selected action, logprob, value estimation and categorical entropy
return action, {"log_prob": mass.log_prob(action), "value": value, "entropy": mass.entropy()}


def learn_step(self, replay, optimizer):
policy_loss = []
value_loss = []
entropy_loss = []

# Discount rewards and boostrap them with the estimation from the next state
last_action, last_value = self(replay.next_state()[-1,:,:])
# Boostrap from zero if it is a terminal state
last_value = (last_value[:, 0]*(1 - replay.done()[-1]))

rewards = cherry.td.discount(self.gamma, replay.reward(), replay.done(), last_value)
for sars, reward in zip(replay, rewards):
log_prob = sars.log_prob.view(self.num_envs, -1)
value = sars.value.view(self.num_envs, -1)
entropy = sars.entropy.view(self.num_envs, -1)
reward = reward.view(self.num_envs, -1)

# Compute advantage
advantage = reward - value

# Compute policy gradient loss
# (advantage.detach() because you do not have to backward on the advantage path)
policy_loss.append(-log_prob * advantage.detach())
# Compute value estimation loss
value_loss.append((reward - value)**2)
# Compute entropy loss
entropy_loss.append(entropy)


# Compute means over accumulated errors
value_loss = torch.stack(value_loss).mean()
policy_loss = torch.stack(policy_loss).mean()
entropy_loss = torch.stack(entropy_loss).mean()

# Take an optimization step
optimizer.zero_grad()
loss = policy_loss + self.vf_coef * value_loss - self.ent_coef * entropy_loss
loss.backward()
# Clip gradients
torch.nn.utils.clip_grad_norm_(self.parameters(), self.max_clip_norm)
optimizer.step()




class A2CPolicy(A2C):
def __init__(self, state_size, action_size, num_envs):
super(A2CPolicy, self).__init__(num_envs)
self.state_size = state_size
self.action_size = action_size
self.n_hidden = 128

# Backbone net
self.net = torch.nn.Sequential(
torch.nn.Linear(self.state_size, self.n_hidden),
torch.nn.LeakyReLU(),
torch.nn.Linear(self.n_hidden, self.n_hidden),
torch.nn.LeakyReLU(),
)

# Action head (policy gradient)
self.action_head = torch.nn.Sequential(
torch.nn.Linear(self.n_hidden, self.action_size),
torch.nn.Softmax(dim=1)
)

# Value estimation head (A2C)
self.value_head = torch.nn.Sequential(
torch.nn.Linear(self.n_hidden, 1),
)

import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import cherry.envs as envs
from cherry.td import discount
from cherry import normalize
import cherry.distributions as distributions

SEED = 567
GAMMA = 0.99
RENDER = False
V_WEIGHT = 0.5

random.seed(SEED)
np.random.seed(SEED)
th.manual_seed(SEED)


class ActorCriticNet(nn.Module):
def __init__(self, env):
super(ActorCriticNet, self).__init__()
self.affine1 = nn.Linear(env.state_size, 128)
self.action_head = nn.Linear(128, env.action_size)
self.value_head = nn.Linear(128, 1)
self.distribution = distributions.ActionDistribution(env,
use_probs=True)

def forward(self, x):
x = F.relu(self.affine1(x))
action_scores = self.action_head(x)
action_mass = self.distribution(F.softmax(action_scores, dim=1))
value = self.value_head(x)
return action_mass, value


def update(replay, optimizer):
policy_loss = []
value_loss = []

# Discount and normalize rewards
rewards = discount(GAMMA, replay.reward(), replay.done())
rewards = normalize(rewards)

# Compute losses
for sars, reward in zip(replay, rewards):
log_prob = sars.log_prob
value = sars.value
policy_loss.append(-log_prob * (reward - value.item()))
value_loss.append(F.mse_loss(value, reward.detach()))

# Take optimization step
optimizer.zero_grad()
loss = th.stack(policy_loss).sum() + V_WEIGHT * th.stack(value_loss).sum()
loss.backward()
optimizer.step()


def get_action_value(state, policy):
mass, value = policy(state)
action = mass.sample()
info = {
'log_prob': mass.log_prob(action), # Cache log_prob for later
'value': value
}
return action, info

# Return both the action probabilities and the value estimations
return self.action_head(self.net(x)), self.value_head(self.net(x))

if __name__ == '__main__':
env = gym.vector.make('CartPole-v0', num_envs=1)
env = envs.Logger(env, interval=1000)
env = envs.Torch(env)
env = envs.Runner(env)
env.seed(SEED)

policy = ActorCriticNet(env)
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
running_reward = 10.0
get_action = lambda state: get_action_value(state, policy)

for episode in count(1):
# We use the Runner collector, but could've written our own
replay = env.run(get_action, episodes=1)

# Update policy
update(replay, optimizer)

# Compute termination criterion
running_reward = running_reward * 0.99 + len(replay) * 0.01
if episode % 10 == 0:
# Should start with 10.41, 12.21, 14.60, then 100:71.30, 200:135.74
print(episode, running_reward)
if running_reward > 190.0:
print('Solved! Running reward now {} and '
'the last episode runs to {} time steps!'.format(running_reward,
len(replay)))
break
env = gym.vector.make('CartPole-v0', num_envs=NUM_ENVS)
env = cherry.envs.Logger(env, interval=1000)
env = cherry.envs.Torch(env)

policy = A2CPolicy(env.state_size, env.action_size, NUM_ENVS)
optimizer = torch.optim.RMSprop(policy.parameters(), lr=7e-4, eps=1e-5, alpha=0.99)

state = env.reset()
for train_step in range(0, TRAIN_STEPS):
replay = cherry.ExperienceReplay()
for step in range(0, STEPS):
action, info = policy.select_action(state)
new_state, reward, done, _ = env.step(action)
replay.append(state, action, reward, new_state, done, **info)
state = new_state

policy.learn_step(replay, optimizer)

env = gym.make('CartPole-v0')
env = cherry.envs.Torch(env)
env = cherry.envs.Runner(env)
while True:
env.run(lambda state: policy.select_action(state), episodes=1, render=True)