Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Commit

Permalink
Batch RL (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gal Leibovich authored Mar 19, 2019
1 parent 4a8451f commit e3c7e52
Show file tree
Hide file tree
Showing 38 changed files with 1,003 additions and 87 deletions.
Empty file added __init__.py
Empty file.
3 changes: 2 additions & 1 deletion docs_raw/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ In Coach, this can be done in two steps -

.. code-block:: python
coach -p Doom_Basic_BC -cp='agent.load_memory_from_file_path=\"<experiment dir>/replay_buffer.p\"'
from rl_coach.core_types import PickledReplayBuffer
coach -p Doom_Basic_BC -cp='agent.load_memory_from_file_path=PickledReplayBuffer(\"<experiment dir>/replay_buffer.p\"')
Visualizations
Expand Down
79 changes: 62 additions & 17 deletions rl_coach/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from rl_coach.utils import Signal, force_list
from rl_coach.utils import dynamic_import_and_instantiate_module_from_params
from rl_coach.memories.backend.memory_impl import get_memory_backend
from rl_coach.core_types import TimeTypes
from rl_coach.off_policy_evaluators.ope_manager import OpeManager
from rl_coach.core_types import PickledReplayBuffer, CsvDataset


class Agent(AgentInterface):
Expand All @@ -49,10 +52,10 @@ def __init__(self, agent_parameters: AgentParameters, parent: Union['LevelManage
and self.ap.memory.shared_memory
if self.shared_memory:
self.shared_memory_scratchpad = self.ap.task_parameters.shared_memory_scratchpad
self.name = agent_parameters.name
self.parent = parent
self.parent_level_manager = None
self.full_name_id = agent_parameters.full_name_id = self.name
# TODO this needs to be sorted out. Why the duplicates for the agent's name?
self.full_name_id = agent_parameters.full_name_id = self.name = agent_parameters.name

if type(agent_parameters.task_parameters) == DistributedTaskParameters:
screen.log_title("Creating agent - name: {} task id: {} (may take up to 30 seconds due to "
Expand Down Expand Up @@ -84,9 +87,17 @@ def __init__(self, agent_parameters: AgentParameters, parent: Union['LevelManage
self.memory.set_memory_backend(self.memory_backend)

if agent_parameters.memory.load_memory_from_file_path:
screen.log_title("Loading replay buffer from pickle. Pickle path: {}"
.format(agent_parameters.memory.load_memory_from_file_path))
self.memory.load(agent_parameters.memory.load_memory_from_file_path)
if isinstance(agent_parameters.memory.load_memory_from_file_path, PickledReplayBuffer):
screen.log_title("Loading a pickled replay buffer. Pickled file path: {}"
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
self.memory.load_pickled(agent_parameters.memory.load_memory_from_file_path.filepath)
elif isinstance(agent_parameters.memory.load_memory_from_file_path, CsvDataset):
screen.log_title("Loading a replay buffer from a CSV file. CSV file path: {}"
.format(agent_parameters.memory.load_memory_from_file_path.filepath))
self.memory.load_csv(agent_parameters.memory.load_memory_from_file_path)
else:
raise ValueError('Trying to load a replay buffer using an unsupported method - {}. '
.format(agent_parameters.memory.load_memory_from_file_path))

if self.shared_memory and self.is_chief:
self.shared_memory_scratchpad.add(self.memory_lookup_name, self.memory)
Expand Down Expand Up @@ -147,6 +158,7 @@ def __init__(self, agent_parameters: AgentParameters, parent: Union['LevelManage
self.total_steps_counter = 0
self.running_reward = None
self.training_iteration = 0
self.training_epoch = 0
self.last_target_network_update_step = 0
self.last_training_phase_step = 0
self.current_episode = self.ap.current_episode = 0
Expand Down Expand Up @@ -184,6 +196,7 @@ def __init__(self, agent_parameters: AgentParameters, parent: Union['LevelManage
self.discounted_return = self.register_signal('Discounted Return')
if isinstance(self.in_action_space, GoalsSpace):
self.distance_from_goal = self.register_signal('Distance From Goal', dump_one_value_per_step=True)

# use seed
if self.ap.task_parameters.seed is not None:
random.seed(self.ap.task_parameters.seed)
Expand All @@ -193,6 +206,9 @@ def __init__(self, agent_parameters: AgentParameters, parent: Union['LevelManage
random.seed()
np.random.seed()

# batch rl
self.ope_manager = OpeManager() if self.ap.is_batch_rl_training else None

@property
def parent(self) -> 'LevelManager':
"""
Expand Down Expand Up @@ -228,6 +244,7 @@ def setup_logger(self) -> None:
format(graph_name=self.parent_level_manager.parent_graph_manager.name,
level_name=self.parent_level_manager.name,
agent_full_id='.'.join(self.full_name_id.split('/')))
self.agent_logger.set_index_name(self.parent_level_manager.parent_graph_manager.time_metric.value.name)
self.agent_logger.set_logger_filenames(self.ap.task_parameters.experiment_path, logger_prefix=logger_prefix,
add_timestamp=True, task_id=self.task_id)
if self.ap.visualization.dump_in_episode_signals:
Expand Down Expand Up @@ -387,7 +404,8 @@ def reset_evaluation_state(self, val: RunPhase) -> None:
elif ending_evaluation:
# we write to the next episode, because it could be that the current episode was already written
# to disk and then we won't write it again
self.agent_logger.set_current_time(self.current_episode + 1)
self.agent_logger.set_current_time(self.get_current_time() + 1)

evaluation_reward = self.accumulated_rewards_across_evaluation_episodes / self.num_evaluation_episodes_completed
self.agent_logger.create_signal_value(
'Evaluation Reward', evaluation_reward)
Expand Down Expand Up @@ -471,8 +489,11 @@ def update_log(self) -> None:
:return: None
"""
# log all the signals to file
self.agent_logger.set_current_time(self.current_episode)
current_time = self.get_current_time()
self.agent_logger.set_current_time(current_time)
self.agent_logger.create_signal_value('Training Iter', self.training_iteration)
self.agent_logger.create_signal_value('Episode #', self.current_episode)
self.agent_logger.create_signal_value('Epoch', self.training_epoch)
self.agent_logger.create_signal_value('In Heatup', int(self._phase == RunPhase.HEATUP))
self.agent_logger.create_signal_value('ER #Transitions', self.call_memory('num_transitions'))
self.agent_logger.create_signal_value('ER #Episodes', self.call_memory('length'))
Expand All @@ -485,13 +506,17 @@ def update_log(self) -> None:
if self._phase == RunPhase.TRAIN else np.nan)

self.agent_logger.create_signal_value('Update Target Network', 0, overwrite=False)
self.agent_logger.update_wall_clock_time(self.current_episode)
self.agent_logger.update_wall_clock_time(current_time)

# The following signals are created with meaningful values only when an evaluation phase is completed.
# Creating with default NaNs for any HEATUP/TRAIN/TEST episode which is not the last in an evaluation phase
self.agent_logger.create_signal_value('Evaluation Reward', np.nan, overwrite=False)
self.agent_logger.create_signal_value('Shaped Evaluation Reward', np.nan, overwrite=False)
self.agent_logger.create_signal_value('Success Rate', np.nan, overwrite=False)
self.agent_logger.create_signal_value('Inverse Propensity Score', np.nan, overwrite=False)
self.agent_logger.create_signal_value('Direct Method Reward', np.nan, overwrite=False)
self.agent_logger.create_signal_value('Doubly Robust', np.nan, overwrite=False)
self.agent_logger.create_signal_value('Sequential Doubly Robust', np.nan, overwrite=False)

for signal in self.episode_signals:
self.agent_logger.create_signal_value("{}/Mean".format(signal.name), signal.get_mean())
Expand All @@ -500,8 +525,7 @@ def update_log(self) -> None:
self.agent_logger.create_signal_value("{}/Min".format(signal.name), signal.get_min())

# dump
if self.current_episode % self.ap.visualization.dump_signals_to_csv_every_x_episodes == 0 \
and self.current_episode > 0:
if self.current_episode % self.ap.visualization.dump_signals_to_csv_every_x_episodes == 0:
self.agent_logger.dump_output_csv()

def handle_episode_ended(self) -> None:
Expand Down Expand Up @@ -537,7 +561,8 @@ def handle_episode_ended(self) -> None:
self.total_reward_in_current_episode >= self.spaces.reward.reward_success_threshold:
self.num_successes_across_evaluation_episodes += 1

if self.ap.visualization.dump_csv:
if self.ap.visualization.dump_csv and \
self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.EpisodeNumber:
self.update_log()

if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
Expand Down Expand Up @@ -651,18 +676,22 @@ def train(self) -> float:
"""
loss = 0
if self._should_train():
self.training_epoch += 1
for network in self.networks.values():
network.set_is_training(True)

for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
# TODO: this should be network dependent
network_parameters = list(self.ap.network_wrappers.values())[0]
# TODO: this should be network dependent
network_parameters = list(self.ap.network_wrappers.values())[0]

# we either go sequentially through the entire replay buffer in the batch RL mode,
# or sample randomly for the basic RL case.
training_schedule = self.call_memory('get_shuffled_data_generator', network_parameters.batch_size) if \
self.ap.is_batch_rl_training else [self.call_memory('sample', network_parameters.batch_size) for _ in
range(self.ap.algorithm.num_consecutive_training_steps)]

for batch in training_schedule:
# update counters
self.training_iteration += 1

# sample a batch and train on it
batch = self.call_memory('sample', network_parameters.batch_size)
if self.pre_network_filter is not None:
batch = self.pre_network_filter.filter(batch, update_internal_state=False, deep_copy=False)

Expand All @@ -673,6 +702,7 @@ def train(self) -> float:
batch = Batch(batch)
total_loss, losses, unclipped_grads = self.learn_from_batch(batch)
loss += total_loss

self.unclipped_grads.add_sample(unclipped_grads)

# TODO: the learning rate decay should be done through the network instead of here
Expand All @@ -697,6 +727,12 @@ def train(self) -> float:
if self.imitation:
self.log_to_screen()

if self.ap.visualization.dump_csv and \
self.parent_level_manager.parent_graph_manager.time_metric == TimeTypes.Epoch:
# in BatchRL, or imitation learning, the agent never acts, so we have to get the stats out here.
# we dump the data out every epoch
self.update_log()

for network in self.networks.values():
network.set_is_training(False)

Expand Down Expand Up @@ -1034,3 +1070,12 @@ def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
for network in self.networks.values():
savers.update(network.collect_savers(parent_path_suffix))
return savers

def get_current_time(self):
pass
return {
TimeTypes.EpisodeNumber: self.current_episode,
TimeTypes.TrainingIteration: self.training_iteration,
TimeTypes.EnvironmentSteps: self.total_steps_counter,
TimeTypes.WallClockTime: self.agent_logger.get_current_wall_clock_time(),
TimeTypes.Epoch: self.training_epoch}[self.parent_level_manager.parent_graph_manager.time_metric]
9 changes: 9 additions & 0 deletions rl_coach/agents/agent_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,12 @@ def handle_episode_ended(self) -> None:
:return: None
"""
raise NotImplementedError("")

def run_off_policy_evaluation(self) -> None:
"""
Run off-policy evaluation estimators to evaluate the trained policy performance against a dataset.
Should only be implemented for off-policy RL algorithms.
:return: None
"""
raise NotImplementedError("")
3 changes: 3 additions & 0 deletions rl_coach/agents/bootstrapped_dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def learn_from_batch(self, batch):
q_st_plus_1 = result[:self.ap.exploration.architecture_num_q_heads]
TD_targets = result[self.ap.exploration.architecture_num_q_heads:]

# add Q value samples for logging
self.q_values.add_sample(TD_targets)

# initialize with the current prediction so that we will
# only update the action that we have actually done in this transition
for i in range(self.ap.network_wrappers['main'].batch_size):
Expand Down
3 changes: 3 additions & 0 deletions rl_coach/agents/categorical_dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def learn_from_batch(self, batch):
(self.networks['main'].online_network, batch.states(network_keys))
])

# add Q value samples for logging
self.q_values.add_sample(self.distribution_prediction_to_q_values(TD_targets))

# select the optimal actions for the next state
target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1)
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
Expand Down
1 change: 1 addition & 0 deletions rl_coach/agents/composite_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,4 @@ def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
savers.update(agent.collect_savers(
parent_path_suffix="{}.{}".format(parent_path_suffix, self.name)))
return savers

3 changes: 3 additions & 0 deletions rl_coach/agents/ddqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def learn_from_batch(self, batch):
(self.networks['main'].online_network, batch.states(network_keys))
])

# add Q value samples for logging
self.q_values.add_sample(TD_targets)

# initialize with the current prediction so that we will
# only update the action that we have actually done in this transition
TD_errors = []
Expand Down
3 changes: 3 additions & 0 deletions rl_coach/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def learn_from_batch(self, batch):
(self.networks['main'].online_network, batch.states(network_keys))
])

# add Q value samples for logging
self.q_values.add_sample(TD_targets)

# only update the action that we have actually done in this transition
TD_errors = []
for i in range(self.ap.network_wrappers['main'].batch_size):
Expand Down
3 changes: 3 additions & 0 deletions rl_coach/agents/n_step_q_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def learn_from_batch(self, batch):
else:
assert True, 'The available values for targets_horizon are: 1-Step, N-Step'

# add Q value samples for logging
self.q_values.add_sample(state_value_head_targets)

# train
result = self.networks['main'].online_network.accumulate_gradients(batch.states(network_keys), [state_value_head_targets])

Expand Down
3 changes: 3 additions & 0 deletions rl_coach/agents/qr_dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def learn_from_batch(self, batch):
(self.networks['main'].online_network, batch.states(network_keys))
])

# add Q value samples for logging
self.q_values.add_sample(self.get_q_values(current_quantiles))

# get the optimal actions to take for the next states
target_actions = np.argmax(self.get_q_values(next_state_quantiles), axis=1)

Expand Down
3 changes: 3 additions & 0 deletions rl_coach/agents/rainbow_dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def learn_from_batch(self, batch):
(self.networks['main'].online_network, batch.states(network_keys))
])

# add Q value samples for logging
self.q_values.add_sample(self.distribution_prediction_to_q_values(TD_targets))

# only update the action that we have actually done in this transition (using the Double-DQN selected actions)
target_actions = ddqn_selected_actions
m = np.zeros((self.ap.network_wrappers['main'].batch_size, self.z_values.size))
Expand Down
Loading

0 comments on commit e3c7e52

Please sign in to comment.