Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Commit

Permalink
Integrate coach.py params with distributed Coach. (#42)
Browse files Browse the repository at this point in the history
* Integrate coach.py params with distributed Coach.
* Minor improvements
- Use enums instead of constants.
- Reduce code duplication.
- Ask experiment name with timeout.
  • Loading branch information
balajismaniam authored Nov 5, 2018
1 parent 95b4fc6 commit 7e70063
Show file tree
Hide file tree
Showing 13 changed files with 263 additions and 285 deletions.
7 changes: 7 additions & 0 deletions dist-coach-config.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[coach]
image = <insert-image-name>
memory_backend = redispubsub
data_store = s3
s3_end_point = s3.amazonaws.com
s3_bucket_name = <insert-s3-bucket-name>
s3_creds_file = <insert-path-for-s3-creds-file>
15 changes: 15 additions & 0 deletions rl_coach/base_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ class EmbeddingMergerType(Enum):
#Multiply = 3


# DistributedCoachSynchronizationType provides the synchronization type for distributed Coach.
# The default value is None, which means the algorithm or preset cannot be used with distributed Coach.
class DistributedCoachSynchronizationType(Enum):
# In SYNC mode, the trainer waits for all the experiences to be gathered from distributed rollout workers before
# training a new policy and the rollout workers wait for a new policy before gathering experiences.
SYNC = "sync"

# In ASYNC mode, the trainer doesn't wait for any set of experiences to be gathered from distributed rollout workers
# and the rollout workers continously gather experiences loading new policies, whenever they become available.
ASYNC = "async"


def iterable_to_items(obj):
if isinstance(obj, dict) or isinstance(obj, OrderedDict) or isinstance(obj, types.MappingProxyType):
items = obj.items()
Expand Down Expand Up @@ -154,6 +166,9 @@ def __init__(self):
# intrinsic reward
self.scale_external_reward_by_intrinsic_reward_value = False

# Distributed Coach params
self.distributed_coach_synchronization_type = None


class PresetValidationParameters(Parameters):
def __init__(self):
Expand Down
176 changes: 175 additions & 1 deletion rl_coach/coach.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
sys.path.append('.')

import copy
from configparser import ConfigParser, Error
from rl_coach.core_types import EnvironmentSteps
import os
from rl_coach import logger
Expand All @@ -26,6 +27,7 @@
import atexit
import time
import sys
import json
from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters
from multiprocessing import Process
from multiprocessing.managers import BaseManager
Expand All @@ -35,6 +37,14 @@
from rl_coach.agents.human_agent import HumanAgentParameters
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, RunType, RunTypeParameters
from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
from rl_coach.memories.backend.memory_impl import construct_memory_params
from rl_coach.data_stores.data_store import DataStoreParameters
from rl_coach.data_stores.s3_data_store import S3DataStoreParameters
from rl_coach.data_stores.data_store_impl import get_data_store, construct_data_store_params
from rl_coach.training_worker import training_worker
from rl_coach.rollout_worker import rollout_worker, wait_for_checkpoint


if len(set(failed_imports)) > 0:
Expand Down Expand Up @@ -108,6 +118,7 @@ def display_all_presets_and_exit():
print(preset)
sys.exit(0)


def expand_preset(preset):
if preset.lower() in [p.lower() for p in list_all_presets()]:
preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', preset))
Expand Down Expand Up @@ -150,6 +161,49 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
if args.list:
display_all_presets_and_exit()

# Read args from config file for distributed Coach.
if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
coach_config = ConfigParser({
'image': '',
'memory_backend': 'redispubsub',
'data_store': 's3',
's3_end_point': 's3.amazonaws.com',
's3_bucket_name': '',
's3_creds_file': ''
})
try:
coach_config.read(args.distributed_coach_config_path)
args.image = coach_config.get('coach', 'image')
args.memory_backend = coach_config.get('coach', 'memory_backend')
args.data_store = coach_config.get('coach', 'data_store')
args.s3_end_point = coach_config.get('coach', 's3_end_point')
args.s3_bucket_name = coach_config.get('coach', 's3_bucket_name')
args.s3_creds_file = coach_config.get('coach', 's3_creds_file')
except Error as e:
screen.error("Error when reading distributed Coach config file: {}".format(e))

if args.image == '':
screen.error("Image cannot be empty.")

data_store_choices = ['s3']
if args.data_store not in data_store_choices:
screen.warning("{} data store is unsupported.".format(args.data_store))
screen.error("Supported data stores are {}.".format(data_store_choices))

memory_backend_choices = ['redispubsub']
if args.memory_backend not in memory_backend_choices:
screen.warning("{} memory backend is not supported.".format(args.memory_backend))
screen.error("Supported memory backends are {}.".format(memory_backend_choices))

if args.s3_bucket_name == '':
screen.error("S3 bucket name cannot be empty.")

if args.s3_creds_file == '':
args.s3_creds_file = None

if args.play and args.distributed_coach:
screen.error("Playing is not supported in distributed Coach.")

# replace a short preset name with the full path
if args.preset is not None:
args.preset = expand_preset(args.preset)
Expand Down Expand Up @@ -217,6 +271,94 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'
graph_manager.improve()


def handle_distributed_coach_tasks(graph_manager, args):
ckpt_inside_container = "/checkpoint"

memory_backend_params = None
if args.memory_backend_params:
memory_backend_params = json.loads(args.memory_backend_params)
memory_backend_params['run_type'] = str(args.distributed_coach_run_type)
graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params))

