diff --git a/baselines/main.py b/baselines/main.py index 46f8bd9d..1d0aaa17 100644 --- a/baselines/main.py +++ b/baselines/main.py @@ -17,16 +17,17 @@ from og_marl.loggers import JsonWriter, WandbLogger from og_marl.offline_dataset import download_and_unzip_vault from og_marl.replay_buffers import FlashbaxReplayBuffer +from og_marl.tf2.networks import CNNEmbeddingNetwork from og_marl.tf2.systems import get_system from og_marl.tf2.utils import set_growing_gpu_memory set_growing_gpu_memory() FLAGS = flags.FLAGS -flags.DEFINE_string("env", "smac_v1", "Environment name.") -flags.DEFINE_string("scenario", "3m", "Environment scenario name.") +flags.DEFINE_string("env", "pettingzoo", "Environment name.") +flags.DEFINE_string("scenario", "pursuit", "Environment scenario name.") flags.DEFINE_string("dataset", "Good", "Dataset type.: 'Good', 'Medium', 'Poor' or 'Replay' ") -flags.DEFINE_string("system", "qmix+cql", "System name.") +flags.DEFINE_string("system", "qmix", "System name.") flags.DEFINE_integer("seed", 42, "Seed.") flags.DEFINE_float("trainer_steps", 5e4, "Number of training steps.") flags.DEFINE_integer("batch_size", 64, "Number of training steps.") @@ -52,7 +53,7 @@ def main(_): print("Vault not found. Exiting.") return - logger = WandbLogger(project="default", config=config) + logger = WandbLogger(config=config) json_writer = JsonWriter( "logs", @@ -65,6 +66,9 @@ def main(_): ) system_kwargs = {"add_agent_id_to_obs": True} + if FLAGS.scenario == "pursuit": + system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork() + system = get_system(FLAGS.system, env, logger, **system_kwargs) system.train_offline(buffer, max_trainer_steps=FLAGS.trainer_steps, json_writer=json_writer) diff --git a/examples/tf2/run_all_baselines.py b/examples/tf2/run_all_baselines.py index fe708c22..102aa118 100644 --- a/examples/tf2/run_all_baselines.py +++ b/examples/tf2/run_all_baselines.py @@ -16,20 +16,19 @@ os.environ["SUPPRESS_GR_PROMPT"] = "1" scenario_system_configs = { - # "smac_v1": { - # "3m": { - # "systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", - # "maicq", "dbc"], - # "datasets": ["Good"], - # "trainer_steps": 3000, - # "evaluate_every": 1000, - # }, - # }, + "smac_v1": { + "3m": { + "systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq", "dbc"], + "datasets": ["Good"], + "trainer_steps": 2000, + "evaluate_every": 1000, + }, + }, "pettingzoo": { "pursuit": { "systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq", "dbc"], "datasets": ["Good"], - "trainer_steps": 3000, + "trainer_steps": 2000, "evaluate_every": 1000, }, }, @@ -80,8 +79,6 @@ if scenario_name == "pursuit": system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork() - if system_name in ["qmix", "qmix+cql", "qmix+bcq", "maicq"]: - system_kwargs["state_embedding_network"] = CNNEmbeddingNetwork() system = get_system(system_name, env, logger, **system_kwargs) diff --git a/og_marl/tf2/networks.py b/og_marl/tf2/networks.py index 481ba32d..2f219ff1 100644 --- a/og_marl/tf2/networks.py +++ b/og_marl/tf2/networks.py @@ -128,12 +128,12 @@ def __init__( def __call__(self, x: Tensor) -> Tensor: """Embed a pixel-styled input into a vector using a conv net. - We assume the input has leading batch, time and agent dims. With trailing dims + We assume the input has trailing dims being the width, height and channel dimensions of the input. The output shape is then given as (B,T,N,Embed) """ - leading_dims = x.shape[:-3] # B,T,N + leading_dims = x.shape[:-3] trailing_dims = x.shape[-3:] # W,H,C x = tf.reshape(x, shape=(-1, *trailing_dims)) diff --git a/og_marl/tf2/systems/bc.py b/og_marl/tf2/systems/bc.py index 66015bf0..f0523b46 100644 --- a/og_marl/tf2/systems/bc.py +++ b/og_marl/tf2/systems/bc.py @@ -106,14 +106,13 @@ def _tf_select_actions( agent_observation, i, len(self._environment.possible_agents) ) agent_observation = tf.expand_dims(agent_observation, axis=0) # add batch dimension - logits, next_rnn_states[agent] = self._policy_network( - agent_observation, rnn_states[agent] - ) + embedding = self._policy_embedding_network(agent_observation) + logits, next_rnn_states[agent] = self._policy_network(embedding, rnn_states[agent]) probs = tf.nn.softmax(logits) if legal_actions is not None: - agent_legals = tf.expand_dims(legal_actions[agent], axis=0) + agent_legals = tf.cast(tf.expand_dims(legal_actions[agent], axis=0), "float32") probs = (probs * agent_legals) / tf.reduce_sum( probs * agent_legals ) # mask and renorm @@ -121,7 +120,7 @@ def _tf_select_actions( action = tfp.distributions.Categorical(probs=probs).sample(1) # Store agent action - actions[agent] = action + actions[agent] = action[0] return actions, next_rnn_states diff --git a/og_marl/tf2/systems/idrqn_bcq.py b/og_marl/tf2/systems/idrqn_bcq.py index 35170e88..a0cea06f 100644 --- a/og_marl/tf2/systems/idrqn_bcq.py +++ b/og_marl/tf2/systems/idrqn_bcq.py @@ -185,8 +185,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st # Get trainable variables variables = ( *self._q_network.trainable_variables, - *self._q_embedding_network.trainable_variables - * self._behaviour_cloning_network.trainable_variables, + *self._q_embedding_network.trainable_variables, + *self._behaviour_cloning_network.trainable_variables, *self._bc_embedding_network.trainable_variables, ) diff --git a/og_marl/tf2/systems/qmix_bcq.py b/og_marl/tf2/systems/qmix_bcq.py index fd1fc18e..70ecad6c 100644 --- a/og_marl/tf2/systems/qmix_bcq.py +++ b/og_marl/tf2/systems/qmix_bcq.py @@ -207,8 +207,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st *self._q_network.trainable_variables, *self._q_embedding_network.trainable_variables, *self._mixer.trainable_variables, - *self._state_embedding_network.trainable_variables - * self._behaviour_cloning_network.trainable_variables, + *self._state_embedding_network.trainable_variables, + *self._behaviour_cloning_network.trainable_variables, ) # Compute gradients.