Skip to content

Commit

Permalink
All discrete action systems running on pursuit.
Browse files Browse the repository at this point in the history
  • Loading branch information
jcformanek committed Mar 8, 2024
1 parent 3414723 commit b674489
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 27 deletions.
12 changes: 8 additions & 4 deletions baselines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
from og_marl.loggers import JsonWriter, WandbLogger
from og_marl.offline_dataset import download_and_unzip_vault
from og_marl.replay_buffers import FlashbaxReplayBuffer
from og_marl.tf2.networks import CNNEmbeddingNetwork
from og_marl.tf2.systems import get_system
from og_marl.tf2.utils import set_growing_gpu_memory

set_growing_gpu_memory()

FLAGS = flags.FLAGS
flags.DEFINE_string("env", "smac_v1", "Environment name.")
flags.DEFINE_string("scenario", "3m", "Environment scenario name.")
flags.DEFINE_string("env", "pettingzoo", "Environment name.")
flags.DEFINE_string("scenario", "pursuit", "Environment scenario name.")
flags.DEFINE_string("dataset", "Good", "Dataset type.: 'Good', 'Medium', 'Poor' or 'Replay' ")
flags.DEFINE_string("system", "qmix+cql", "System name.")
flags.DEFINE_string("system", "qmix", "System name.")
flags.DEFINE_integer("seed", 42, "Seed.")
flags.DEFINE_float("trainer_steps", 5e4, "Number of training steps.")
flags.DEFINE_integer("batch_size", 64, "Number of training steps.")
Expand All @@ -52,7 +53,7 @@ def main(_):
print("Vault not found. Exiting.")
return

logger = WandbLogger(project="default", config=config)
logger = WandbLogger(config=config)

json_writer = JsonWriter(
"logs",
Expand All @@ -65,6 +66,9 @@ def main(_):
)

system_kwargs = {"add_agent_id_to_obs": True}
if FLAGS.scenario == "pursuit":
system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork()

system = get_system(FLAGS.system, env, logger, **system_kwargs)

system.train_offline(buffer, max_trainer_steps=FLAGS.trainer_steps, json_writer=json_writer)
Expand Down
21 changes: 9 additions & 12 deletions examples/tf2/run_all_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,19 @@
os.environ["SUPPRESS_GR_PROMPT"] = "1"

scenario_system_configs = {
# "smac_v1": {
# "3m": {
# "systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq",
# "maicq", "dbc"],
# "datasets": ["Good"],
# "trainer_steps": 3000,
# "evaluate_every": 1000,
# },
# },
"smac_v1": {
"3m": {
"systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq", "dbc"],
"datasets": ["Good"],
"trainer_steps": 2000,
"evaluate_every": 1000,
},
},
"pettingzoo": {
"pursuit": {
"systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq", "dbc"],
"datasets": ["Good"],
"trainer_steps": 3000,
"trainer_steps": 2000,
"evaluate_every": 1000,
},
},
Expand Down Expand Up @@ -80,8 +79,6 @@

if scenario_name == "pursuit":
system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork()
if system_name in ["qmix", "qmix+cql", "qmix+bcq", "maicq"]:
system_kwargs["state_embedding_network"] = CNNEmbeddingNetwork()

system = get_system(system_name, env, logger, **system_kwargs)

Expand Down
4 changes: 2 additions & 2 deletions og_marl/tf2/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def __init__(
def __call__(self, x: Tensor) -> Tensor:
"""Embed a pixel-styled input into a vector using a conv net.
We assume the input has leading batch, time and agent dims. With trailing dims
We assume the input has trailing dims
being the width, height and channel dimensions of the input.
The output shape is then given as (B,T,N,Embed)
"""
leading_dims = x.shape[:-3] # B,T,N
leading_dims = x.shape[:-3]
trailing_dims = x.shape[-3:] # W,H,C

x = tf.reshape(x, shape=(-1, *trailing_dims))
Expand Down
9 changes: 4 additions & 5 deletions og_marl/tf2/systems/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,21 @@ def _tf_select_actions(
agent_observation, i, len(self._environment.possible_agents)
)
agent_observation = tf.expand_dims(agent_observation, axis=0) # add batch dimension
logits, next_rnn_states[agent] = self._policy_network(
agent_observation, rnn_states[agent]
)
embedding = self._policy_embedding_network(agent_observation)
logits, next_rnn_states[agent] = self._policy_network(embedding, rnn_states[agent])

probs = tf.nn.softmax(logits)

if legal_actions is not None:
agent_legals = tf.expand_dims(legal_actions[agent], axis=0)
agent_legals = tf.cast(tf.expand_dims(legal_actions[agent], axis=0), "float32")
probs = (probs * agent_legals) / tf.reduce_sum(
probs * agent_legals
) # mask and renorm

action = tfp.distributions.Categorical(probs=probs).sample(1)

# Store agent action
actions[agent] = action
actions[agent] = action[0]

return actions, next_rnn_states

Expand Down
4 changes: 2 additions & 2 deletions og_marl/tf2/systems/idrqn_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
# Get trainable variables
variables = (
*self._q_network.trainable_variables,
*self._q_embedding_network.trainable_variables
* self._behaviour_cloning_network.trainable_variables,
*self._q_embedding_network.trainable_variables,
*self._behaviour_cloning_network.trainable_variables,
*self._bc_embedding_network.trainable_variables,
)

Expand Down
4 changes: 2 additions & 2 deletions og_marl/tf2/systems/qmix_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
*self._q_network.trainable_variables,
*self._q_embedding_network.trainable_variables,
*self._mixer.trainable_variables,
*self._state_embedding_network.trainable_variables
* self._behaviour_cloning_network.trainable_variables,
*self._state_embedding_network.trainable_variables,
*self._behaviour_cloning_network.trainable_variables,
)

# Compute gradients.
Expand Down

0 comments on commit b674489

Please sign in to comment.