data_store_params = None
if args.data_store_params:
data_store_params = construct_data_store_params(json.loads(args.data_store_params))
data_store_params.checkpoint_dir = ckpt_inside_container
graph_manager.data_store_params = data_store_params

if args.distributed_coach_run_type == RunType.TRAINER:
training_worker(
graph_manager=graph_manager,
checkpoint_dir=ckpt_inside_container
)

if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
data_store = None
if args.data_store_params:
data_store = get_data_store(data_store_params)
wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store)

rollout_worker(
graph_manager=graph_manager,
checkpoint_dir=ckpt_inside_container,
data_store=data_store,
num_workers=args.num_workers
)


def handle_distributed_coach_orchestrator(graph_manager, args):
ckpt_inside_container = "/checkpoint"
rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + sys.argv[1:]
trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + sys.argv[1:]

if '--experiment_name' not in rollout_command:
rollout_command = rollout_command + ['--experiment_name', args.experiment_name]

if '--experiment_name' not in trainer_command:
trainer_command = trainer_command + ['--experiment_name', args.experiment_name]

memory_backend_params = None
if args.memory_backend == "redispubsub":
memory_backend_params = RedisPubSubMemoryBackendParameters()

ds_params_instance = None
if args.data_store == "s3":
ds_params = DataStoreParameters("s3", "", "")
ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name,
creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container)

worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers)
trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER))

orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params],
kubeconfig='~/.kube/config',
memory_backend_parameters=memory_backend_params,
data_store_params=ds_params_instance)
orchestrator = Kubernetes(orchestration_params)
if not orchestrator.setup():
print("Could not setup.")
return

if orchestrator.deploy_trainer():
print("Successfully deployed trainer.")
else:
print("Could not deploy trainer.")
return

if orchestrator.deploy_worker():
print("Successfully deployed rollout worker(s).")
else:
print("Could not deploy rollout worker(s).")
return

try:
orchestrator.trainer_logs()
except KeyboardInterrupt:
pass

orchestrator.undeploy()


def main():
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--preset',
Expand Down Expand Up @@ -329,11 +471,35 @@ def main():
help="(int) A seed to use for running the experiment",
default=None,
type=int)
parser.add_argument('-dc', '--distributed_coach',
help="(flag) Use distributed Coach.",
action='store_true')
parser.add_argument('-dcp', '--distributed_coach_config_path',
help="(string) Path to config file when using distributed rollout workers."
"Only distributed Coach parameters should be provided through this config file."
"Rest of the parameters are provided using Coach command line options."
"Used only with --distributed_coach flag."
"Ignored if --distributed_coach flag is not used.",
type=str)
parser.add_argument('--memory_backend_params',
help=argparse.SUPPRESS,
type=str)
parser.add_argument('--data_store_params',
help=argparse.SUPPRESS,
type=str)
parser.add_argument('--distributed_coach_run_type',
help=argparse.SUPPRESS,
type=RunType,
default=RunType.ORCHESTRATOR,
choices=list(RunType))

