Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateless schedules #345

Open
wants to merge 18 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions hive/agents/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from hive.agents.qnets.utils import InitializationFn
from hive.agents.td3 import TD3
from hive.replays import BaseReplayBuffer
from hive.utils.loggers import Logger
from hive.utils.utils import LossFn, OptimizerFn


Expand Down Expand Up @@ -33,7 +32,6 @@ def __init__(
reward_clip: float = None,
soft_update_fraction: float = 0.005,
batch_size: int = 64,
logger: Logger = None,
log_frequency: int = 100,
update_frequency: int = 1,
action_noise: float = 0,
Expand Down Expand Up @@ -112,7 +110,6 @@ def __init__(
reward_clip=reward_clip,
soft_update_fraction=soft_update_fraction,
batch_size=batch_size,
logger=logger,
log_frequency=log_frequency,
update_frequency=update_frequency,
policy_update_frequency=1,
Expand Down
87 changes: 48 additions & 39 deletions hive/agents/dqn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import copy
import os
from collections import deque

import gymnasium as gym
import numpy as np
import torch
from gymnasium.vector.utils.numpy_utils import create_empty_array

from hive.agents.agent import Agent
from hive.agents.qnets.base import FunctionApproximator
Expand All @@ -13,8 +15,9 @@
calculate_output_dim,
create_init_weights_fn,
)
from hive.agents.utils import roll_state
from hive.replays import BaseReplayBuffer, CircularReplayBuffer
from hive.utils.loggers import Logger, NullLogger
from hive.utils.loggers import logger
from hive.utils.schedule import (
LinearSchedule,
PeriodicSchedule,
Expand Down Expand Up @@ -53,7 +56,6 @@ def __init__(
min_replay_history: int = 5000,
batch_size: int = 32,
device="cpu",
logger: Logger = None,
log_frequency: int = 100,
):
"""
Expand Down Expand Up @@ -108,6 +110,7 @@ def __init__(
super().__init__(
observation_space=observation_space, action_space=action_space, id=id
)
self._stack_size = stack_size
self._state_size = (
stack_size * self._observation_space.shape[0],
*self._observation_space.shape[1:],
Expand All @@ -124,6 +127,11 @@ def __init__(
self._replay_buffer = replay_buffer(
observation_shape=self._observation_space.shape,
observation_dtype=self._observation_space.dtype,
action_shape=self._action_space.shape,
action_dtype=self._action_space.dtype,
stack_size=stack_size,
gamma=discount_rate,
n_step=n_step,
)
self._discount_rate = discount_rate**n_step
self._grad_clip = grad_clip
Expand All @@ -134,13 +142,7 @@ def __init__(
loss_fn = torch.nn.SmoothL1Loss
self._loss_fn = loss_fn(reduction="none")
self._batch_size = batch_size
self._logger = logger
if self._logger is None:
self._logger = NullLogger([])
self._timescale = self.id
self._logger.register_timescale(
self._timescale, PeriodicSchedule(False, True, log_frequency)
)
self._log_schedule = PeriodicSchedule(False, True, log_frequency)
if update_period_schedule is None:
self._update_period_schedule = PeriodicSchedule(False, True, 1)
else:
Expand Down Expand Up @@ -189,6 +191,21 @@ def eval(self):
self._qnet.eval()
self._target_qnet.eval()

def preprocess_observation(self, observation, agent_traj_state):
if agent_traj_state is None:
observation_stack = create_empty_array(
self._observation_space, n=self._stack_size
)
else:
observation_stack = agent_traj_state["observation_stack"]
observation_stack = roll_state(observation, observation_stack)
state = (
torch.tensor(observation_stack, device=self._device, dtype=torch.float32)
.flatten(0, 1)
.unsqueeze(0)
)
return state, observation_stack

def preprocess_update_info(self, update_info):
"""Preprocesses the :obj:`update_info` before it goes into the replay buffer.
Clips the reward in update_info.
Expand All @@ -208,9 +225,8 @@ def preprocess_update_info(self, update_info):
"reward": update_info["reward"],
"terminated": update_info["terminated"],
"truncated": update_info["truncated"],
"source": update_info["source"],
}
if "agent_id" in update_info:
preprocessed_update_info["agent_id"] = int(update_info["agent_id"])

return preprocessed_update_info

Expand All @@ -231,37 +247,35 @@ def preprocess_update_batch(self, batch):
return (batch["observation"],), (batch["next_observation"],), batch

@torch.no_grad()
def act(self, observation, agent_traj_state=None):
def act(self, observation, agent_traj_state, global_step):
"""Returns the action for the agent. If in training mode, follows an epsilon
greedy policy. Otherwise, returns the action with the highest Q-value.

Args:
observation: The current observation.
agent_traj_state: Contains necessary state information for the agent
to process current trajectory. This should be updated and returned.

Returns:
- action
- agent trajectory state
"""

# Determine and log the value of epsilon
if self._training:
if not self._learn_schedule.get_value():
if not self._learn_schedule(global_step):
epsilon = 1.0
else:
epsilon = self._epsilon_schedule.update()
if self._logger.update_step(self._timescale):
self._logger.log_scalar("epsilon", epsilon, self._timescale)
epsilon = self._epsilon_schedule(global_step)
if self._log_schedule(global_step):
logger.log_scalar("epsilon", epsilon, self.id)
else:
epsilon = self._test_epsilon

state, observation_stack = self.preprocess_observation(
observation, agent_traj_state
)

# Sample action. With epsilon probability choose random action,
# otherwise select the action with the highest q-value.
observation = torch.tensor(
np.expand_dims(observation, axis=0), device=self._device
).float()
qvals = self._qnet(observation)
qvals = self._qnet(state)
if self._rng.random() < epsilon:
action = self._rng.integers(self._action_space.n)
else:
Expand All @@ -270,14 +284,14 @@ def act(self, observation, agent_traj_state=None):

if (
self._training
and self._logger.should_log(self._timescale)
and self._log_schedule(global_step)
and agent_traj_state is None
):
self._logger.log_scalar("train_qval", torch.max(qvals), self._timescale)
agent_traj_state = {}
logger.log_scalar("train_qval", torch.max(qvals), self.id)
agent_traj_state = {"observation_stack": observation_stack}
return action, agent_traj_state

def update(self, update_info, agent_traj_state=None):
def update(self, update_info, agent_traj_state, global_step):
"""
Updates the DQN agent.

Expand All @@ -297,15 +311,16 @@ def update(self, update_info, agent_traj_state=None):
return

# Add the most recent transition to the replay buffer.
self._replay_buffer.add(**self.preprocess_update_info(update_info))
transition = self.preprocess_update_info(update_info)
self._replay_buffer.add(**transition)

# Update the q network based on a sample batch from the replay buffer.
# If the replay buffer doesn't have enough samples, catch the exception
# and move on.
if (
self._learn_schedule.update()
self._learn_schedule(global_step)
and self._replay_buffer.size() > 0
and self._update_period_schedule.update()
and self._update_period_schedule(global_step)
):
batch = self._replay_buffer.sample(batch_size=self._batch_size)
(
Expand All @@ -330,8 +345,8 @@ def update(self, update_info, agent_traj_state=None):

loss = self._loss_fn(pred_qvals, q_targets).mean()

if self._logger.should_log(self._timescale):
self._logger.log_scalar("train_loss", loss, self._timescale)
if self._log_schedule(global_step):
logger.log_scalar("train_loss", loss, self.id)

loss.backward()
if self._grad_clip is not None:
Expand All @@ -341,7 +356,7 @@ def update(self, update_info, agent_traj_state=None):
self._optimizer.step()

# Update target network
if self._target_net_update_schedule.update():
if self._target_net_update_schedule(global_step):
self._update_target()
return agent_traj_state

Expand All @@ -368,9 +383,6 @@ def save(self, dname):
"qnet": self._qnet.state_dict(),
"target_qnet": self._target_qnet.state_dict(),
"optimizer": self._optimizer.state_dict(),
"learn_schedule": self._learn_schedule,
"epsilon_schedule": self._epsilon_schedule,
"target_net_update_schedule": self._target_net_update_schedule,
"rng": self._rng,
},
os.path.join(dname, "agent.pt"),
Expand All @@ -384,8 +396,5 @@ def load(self, dname):
self._qnet.load_state_dict(checkpoint["qnet"])
self._target_qnet.load_state_dict(checkpoint["target_qnet"])
self._optimizer.load_state_dict(checkpoint["optimizer"])
self._learn_schedule = checkpoint["learn_schedule"]
self._epsilon_schedule = checkpoint["epsilon_schedule"]
self._target_net_update_schedule = checkpoint["target_net_update_schedule"]
self._rng = checkpoint["rng"]
self._replay_buffer.load(os.path.join(dname, "replay"))
64 changes: 35 additions & 29 deletions hive/agents/drqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
apply_to_tensor,
)
from hive.replays.recurrent_replay import RecurrentReplayBuffer
from hive.utils.loggers import Logger, NullLogger
from hive.utils.loggers import logger
from hive.utils.schedule import (
LinearSchedule,
PeriodicSchedule,
Expand Down Expand Up @@ -54,7 +54,6 @@ def __init__(
min_replay_history: int = 5000,
batch_size: int = 32,
device="cpu",
logger: Logger = None,
log_frequency: int = 100,
store_hidden: bool = True,
burn_frames: int = 0,
Expand Down Expand Up @@ -157,13 +156,7 @@ def __init__(
loss_fn = torch.nn.SmoothL1Loss
self._loss_fn = loss_fn(reduction="none")
self._batch_size = batch_size
self._logger = logger
if self._logger is None:
self._logger = NullLogger([])
self._timescale = self.id
self._logger.register_timescale(
self._timescale, PeriodicSchedule(False, True, log_frequency)
)
self._log_schedule = PeriodicSchedule(False, True, log_frequency)
if update_period_schedule is None:
self._update_period_schedule = PeriodicSchedule(False, True, 1)
else:
Expand Down Expand Up @@ -207,6 +200,18 @@ def create_q_networks(self, representation_net, sequence_fn):
self._qnet.apply(self._init_fn)
self._target_qnet = copy.deepcopy(self._qnet).requires_grad_(False)

def preprocess_observation(self, observation, agent_traj_state):
# Reset hidden state if it is episode beginning.
if agent_traj_state is None:
hidden_state = self._qnet.init_hidden(batch_size=1)
else:
hidden_state = agent_traj_state["hidden_state"]

state = torch.tensor(
np.expand_dims(observation, axis=(0, 1)), device=self._device
).float()
return state, hidden_state

def preprocess_update_info(self, update_info, hidden_state):
"""Preprocesses the :obj:`update_info` before it goes into the replay buffer.
Clips the reward in update_info.
Expand Down Expand Up @@ -265,7 +270,7 @@ def preprocess_update_batch(self, batch):
return (batch["observation"]), (batch["next_observation"]), batch

@torch.no_grad()
def act(self, observation, agent_traj_state=None):
def act(self, observation, agent_traj_state, global_step):
"""Returns the action for the agent. If in training mode, follows an epsilon
greedy policy. Otherwise, returns the action with the highest Q-value.

Expand All @@ -280,37 +285,38 @@ def act(self, observation, agent_traj_state=None):

# Determine and log the value of epsilon
if self._training:
if not self._learn_schedule.get_value():
if not self._learn_schedule(global_step):
epsilon = 1.0
else:
epsilon = self._epsilon_schedule.update()
if self._logger.update_step(self._timescale):
self._logger.log_scalar("epsilon", epsilon, self._timescale)
epsilon = self._epsilon_schedule(global_step)
if self._log_schedule(global_step):
logger.log_scalar("epsilon", epsilon, self.id)
else:
epsilon = self._test_epsilon

# Sample action. With epsilon probability choose random action,
# otherwise select the action with the highest q-value.
# Insert batch_size and sequence_len dimensions to observation
observation = torch.tensor(
np.expand_dims(observation, axis=(0, 1)), device=self._device
).float()
hidden_state = (
None if agent_traj_state is None else agent_traj_state["hidden_state"]
)
qvals, hidden_state = self._qnet(observation, hidden_state)
state, hidden_state = self.preprocess_observation(observation, agent_traj_state)
qvals, hidden_state = self._qnet(state, hidden_state)
if self._rng.random() < epsilon:
action = self._rng.integers(self._action_space.n)
else:
# Note: not explicitly handling the ties
action = torch.argmax(qvals).item()
if agent_traj_state is None:
if self._training and self._logger.should_log(self._timescale):
self._logger.log_scalar("train_qval", torch.max(qvals), self._timescale)
if self._training and self._log_schedule(global_step):
logger.log_scalar("train_qval", torch.max(qvals), self.id)

if (
self._training
and self._log_schedule(global_step)
and agent_traj_state is None
):
logger.log_scalar("train_qval", torch.max(qvals), self.id)
return action, {"hidden_state": hidden_state}

def update(self, update_info, agent_traj_state=None):
def update(self, update_info, agent_traj_state, global_step):
"""
Updates the DRQN agent.

Expand Down Expand Up @@ -339,9 +345,9 @@ def update(self, update_info, agent_traj_state=None):
# If the replay buffer doesn't have enough samples, catch the exception
# and move on.
if (
self._learn_schedule.update()
self._learn_schedule(global_step)
and self._replay_buffer.size() > 0
and self._update_period_schedule.update()
and self._update_period_schedule(global_step)
):
batch = self._replay_buffer.sample(batch_size=self._batch_size)
(
Expand Down Expand Up @@ -385,8 +391,8 @@ def update(self, update_info, agent_traj_state=None):
interm_loss *= batch["mask"]
loss = interm_loss.sum() / batch["mask"].sum()

if self._logger.should_log(self._timescale):
self._logger.log_scalar("train_loss", loss, self._timescale)
if self._log_schedule(global_step):
logger.log_scalar("train_loss", loss, self.id)

loss.backward()
if self._grad_clip is not None:
Expand All @@ -396,6 +402,6 @@ def update(self, update_info, agent_traj_state=None):
self._optimizer.step()

# Update target network
if self._target_net_update_schedule.update():
if self._target_net_update_schedule(global_step):
self._update_target()
return agent_traj_state
Loading