diff --git a/docs/img/image.png b/docs/img/two_route_net.png
similarity index 100%
rename from docs/img/image.png
rename to docs/img/two_route_net.png
diff --git a/docs/img/vdn.png b/docs/img/vdn.png
new file mode 100644
index 000000000..20b1ea8b5
Binary files /dev/null and b/docs/img/vdn.png differ
diff --git a/tutorials/MarlAlgorithms/qmix_mutation.ipynb b/tutorials/MarlAlgorithms/qmix_mutation.ipynb
index 0054a5b3d..cd8477660 100644
--- a/tutorials/MarlAlgorithms/qmix_mutation.ipynb
+++ b/tutorials/MarlAlgorithms/qmix_mutation.ipynb
@@ -75,18 +75,15 @@
"import torch\n",
"\n",
"from tensordict.nn import TensorDictModule, TensorDictSequential\n",
- "from torch.distributions import Categorical\n",
"from torchrl.envs.libs.pettingzoo import PettingZooWrapper\n",
"from torchrl.envs.transforms import TransformedEnv, RewardSum\n",
"from torchrl.envs.utils import check_env_specs\n",
- "from torchrl.modules import MultiAgentMLP\n",
"from torch import nn\n",
"from torchrl._utils import logger as torchrl_logger\n",
"from torchrl.collectors import SyncDataCollector\n",
"from torchrl.data import TensorDictReplayBuffer\n",
"from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement\n",
"from torchrl.data.replay_buffers.storages import LazyTensorStorage\n",
- "from torchrl.envs import RewardSum, TransformedEnv\n",
"from torchrl.modules import EGreedyModule, QValueModule, SafeSequential\n",
"from torchrl.modules.models.multiagent import MultiAgentMLP, QMixer\n",
"from torchrl.objectives import SoftUpdate, ValueEstimators\n",
diff --git a/tutorials/MarlAlgorithms/vdn_mutation.ipynb b/tutorials/MarlAlgorithms/vdn_mutation.ipynb
index 0005397c5..1a975675b 100644
--- a/tutorials/MarlAlgorithms/vdn_mutation.ipynb
+++ b/tutorials/MarlAlgorithms/vdn_mutation.ipynb
@@ -1,50 +1,98 @@
{
"cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# VDN algorithm implementation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> In this notebook, we implement a state-of-the-art Multi Agent Reinforcement Leaning (MARL) algorithm **[VDN](https://arxiv.org/abs/1706.05296)** in our environment. **VDN** is a deep algorithm for cooperative MARL, particularly suited for situations where agents receive a single, shared reward. Value-decomposition networks are a step towards automatically decomposing complex learning problems into local, more readile learnable sub-problems.\n",
+ "\n",
+ "\n",
+ "> Tutorial based on [VDN TorchRL Tutorial](https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/qmix_vdn.py)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "> Picture taken from VDN [paper](https://arxiv.org/pdf/1706.05296)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### High-level overview of VDN algorithm\n",
+ "\n",
+ "The joint action-value function for the system can be additively decomposed into value functions accross agents:\n",
+ "\n",
+ "$$\n",
+ "Q((h^{1}, h^{2}, \\ldots, h^{d}), (a^{1}, a^{2}, \\ldots, a^{d})) \\approx \\sum_{i=1}^{d} \\tilde{Q}_i(h^{i}, a^{i}),\n",
+ "$$\n",
+ "\n",
+ "\n",
+ "where the $\\tilde{Q}_i$ depends only on each agent's local observations.\n",
+ "\n",
+ "**Value-Decomposition** outperforms both centralized and fully independent learning approaches. When combined with additional techniques, it consistently yields agents that significantly surpass their centralized and independent counterparts.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Simulation overview"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> We simulate our environment with an initial population of **20 human agents**. These agents navigate the environment and eventually converge on the fastest path. After this convergence, we will transition **10 of these human agents** into **machine agents**, specifically autonomous vehicles (AVs), which will then employ the QMIX reinforcement learning algorithms to further refine their learning."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Imported libraries"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
- "import sys\n",
- "import os\n",
- "from tqdm import tqdm\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
"import os\n",
- "import pandas as pd\n",
- "from tensordict.nn import TensorDictModule, TensorDictSequential\n",
- "from tensordict.nn.distributions import NormalParamExtractor\n",
+ "import sys\n",
+ "import time\n",
"import torch\n",
+ "from torch import nn\n",
+ "\n",
+ "\n",
+ "from tensordict.nn import TensorDictModule, TensorDictSequential\n",
"from torchrl.collectors import SyncDataCollector\n",
- "from torch.distributions import Categorical\n",
"from torchrl.envs.libs.pettingzoo import PettingZooWrapper\n",
- "from torchrl.envs.transforms import TransformedEnv, RewardSum\n",
"from torchrl.envs.utils import check_env_specs\n",
- "from torchrl.data.replay_buffers import ReplayBuffer\n",
"from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement\n",
- "from torchrl.data.replay_buffers.storages import LazyTensorStorage\n",
- "from torchrl.modules import MultiAgentMLP, ProbabilisticActor\n",
- "from torchrl.objectives.value import GAE\n",
- "from torchrl.objectives import ClipPPOLoss, ValueEstimators\n",
- "from tqdm import tqdm\n",
- "from tensordict.nn import TensorDictModule, TensorDictSequential\n",
- "from torch import nn\n",
"from torchrl._utils import logger as torchrl_logger\n",
- "from torchrl.collectors import SyncDataCollector\n",
"from torchrl.data import TensorDictReplayBuffer\n",
- "from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement\n",
"from torchrl.data.replay_buffers.storages import LazyTensorStorage\n",
"from torchrl.envs import RewardSum, TransformedEnv\n",
- "from torchrl.envs.libs.vmas import VmasEnv\n",
- "from torchrl.envs.utils import ExplorationType, set_exploration_type\n",
"from torchrl.modules import EGreedyModule, QValueModule, SafeSequential\n",
- "from torchrl.modules.models.multiagent import MultiAgentMLP, QMixer, VDNMixer\n",
+ "from torchrl.modules.models.multiagent import MultiAgentMLP, VDNMixer\n",
"from torchrl.objectives import SoftUpdate, ValueEstimators\n",
"from torchrl.objectives.multiagent.qmixer import QMixerLoss\n",
- "import sys\n",
- "import os\n",
- "import json\n",
+ "\n",
"\n",
"sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../')))\n",
"\n",
@@ -56,6 +104,13 @@
"os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Hyperparameters setting"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 2,
@@ -104,7 +159,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "#### Environment creation"
+ "#### Environment initialization"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> In this example, the environment initially contains only human agents."
]
},
{
@@ -134,8 +196,49 @@
"source": [
"params = get_params(\"params.json\")\n",
"\n",
- "env = TrafficEnvironment(params[kc.RUNNER], params[kc.ENVIRONMENT], params[kc.SIMULATOR], params[kc.AGENT_GEN], params[kc.AGENTS], params[kc.PLOTTER])\n",
- "\n",
+ "env = TrafficEnvironment(params[kc.RUNNER], params[kc.ENVIRONMENT], params[kc.SIMULATOR], params[kc.AGENT_GEN], params[kc.AGENTS], params[kc.PLOTTER])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of total agents is: 20 \n",
+ "\n",
+ "Agents are: [Human 0, Human 1, Human 2, Human 3, Human 4, Human 5, Human 6, Human 7, Human 8, Human 9, Human 10, Human 11, Human 12, Human 13, Human 14, Human 15, Human 16, Human 17, Human 18, Human 19] \n",
+ "\n",
+ "Number of human agents is: 20 \n",
+ "\n",
+ "Number of machine agents (autonomous vehicles) is: 0 \n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Number of total agents is: \", len(env.all_agents), \"\\n\")\n",
+ "print(\"Agents are: \", env.all_agents, \"\\n\")\n",
+ "print(\"Number of human agents is: \", len(env.human_agents), \"\\n\")\n",
+ "print(\"Number of machine agents (autonomous vehicles) is: \", len(env.machine_agents), \"\\n\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Reset the environment and the connection with SUMO"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"env.start()\n",
"env.reset()"
]
@@ -166,6 +269,13 @@
"#### Mutation"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> **Mutation**: a portion of human agents are converted into machine agents (autonomous vehicles). You can adjust the number of agents to be mutated in the /params.json
file."
+ ]
+ },
{
"cell_type": "code",
"execution_count": 7,
@@ -177,7 +287,43 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of total agents is: 20 \n",
+ "\n",
+ "Agents are: [Machine 19, Machine 5, Machine 17, Machine 9, Machine 8, Machine 0, Machine 14, Machine 11, Machine 10, Machine 18, Human 1, Human 2, Human 3, Human 4, Human 6, Human 7, Human 12, Human 13, Human 15, Human 16] \n",
+ "\n",
+ "Number of human agents is: 10 \n",
+ "\n",
+ "Number of machine agents (autonomous vehicles) is: 10 \n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Number of total agents is: \", len(env.all_agents), \"\\n\")\n",
+ "print(\"Agents are: \", env.all_agents, \"\\n\")\n",
+ "print(\"Number of human agents is: \", len(env.human_agents), \"\\n\")\n",
+ "print(\"Number of machine agents (autonomous vehicles) is: \", len(env.machine_agents), \"\\n\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Create a group that contains all the machine (RL) agents.\n",
+ "\n",
+ "> **Hint:** the agents aren't competely independent in this example."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -185,8 +331,22 @@
"for machines in env.machine_agents:\n",
" machine_list.append(str(machines.id))\n",
" \n",
- "group = {'agents': machine_list}\n",
- " \n",
+ "group = {'agents': machine_list}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### PettingZoo environment wrapper"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"env = PettingZooWrapper(\n",
" env=env,\n",
" use_mask=True,\n",
@@ -198,28 +358,25 @@
]
},
{
- "cell_type": "code",
- "execution_count": 9,
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'agents': ['15', '4', '5', '18', '14', '7', '2', '1', '19', '11']}"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
"source": [
- "group"
+ "> The environment is defined by a series of metadata that describe what can be expected during its execution. \n",
+ "\n",
+ "There are four specs to look at:\n",
+ "\n",
+ "- action_spec
defines the action space;\n",
+ "\n",
+ "- reward_spec
defines the reward domain;\n",
+ "\n",
+ "- done_spec
defines the done domain;\n",
+ "\n",
+ "- observation_spec
which defines the domain of all other outputs from environment steps;"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -296,16 +453,63 @@
" device=cpu,\n",
" dtype=torch.bool,\n",
" domain=discrete), device=cpu, shape=torch.Size([10])), device=cpu, shape=torch.Size([]))\n",
- "env.group is: {'agents': ['15', '4', '5', '18', '14', '7', '2', '1', '19', '11']}\n"
+ "env.group is: {'agents': ['9', '18', '4', '19', '14', '15', '2', '10', '1', '0']}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"action_spec:\", env.full_action_spec, \"\\n\\n\")\n",
+ "print(\"reward_spec:\", env.full_reward_spec, \"\\n\\n\")\n",
+ "print(\"done_spec:\", env.full_done_spec, \"\\n\\n\")\n",
+ "print(\"observation_spec:\", env.observation_spec, \"\\n\\n\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Agent group mapping"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "env.group is: {'agents': ['19', '5', '17', '9', '8', '0', '14', '11', '10', '18']} \n",
+ "\n",
+ "\n"
]
}
],
"source": [
- "print(\"action_spec:\", env.full_action_spec)\n",
- "print(\"reward_spec:\", env.full_reward_spec)\n",
- "print(\"done_spec:\", env.full_done_spec)\n",
- "print(\"observation_spec:\", env.observation_spec)\n",
- "print(\"env.group is: \", env.group_map)\n"
+ "print(\"env.group is: \", env.group_map, \"\\n\\n\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Transforms"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> We can append any TorchRL transform we need to our environment. These will modify its input/output in some desired way. In multi-agent contexts, it is paramount to provide explicitly the keys to modify.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here we instatiate a RewardSum
transformer that will sum rewards over episode."
]
},
{
@@ -320,6 +524,13 @@
")"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The check_env_specs()
function runs a small rollout and compared it output against the environment specs. It will raise an error if the specs aren't properly defined."
+ ]
+ },
{
"cell_type": "code",
"execution_count": 12,
@@ -353,6 +564,13 @@
"#### Policy network"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Instantiate an `MPL` that can be used in multi-agent contexts."
+ ]
+ },
{
"cell_type": "code",
"execution_count": 14,
@@ -373,65 +591,36 @@
]
},
{
- "cell_type": "code",
- "execution_count": 15,
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "10"
- ]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
"source": [
- "env.n_agents"
+ "> The neural network is wrapped in a `TensorDictModule`, which is responsible for managing the input and output interactions with the tensordict. Specifically, the module reads from the specified `in_keys`, processes the inputs through the neural network, and writes the resulting outputs to the defined `out_keys`. "
]
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "module = TensorDictModule(\n",
+ " net, in_keys=[(\"agents\", \"observation\")], out_keys=[(\"agents\", \"action_value\")]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "MultiAgentMLP(\n",
- " (agent_networks): ModuleList(\n",
- " (0): MLP(\n",
- " (0): Linear(in_features=2, out_features=256, bias=True)\n",
- " (1): Tanh()\n",
- " (2): Linear(in_features=256, out_features=256, bias=True)\n",
- " (3): Tanh()\n",
- " (4): Linear(in_features=256, out_features=2, bias=True)\n",
- " )\n",
- " )\n",
- ")"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
"source": [
- "net"
+ "> **`QValueModule`** takes a tensor as input, which contains the `Q-values` (these values indicate how good it is to take each action in the given state). It identifies the action with the highest `Q-values` using the `argmax` operation."
]
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "module = TensorDictModule(\n",
- " net, in_keys=[(\"agents\", \"observation\")], out_keys=[(\"agents\", \"action_value\")]\n",
- " )\n",
- "\n",
"value_module = QValueModule(\n",
" action_value_key=(\"agents\", \"action_value\"),\n",
" out_keys=[\n",
@@ -441,11 +630,32 @@
" ],\n",
" spec=env.action_spec,\n",
" action_space=None,\n",
- ")\n",
- "\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> **`SafeSequential`** is a `TensordictModule` that will concatenate the parameter lists in a single list."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"qnet = SafeSequential(module, value_module)"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> In the already made `Q network` the **`Epsilon-Greedy exploration module`** is added. This module randomly updates the actions in a tensordict given an epsilon greedy exploration strategy."
+ ]
+ },
{
"cell_type": "code",
"execution_count": 18,
@@ -453,15 +663,24 @@
"outputs": [],
"source": [
"qnet_explore = TensorDictSequential(\n",
- " qnet,\n",
- " EGreedyModule(\n",
- " eps_init=0.3,\n",
- " eps_end=0,\n",
- " annealing_num_steps=int(total_frames * (1 / 2)),\n",
- " action_key=env.action_key,\n",
- " spec=env.action_spec,\n",
- " ),\n",
- " )"
+ " qnet,\n",
+ " EGreedyModule(\n",
+ " eps_init=0.3,\n",
+ " eps_end=0,\n",
+ " annealing_num_steps=int(total_frames * (1 / 2)),\n",
+ " action_key=env.action_key,\n",
+ " spec=env.action_spec,\n",
+ " ),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Mixer\n",
+ "\n",
+ "> `VDNMixer` mixes **the local Q values** of the agents into **a global Q value** by summing them together, accorbing to [VDN paper](https://arxiv.org/pdf/1706.05296)."
]
},
{
@@ -471,57 +690,33 @@
"outputs": [],
"source": [
"mixer = TensorDictModule(\n",
- " module=QMixer(\n",
- " state_shape=env.observation_spec[\n",
- " \"agents\", \"observation\"\n",
- " ].shape,\n",
- " mixing_embed_dim=32,\n",
- " n_agents=env.n_agents,\n",
- " device=device,\n",
- " ),\n",
- " in_keys=[(\"agents\", \"chosen_action_value\"), (\"agents\", \"observation\")],\n",
- " out_keys=[\"chosen_action_value\"],\n",
- " )"
+ " module=VDNMixer(\n",
+ " n_agents=env.n_agents,\n",
+ " device=device,\n",
+ " ),\n",
+ " in_keys=[(\"agents\", \"chosen_action_value\")],\n",
+ " out_keys=[\"chosen_action_value\"],\n",
+ ")"
]
},
{
- "cell_type": "code",
- "execution_count": 21,
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "TensorDictModule(\n",
- " module=QMixer(\n",
- " (hyper_w_1): Linear(in_features=20, out_features=320, bias=True)\n",
- " (hyper_w_final): Linear(in_features=20, out_features=32, bias=True)\n",
- " (hyper_b_1): Linear(in_features=20, out_features=32, bias=True)\n",
- " (V): Sequential(\n",
- " (0): Linear(in_features=20, out_features=32, bias=True)\n",
- " (1): ReLU()\n",
- " (2): Linear(in_features=32, out_features=1, bias=True)\n",
- " )\n",
- " ),\n",
- " device=cpu,\n",
- " in_keys=[('agents', 'chosen_action_value'), ('agents', 'observation')],\n",
- " out_keys=['chosen_action_value'])"
- ]
- },
- "execution_count": 21,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
"source": [
- "mixer"
+ "#### Collector"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "#### Collector"
+ "Collectors perform the following operations:\n",
+ "\n",
+ "1. **Reset Environment**: Initialize the environment.\n",
+ "2. **Compute Action**: Determine the next action using the policy and the latest observation.\n",
+ "3. **Execute Step**: Step through the environment with the computed action.\n",
+ "\n",
+ "These operations repeat until the environment signals to stop."
]
},
{
@@ -547,6 +742,14 @@
"#### Replay buffer"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> In an off-policy setting, the replay buffer exceeds the number of frames utilized for policy updates, allowing agents to learn from previous rollouts as well.\n",
+ "\n"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 23,
@@ -567,6 +770,13 @@
"#### Qmix loss function"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> `QMixerLoss` mixes *local agent q values* into *a global q value* according to a mixing network and then uses DQN updated on the global value."
+ ]
+ },
{
"cell_type": "code",
"execution_count": 24,
@@ -574,12 +784,14 @@
"outputs": [],
"source": [
"loss_module = QMixerLoss(qnet, mixer, delay_value=True)\n",
+ "\n",
"loss_module.set_keys(\n",
" action_value=(\"agents\", \"action_value\"),\n",
" local_value=(\"agents\", \"chosen_action_value\"),\n",
" global_value=\"chosen_action_value\",\n",
" action=env.action_key,\n",
")\n",
+ "\n",
"loss_module.make_value_estimator(ValueEstimators.TD0, gamma=gamma)\n",
"target_net_updater = SoftUpdate(loss_module, eps=1 - tau)\n",
"\n",
@@ -623,17 +835,16 @@
}
],
"source": [
- "import time\n",
- "\n",
"total_time = 0\n",
"total_frames = 0\n",
"sampling_start = time.time()\n",
+ "\n",
"for i, tensordict_data in enumerate(collector):\n",
" torchrl_logger.info(f\"\\nIteration {i}\")\n",
"\n",
" sampling_time = time.time() - sampling_start\n",
"\n",
- " # Remove agent dimension from reward (since it is shared in QMIX/VDN)\n",
+ " ## Generate the rollouts\n",
" tensordict_data.set(\n",
" (\"next\", \"reward\"), tensordict_data.get((\"next\", env.reward_key)).mean(-2)\n",
" )\n",
@@ -644,13 +855,17 @@
" )\n",
" del tensordict_data[\"next\", \"agents\", \"episode_reward\"]\n",
"\n",
+ "\n",
" current_frames = tensordict_data.numel()\n",
" total_frames += current_frames\n",
" data_view = tensordict_data.reshape(-1)\n",
" replay_buffer.extend(data_view)\n",
"\n",
+ "\n",
" training_tds = []\n",
" training_start = time.time()\n",
+ " \n",
+ " ## Update the policies of the learning agents\n",
" for _ in range(num_epochs):\n",
" for _ in range(frames_per_batch // minibatch_size):\n",
" subdata = replay_buffer.sample()\n",
@@ -680,6 +895,13 @@
" training_tds = torch.stack(training_tds) "
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Check `\\plots` directory to find the plots created from this experiment."
+ ]
+ },
{
"cell_type": "code",
"execution_count": 21,
diff --git a/tutorials/PettingZooEnv/README.md b/tutorials/PettingZooEnv/README.md
index 4ed245ceb..e9492bb87 100644
--- a/tutorials/PettingZooEnv/README.md
+++ b/tutorials/PettingZooEnv/README.md
@@ -4,7 +4,7 @@
In these notebooks, we use a two-route network in our simulator [SUMO](https://eclipse.dev/sumo/), where agents (vehicles) navigate from their predefined origin to their predefined destination point, aiming to determine the fastest route.
-
+
---