args = parse_arguments(parser)

graph_manager = get_graph_manager_from_args(args)

if args.distributed_coach and not graph_manager.agent_params.algorithm.distributed_coach_synchronization_type:
screen.error("{} preset is not supported using distributed Coach.".format(args.preset))

# Intel optimized TF seems to run significantly faster when limiting to a single OMP thread.
# This will not affect GPU runs.
os.environ["OMP_NUM_THREADS"] = "1"
Expand All @@ -343,14 +509,22 @@ def main():
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_verbosity)

# turn off the summary at the end of the run if necessary
if not args.no_summary:
if not args.no_summary and not args.distributed_coach:
atexit.register(logger.summarize_experiment)
screen.change_terminal_title(args.experiment_name)

# open dashboard
if args.open_dashboard:
open_dashboard(args.experiment_path)

if args.distributed_coach and args.distributed_coach_run_type != RunType.ORCHESTRATOR:
handle_distributed_coach_tasks(graph_manager, args)
return

if args.distributed_coach and args.distributed_coach_run_type == RunType.ORCHESTRATOR:
handle_distributed_coach_orchestrator(graph_manager, args)
return

# Single-threaded runs
if args.num_workers == 1:
# Start the training or evaluation
Expand Down
4 changes: 3 additions & 1 deletion rl_coach/graph_managers/graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from rl_coach.logger import screen, Logger
from rl_coach.utils import set_cpu, start_shell_command_and_wait
from rl_coach.data_stores.data_store_impl import get_data_store
from rl_coach.orchestrators.kubernetes_orchestrator import RunType


class ScheduleParameters(Parameters):
Expand Down Expand Up @@ -361,9 +362,10 @@ def act(self, steps: PlayingStepsType) -> None:
self.verify_graph_was_created()

if hasattr(self, 'data_store_params') and hasattr(self.agent_params.memory, 'memory_backend_params'):
if self.agent_params.memory.memory_backend_params.run_type == "worker":
if self.agent_params.memory.memory_backend_params.run_type == str(RunType.ROLLOUT_WORKER):
data_store = get_data_store(self.data_store_params)
data_store.load_from_store()

# perform several steps of playing
count_end = self.current_step_counter + steps
while self.current_step_counter < count_end:
Expand Down
25 changes: 24 additions & 1 deletion rl_coach/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import os
import re
import shutil
import signal
import time
import uuid
from subprocess import Popen, PIPE
from typing import Union

Expand Down Expand Up @@ -90,6 +92,23 @@ def error(self, text, crash=True):
def ask_input(self, title):
return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END))

def ask_input_with_timeout(self, title, timeout, msg_if_timeout='Timeout expired.'):
class TimeoutExpired(Exception):
pass

def timeout_alarm_handler(signum, frame):
raise TimeoutExpired

signal.signal(signal.SIGALRM, timeout_alarm_handler)
signal.alarm(timeout)

try:
return input("{}{}{}".format(Colors.BG_CYAN, title, Colors.END))
except TimeoutExpired:
self.warning(msg_if_timeout)
finally:
signal.alarm(0)

def ask_yes_no(self, title: str, default: Union[None, bool] = None):
"""
Ask the user for a yes / no question and return True if the answer is yes and False otherwise.
Expand Down Expand Up @@ -333,10 +352,14 @@ def get_experiment_name(initial_experiment_name=''):
match = None
while match is None:
if initial_experiment_name == '':
experiment_name = screen.ask_input("Please enter an experiment name: ")
msg_if_timeout = "Timeout waiting for experiement name."
experiment_name = screen.ask_input_with_timeout("Please enter an experiment name: ", 60, msg_if_timeout)
else:
experiment_name = initial_experiment_name

if not experiment_name:
experiment_name = ''

experiment_name = experiment_name.replace(" ", "_")
match = re.match("^$|^[\w -/]{1,1000}$", experiment_name)

Expand Down
Loading

0 comments on commit 7e70063

Please sign in to comment.