diff --git a/examples/distributed/offpolicy_distributed_primary.py b/examples/distributed/offpolicy_distributed_primary.py new file mode 100644 index 00000000..c6b726a6 --- /dev/null +++ b/examples/distributed/offpolicy_distributed_primary.py @@ -0,0 +1,93 @@ +from genrl.distributed import ( + Master, + ExperienceServer, + ParameterServer, + ActorNode, + LearnerNode, +) +from genrl.core import ReplayBuffer +from genrl.agents import DDPG +from genrl.trainers import DistributedTrainer +from genrl.utils import Logger +import gym +import torch.distributed.rpc as rpc +import time + + +N_ACTORS = 1 +BUFFER_SIZE = 5000 +MAX_ENV_STEPS = 500 +TRAIN_STEPS = 5000 +BATCH_SIZE = 64 +INIT_BUFFER_SIZE = 1000 +WARMUP_STEPS = 1000 + + +def collect_experience(agent, parameter_server, experience_server, learner): + while not learner.is_completed(): + agent.load_weights(parameter_server.get_weights()) + obs = agent.env.reset() + done = False + for i in range(MAX_ENV_STEPS): + action = ( + agent.env.action_space.sample() + if i < WARMUP_STEPS + else agent.select_action(obs) + ) + next_obs, reward, done, _ = agent.env.step(action) + experience_server.push((obs, action, reward, next_obs, done)) + obs = next_obs + if done: + break + + +class MyTrainer(DistributedTrainer): + def __init__( + self, agent, train_steps, batch_size, init_buffer_size, log_interval=200 + ): + super(MyTrainer, self).__init__(agent) + self.train_steps = train_steps + self.batch_size = batch_size + self.init_buffer_size = init_buffer_size + self.log_interval = log_interval + + def train(self, parameter_server, experience_server): + while experience_server.__len__() < self.init_buffer_size: + time.sleep(1) + for i in range(self.train_steps): + batch = experience_server.sample(self.batch_size) + if batch is None: + continue + self.agent.update_params(1, batch) + parameter_server.store_weights(self.agent.get_weights()) + if i % self.log_interval == 0: + self.evaluate(i) + + +master = Master( + world_size=5, + address="localhost", + port=29502, + proc_start_method="fork", + rpc_backend=rpc.BackendType.TENSORPIPE, +) +env = gym.make("Pendulum-v0") +agent = DDPG("mlp", env) +parameter_server = ParameterServer("param-0", master, agent.get_weights(), rank=1) +buffer = ReplayBuffer(BUFFER_SIZE) +experience_server = ExperienceServer("experience-0", master, buffer, rank=2) +trainer = MyTrainer(agent, TRAIN_STEPS, BATCH_SIZE, INIT_BUFFER_SIZE) +learner = LearnerNode("learner-0", master, "param-0", "experience-0", trainer, rank=3) +actors = [ + ActorNode( + name=f"actor-{i}", + master=master, + parameter_server_name="param-0", + experience_server_name="experience-0", + learner_name="learner-0", + agent=agent, + collect_experience=collect_experience, + rank=i + 4, + ) + for i in range(N_ACTORS) +] diff --git a/examples/distributed/offpolicy_distributed_secondary.py b/examples/distributed/offpolicy_distributed_secondary.py new file mode 100644 index 00000000..9a89f6d8 --- /dev/null +++ b/examples/distributed/offpolicy_distributed_secondary.py @@ -0,0 +1,80 @@ +from genrl.distributed import ( + Master, + ExperienceServer, + ParameterServer, + ActorNode, + LearnerNode, + WeightHolder, +) +from genrl.core import ReplayBuffer +from genrl.agents import DDPG +from genrl.trainers import DistributedTrainer +import gym +import argparse +import torch.multiprocessing as mp + + +N_ACTORS = 2 +BUFFER_SIZE = 10 +MAX_ENV_STEPS = 100 +TRAIN_STEPS = 10 +BATCH_SIZE = 1 + + +def collect_experience(agent, experience_server_rref): + obs = agent.env.reset() + done = False + for i in range(MAX_ENV_STEPS): + action = agent.select_action(obs) + next_obs, reward, done, info = agent.env.step(action) + experience_server_rref.rpc_sync().push((obs, action, reward, next_obs, done)) + obs = next_obs + if done: + break + + +# class MyTrainer(DistributedTrainer): +# def __init__(self, agent, train_steps, batch_size): +# super(MyTrainer, self).__init__(agent) +# self.train_steps = train_steps +# self.batch_size = batch_size + +# def train(self, parameter_server_rref, experience_server_rref): +# i = 0 +# while i < self.train_steps: +# batch = experience_server_rref.rpc_sync().sample(self.batch_size) +# if batch is None: +# continue +# self.agent.update_params(batch, 1) +# parameter_server_rref.rpc_sync().store_weights(self.agent.get_weights()) +# print(f"Trainer: {i + 1} / {self.train_steps} steps completed") +# i += 1 + + +mp.set_start_method("fork") + +master = Master(world_size=8, address="localhost", port=29500, secondary=True) +env = gym.make("Pendulum-v0") +agent = DDPG("mlp", env) +# parameter_server = ParameterServer( +# "param-0", master, WeightHolder(agent.get_weights()), rank=1 +# ) +# buffer = ReplayBuffer(BUFFER_SIZE) +# experience_server = ExperienceServer("experience-0", master, buffer, rank=2) +# trainer = MyTrainer(agent, TRAIN_STEPS, BATCH_SIZE) +# learner = LearnerNode( +# "learner-0", master, parameter_server, experience_server, trainer, rank=3 +# ) +actors = [ + ActorNode( + name=f"actor-{i+2}", + master=master, + parameter_server_name="param-0", + experience_server_name="experience-0", + learner_name="learner-0", + agent=agent, + collect_experience=collect_experience, + rank=i + 6, + ) + for i in range(N_ACTORS) +] diff --git a/examples/distributed/onpolicy_distributed_primary.py b/examples/distributed/onpolicy_distributed_primary.py new file mode 100644 index 00000000..9709f4f1 --- /dev/null +++ b/examples/distributed/onpolicy_distributed_primary.py @@ -0,0 +1,236 @@ +from genrl.distributed import ( + Master, + ExperienceServer, + ParameterServer, + ActorNode, + LearnerNode, +) +from genrl.core.policies import MlpPolicy +from genrl.core.values import MlpValue +from genrl.trainers import DistributedTrainer +import gym +import torch.distributed.rpc as rpc +import torch +from genrl.utils import get_env_properties +import torch.nn.functional as F +import copy +import time + +N_ACTORS = 1 +BUFFER_SIZE = 5 +MAX_ENV_STEPS = 500 +TRAIN_STEPS = 50 + + +def get_advantages_returns(rewards, dones, values, gamma=0.99, gae_lambda=1): + buffer_size = len(rewards) + advantages = torch.zeros_like(rewards) + last_gae_lam = 0 + for step in reversed(range(buffer_size)): + if step == buffer_size - 1: + next_non_terminal = 1.0 - dones[-1] + next_value = values[-1] + else: + next_non_terminal = 1.0 - dones[step + 1] + next_value = values[step + 1] + delta = rewards[step] + gamma * next_value * next_non_terminal - values[step] + last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam + advantages[step] = last_gae_lam + returns = advantages + values + return advantages.detach(), returns.detach() + + +def unroll_trajs(trajectories): + size = sum([len(traj) for traj in trajectories]) + obs = torch.zeros(size, *trajectories[0].states[0].shape) + actions = torch.zeros(size) + rewards = torch.zeros(size) + dones = torch.zeros(size) + + i = 0 + for traj in trajectories: + for j in range(len(traj)): + obs[i] = torch.tensor(traj.states[j]) + actions[i] = torch.tensor(traj.actions[j]) + rewards[i] = torch.tensor(traj.rewards[j]) + dones[i] = torch.tensor(traj.dones[j]) + + return obs, actions, rewards, dones + + +class A2C: + def __init__( + self, env, policy, value, policy_optim, value_optim, grad_norm_limit=0.5 + ): + self.env = env + self.policy = policy + self.value = value + self.policy_optim = policy_optim + self.value_optim = value_optim + self.grad_norm_limit = grad_norm_limit + + def select_action(self, obs: torch.Tensor, deterministic: bool = False): + logits = self.policy(torch.tensor(obs, dtype=torch.float)) + distribution = torch.distributions.Categorical(logits=logits) + action = torch.argmax(logits) if deterministic else distribution.sample() + return action.item() + + def update_params(self, trajectories): + obs, actions, rewards, dones = unroll_trajs(trajectories) + values = self.value(obs).view(-1) + dist = torch.distributions.Categorical(self.policy(obs)) + log_probs = dist.log_prob(actions) + entropy = dist.entropy() + advantages, returns = get_advantages_returns(rewards, dones, values) + + policy_loss = -torch.mean(advantages * log_probs) - torch.mean(entropy) + value_loss = F.mse_loss(returns, values) + + self.policy_optim.zero_grad() + policy_loss.backward() + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.grad_norm_limit) + self.policy_optim.step() + + self.value_optim.zero_grad() + value_loss.backward() + torch.nn.utils.clip_grad_norm_(self.value.parameters(), self.grad_norm_limit) + self.value_optim.step() + + def get_weights(self): + return {"policy": self.policy.state_dict(), "value": self.value.state_dict()} + + def load_weights(self, weights): + self.policy.load_state_dict(weights["policy"]) + self.value.load_state_dict(weights["value"]) + + +class Trajectory: + def __init__(self): + self.states = [] + self.actions = [] + self.rewards = [] + self.dones = [] + self.__len = 0 + + def add(self, state, action, reward, done): + self.states.append(state) + self.actions.append(action) + self.rewards.append(reward) + self.dones.append(done) + self.__len += 1 + + def __len__(self): + return self.__len + + +class TrajBuffer: + def __init__(self, size): + if size <= 0: + raise ValueError("Size of buffer must be larger than 0") + self._size = size + self._memory = [] + self._full = False + + def is_full(self): + return self._full + + def push(self, traj): + if not self.is_full(): + self._memory.append(traj) + if len(self._memory) >= self._size: + self._full = True + + def get(self, clear=True): + out = copy.deepcopy(self._memory) + if clear: + self._memory = [] + self._full = False + return out + + +def collect_experience(agent, parameter_server, experience_server, learner): + current_step = -1 + while not learner.is_completed(): + if not learner.current_train_step() > current_step: + time.sleep(0.5) + continue + current_step = learner.current_train_step() + traj = Trajectory() + agent.load_weights(parameter_server.get_weights()) + while not experience_server.is_full(): + obs = agent.env.reset() + done = False + for _ in range(MAX_ENV_STEPS): + action = agent.select_action(obs) + next_obs, reward, done, _ = agent.env.step(action) + traj.add(obs, action, reward, done) + obs = next_obs + if done: + break + experience_server.push(traj) + print("pushed a traj") + + +class MyTrainer(DistributedTrainer): + def __init__(self, agent, train_steps, log_interval=1): + super(MyTrainer, self).__init__(agent) + self.train_steps = train_steps + self.log_interval = log_interval + self._weights_available = True + self._current_train_step = 0 + + def current_train_step(self): + return self._current_train_step + + def train(self, parameter_server, experience_server): + self._current_train_step = 0 + while True: + if experience_server.is_full(): + self._weights_available = False + trajectories = experience_server.get() + if trajectories is None: + continue + self.agent.update_params(trajectories) + parameter_server.store_weights(self.agent.get_weights()) + self._weights_available = True + if self._current_train_step % self.log_interval == 0: + self.evaluate(self._current_train_step) + self._current_train_step += 1 + if self._current_train_step >= self.train_steps: + break + + +master = Master( + world_size=N_ACTORS + 4, + address="localhost", + port=29500, + proc_start_method="fork", + rpc_backend=rpc.BackendType.TENSORPIPE, +) + +env = gym.make("CartPole-v0") +state_dim, action_dim, discrete, action_lim = get_env_properties(env, "mlp") +policy = MlpPolicy(state_dim, action_dim, (32, 32), discrete) +value = MlpValue(state_dim, action_dim, "V", (32, 32)) +policy_optim = torch.optim.Adam(policy.parameters(), lr=1e-3) +value_optim = torch.optim.Adam(value.parameters(), lr=1e-3) +agent = A2C(env, policy, value, policy_optim, value_optim) +buffer = TrajBuffer(BUFFER_SIZE) + +parameter_server = ParameterServer("param-0", master, agent.get_weights(), rank=1) +experience_server = ExperienceServer("experience-0", master, buffer, rank=2) +trainer = MyTrainer(agent, TRAIN_STEPS) +learner = LearnerNode("learner-0", master, "param-0", "experience-0", trainer, rank=3) +actors = [ + ActorNode( + name=f"actor-{i}", + master=master, + parameter_server_name="param-0", + experience_server_name="experience-0", + learner_name="learner-0", + agent=agent, + collect_experience=collect_experience, + rank=i + 4, + ) + for i in range(N_ACTORS) +] diff --git a/genrl/agents/deep/base/offpolicy.py b/genrl/agents/deep/base/offpolicy.py index 0f294042..660c87c4 100644 --- a/genrl/agents/deep/base/offpolicy.py +++ b/genrl/agents/deep/base/offpolicy.py @@ -80,7 +80,7 @@ def _reshape_batch(self, batch: List): """ return [*batch] - def sample_from_buffer(self, beta: float = None): + def sample_from_buffer(self, beta: float = None, batch = None): """Samples experiences from the buffer and converts them into usable formats Args: @@ -89,11 +89,12 @@ def sample_from_buffer(self, beta: float = None): Returns: batch (:obj:`list`): Replay experiences sampled from the buffer """ - # Samples from the buffer - if beta is not None: - batch = self.replay_buffer.sample(self.batch_size, beta=beta) - else: - batch = self.replay_buffer.sample(self.batch_size) + if batch is None: + # Samples from the buffer + if beta is not None: + batch = self.replay_buffer.sample(self.batch_size, beta=beta) + else: + batch = self.replay_buffer.sample(self.batch_size) states, actions, rewards, next_states, dones = self._reshape_batch(batch) @@ -106,7 +107,7 @@ def sample_from_buffer(self, beta: float = None): *[states, actions, rewards, next_states, dones, indices, weights] ) else: - raise NotImplementedError + batch = ReplayBufferSamples(*[states, actions, rewards, next_states, dones]) return batch def get_q_loss(self, batch: collections.namedtuple) -> torch.Tensor: @@ -277,4 +278,4 @@ def load_weights(self, weights) -> None: Args: weights (:obj:`dict`): Dictionary of different neural net weights """ - self.ac.load_state_dict(weights["weights"]) + self.ac.load_state_dict(weights) diff --git a/genrl/agents/deep/ddpg/ddpg.py b/genrl/agents/deep/ddpg/ddpg.py index b21b6808..4efbbfdf 100644 --- a/genrl/agents/deep/ddpg/ddpg.py +++ b/genrl/agents/deep/ddpg/ddpg.py @@ -79,14 +79,14 @@ def _create_model(self) -> None: self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy) self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value) - def update_params(self, update_interval: int) -> None: + def update_params(self, update_interval: int, batch = None) -> None: """Update parameters of the model Args: update_interval (int): Interval between successive updates of the target model """ for timestep in range(update_interval): - batch = self.sample_from_buffer() + batch = self.sample_from_buffer(batch=batch) value_loss = self.get_q_loss(batch) self.logs["value_loss"].append(value_loss.item()) @@ -123,6 +123,9 @@ def get_hyperparams(self) -> Dict[str, Any]: } return hyperparams + def get_weights(self): + return self.ac.state_dict() + def get_logging_params(self) -> Dict[str, Any]: """Gets relevant parameters for logging diff --git a/genrl/core/buffers.py b/genrl/core/buffers.py index 0a5b6e7c..207c7d0b 100644 --- a/genrl/core/buffers.py +++ b/genrl/core/buffers.py @@ -57,6 +57,9 @@ def sample( :returns: (Tuple composing of `state`, `action`, `reward`, `next_state` and `done`) """ + if batch_size > len(self.memory): + return None + batch = random.sample(self.memory, batch_size) state, action, reward, next_state, done = map(np.stack, zip(*batch)) return [ @@ -70,7 +73,7 @@ def __len__(self) -> int: :returns: Length of replay memory """ - return self.pos + return len(self.memory) class PrioritizedBuffer: diff --git a/genrl/distributed/__init__.py b/genrl/distributed/__init__.py new file mode 100644 index 00000000..c3276db1 --- /dev/null +++ b/genrl/distributed/__init__.py @@ -0,0 +1,5 @@ +from genrl.distributed.core import Master, Node +from genrl.distributed.parameter_server import ParameterServer, WeightHolder +from genrl.distributed.experience_server import ExperienceServer +from genrl.distributed.actor import ActorNode +from genrl.distributed.learner import LearnerNode diff --git a/genrl/distributed/actor.py b/genrl/distributed/actor.py new file mode 100644 index 00000000..abc5a2eb --- /dev/null +++ b/genrl/distributed/actor.py @@ -0,0 +1,52 @@ +from genrl.distributed.core import Node +from genrl.distributed.core import get_proxy, store_rref +import torch.distributed.rpc as rpc + + +class ActorNode(Node): + def __init__( + self, + name, + master, + parameter_server_name, + experience_server_name, + learner_name, + agent, + collect_experience, + rank=None, + ): + super(ActorNode, self).__init__(name, master, rank) + self.init_proc( + target=self.act, + kwargs=dict( + parameter_server_name=parameter_server_name, + experience_server_name=experience_server_name, + learner_name=learner_name, + agent=agent, + collect_experience=collect_experience, + ), + ) + self.start_proc() + + @staticmethod + def act( + name, + world_size, + rank, + parameter_server_name, + experience_server_name, + learner_name, + agent, + collect_experience, + rpc_backend, + **kwargs, + ): + rpc.init_rpc(name=name, world_size=world_size, rank=rank, backend=rpc_backend) + print(f"{name}: RPC Initialised") + store_rref(name, rpc.RRef(agent)) + parameter_server = get_proxy(parameter_server_name) + experience_server = get_proxy(experience_server_name) + learner = get_proxy(learner_name) + print(f"{name}: Begining experience collection") + collect_experience(agent, parameter_server, experience_server, learner) + rpc.shutdown() diff --git a/genrl/distributed/core.py b/genrl/distributed/core.py new file mode 100644 index 00000000..1adf6805 --- /dev/null +++ b/genrl/distributed/core.py @@ -0,0 +1,181 @@ +import torch.distributed.rpc as rpc + +import threading + +import torch.multiprocessing as mp +import os +import time + +_rref_reg = {} +_global_lock = threading.Lock() + + +def _get_rref(idx): + global _rref_reg + with _global_lock: + if idx in _rref_reg.keys(): + return _rref_reg[idx] + else: + return None + + +def _store_rref(idx, rref): + global _rref_reg + with _global_lock: + if idx in _rref_reg.keys(): + raise Warning( + f"Re-assigning RRef for key: {idx}. Make sure you are not using duplicate names for nodes" + ) + _rref_reg[idx] = rref + + +def _get_num_rrefs(): + global _rref_reg + with _global_lock: + return len(_rref_reg.keys()) + + +def get_rref(idx): + rref = rpc.rpc_sync("master", _get_rref, args=(idx,)) + while rref is None: + time.sleep(0.5) + rref = rpc.rpc_sync("master", _get_rref, args=(idx,)) + return rref + + +def store_rref(idx, rref): + rpc.rpc_sync("master", _store_rref, args=(idx, rref)) + + +def get_proxy(idx): + return get_rref(idx).rpc_sync() + + +def set_environ(address, port): + os.environ["MASTER_ADDR"] = str(address) + os.environ["MASTER_PORT"] = str(port) + + +class Node: + def __init__(self, name, master, rank): + self._name = name + self.master = master + if rank >= 0 and rank < master.world_size: + self._rank = rank + elif rank >= master.world_size: + raise ValueError("Specified rank greater than allowed by world size") + else: + raise ValueError("Invalid value of rank") + self.p = None + + def __del__(self): + if self.p is None: + raise RuntimeWarning( + "Removing node when process was not initialised properly" + ) + else: + self.p.join() + + @staticmethod + def _target_wrapper(target, **kwargs): + pid = os.getpid() + print(f"Starting {kwargs['name']} with pid {pid}") + set_environ(kwargs["master_address"], kwargs["master_port"]) + target(**kwargs) + print(f"Shutdown {kwargs['name']} with pid {pid}") + + def init_proc(self, target, kwargs): + kwargs.update( + dict( + name=self.name, + master_address=self.master.address, + master_port=self.master.port, + world_size=self.master.world_size, + rank=self.rank, + rpc_backend=self.master.rpc_backend, + ) + ) + self.p = mp.Process(target=self._target_wrapper, args=(target,), kwargs=kwargs) + + def start_proc(self): + if self.p is None: + raise RuntimeError("Trying to start uninitialised process") + self.p.start() + + @property + def name(self): + return self._name + + @property + def rref(self): + return get_rref(self.name) + + @property + def rank(self): + return self._rank + + +class Master: + def __init__( + self, + world_size, + address="localhost", + port=29501, + secondary=False, + proc_start_method="fork", + rpc_backend=rpc.BackendType.PROCESS_GROUP, + ): + mp.set_start_method(proc_start_method) + set_environ(address, port) + self._world_size = world_size + self._address = address + self._port = port + self._secondary = secondary + self._rpc_backend = rpc_backend + + print( + "Configuration - {\n" + f"RPC Address : {self.address}\n" + f"RPC Port : {self.port}\n" + f"RPC World Size : {self.world_size}\n" + f"RPC Backend : {self.rpc_backend}\n" + f"Process Start Method : {proc_start_method}\n" + f"Seondary Master : {self.is_secondary}\n" + "}" + ) + + if not self._secondary: + self.p = mp.Process(target=self._run_master, args=(world_size, rpc_backend)) + self.p.start() + else: + self.p = None + + def __del__(self): + if not self.p is None: + self.p.join() + + @staticmethod + def _run_master(world_size, rpc_backend): + print(f"Starting master with pid {os.getpid()}") + rpc.init_rpc("master", rank=0, world_size=world_size, backend=rpc_backend) + rpc.shutdown() + + @property + def world_size(self): + return self._world_size + + @property + def address(self): + return self._address + + @property + def port(self): + return self._port + + @property + def is_secondary(self): + return self._secondary + + @property + def rpc_backend(self): + return self._rpc_backend diff --git a/genrl/distributed/experience_server.py b/genrl/distributed/experience_server.py new file mode 100644 index 00000000..c2c7bccb --- /dev/null +++ b/genrl/distributed/experience_server.py @@ -0,0 +1,22 @@ +from genrl.distributed import Node +from genrl.distributed.core import store_rref + +import torch.distributed.rpc as rpc + + +class ExperienceServer(Node): + def __init__(self, name, master, buffer, rank=None): + super(ExperienceServer, self).__init__(name, master, rank) + self.init_proc( + target=self.run_paramater_server, + kwargs=dict(buffer=buffer), + ) + self.start_proc() + + @staticmethod + def run_paramater_server(name, world_size, rank, buffer, rpc_backend, **kwargs): + rpc.init_rpc(name=name, world_size=world_size, rank=rank, backend=rpc_backend) + print(f"{name}: Initialised RPC") + store_rref(name, rpc.RRef(buffer)) + print(f"{name}: Serving experience buffer") + rpc.shutdown() diff --git a/genrl/distributed/learner.py b/genrl/distributed/learner.py new file mode 100644 index 00000000..c9181052 --- /dev/null +++ b/genrl/distributed/learner.py @@ -0,0 +1,47 @@ +from genrl.distributed import Node +from genrl.distributed.core import get_proxy, store_rref + +import torch.distributed.rpc as rpc + + +class LearnerNode(Node): + def __init__( + self, + name, + master, + parameter_server_name, + experience_server_name, + trainer, + rank=None, + ): + super(LearnerNode, self).__init__(name, master, rank) + self.init_proc( + target=self.learn, + kwargs=dict( + parameter_server_name=parameter_server_name, + experience_server_name=experience_server_name, + trainer=trainer, + ), + ) + self.start_proc() + + @staticmethod + def learn( + name, + world_size, + rank, + parameter_server_name, + experience_server_name, + trainer, + rpc_backend, + **kwargs, + ): + rpc.init_rpc(name=name, world_size=world_size, rank=rank, backend=rpc_backend) + print(f"{name}: Initialised RPC") + store_rref(name, rpc.RRef(trainer)) + parameter_server = get_proxy(parameter_server_name) + experience_server = get_proxy(experience_server_name) + print(f"{name}: Beginning training") + trainer.train(parameter_server, experience_server) + trainer.set_completed(True) + rpc.shutdown() diff --git a/genrl/distributed/parameter_server.py b/genrl/distributed/parameter_server.py new file mode 100644 index 00000000..cde20d14 --- /dev/null +++ b/genrl/distributed/parameter_server.py @@ -0,0 +1,36 @@ +from genrl.distributed import Node +from genrl.distributed.core import store_rref + +import torch.distributed.rpc as rpc + + +class ParameterServer(Node): + def __init__(self, name, master, init_params, rank=None): + super(ParameterServer, self).__init__(name, master, rank) + self.init_proc( + target=self.run_paramater_server, + kwargs=dict(init_params=init_params), + ) + self.start_proc() + + @staticmethod + def run_paramater_server( + name, world_size, rank, init_params, rpc_backend, **kwargs + ): + rpc.init_rpc(name=name, world_size=world_size, rank=rank, backend=rpc_backend) + print(f"{name}: Initialised RPC") + params = WeightHolder(init_weights=init_params) + store_rref(name, rpc.RRef(params)) + print(f"{name}: Serving parameters") + rpc.shutdown() + + +class WeightHolder: + def __init__(self, init_weights): + self._weights = init_weights + + def store_weights(self, weights): + self._weights = weights + + def get_weights(self): + return self._weights diff --git a/genrl/trainers/__init__.py b/genrl/trainers/__init__.py index 7410831b..c5448cc3 100644 --- a/genrl/trainers/__init__.py +++ b/genrl/trainers/__init__.py @@ -3,3 +3,4 @@ from genrl.trainers.classical import ClassicalTrainer # noqa from genrl.trainers.offpolicy import OffPolicyTrainer # noqa from genrl.trainers.onpolicy import OnPolicyTrainer # noqa +from genrl.trainers.distributed import DistributedTrainer # noqa diff --git a/genrl/trainers/distributed.py b/genrl/trainers/distributed.py new file mode 100644 index 00000000..33c50b2d --- /dev/null +++ b/genrl/trainers/distributed.py @@ -0,0 +1,48 @@ +from genrl.utils import safe_mean +from genrl.utils import Logger + + +class DistributedTrainer: + def __init__(self, agent): + self.agent = agent + self.env = self.agent.env + self._completed_training_flag = False + self.logger = Logger(formats=["stdout"]) + + + def train(self, parameter_server, experience_server): + raise NotImplementedError + + def is_completed(self): + return self._completed_training_flag + + def set_completed(self, value=True): + self._completed_training_flag = value + + def evaluate(self, timestep, render: bool = False) -> None: + """Evaluate performance of Agent + + Args: + render (bool): Option to render the environment during evaluation + """ + episode_rewards = [] + for i in range(5): + state = self.env.reset() + done = False + episode_reward = 0 + while not done: + action = self.agent.select_action(state, deterministic=True) + next_state, reward, done, _ = self.env.step(action) + episode_reward += reward + state = next_state + episode_rewards.append(episode_reward) + episode_reward = 0 + + self.logger.write( + { + "timestep": timestep, + # **self.agent.get_logging_params(), + "Episode Reward": safe_mean(episode_rewards), + }, + "timestep", + )