diff --git a/tutorials/MarlAlgorithms/sac_mutation.ipynb b/tutorials/MarlAlgorithms/sac_mutation.ipynb
new file mode 100644
index 000000000..dcef4ac89
--- /dev/null
+++ b/tutorials/MarlAlgorithms/sac_mutation.ipynb
@@ -0,0 +1,883 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# SAC algorithm implementation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> In this notebook, we implement a state-of-the-art Multi Agent Reinforcement Leaning (MARL) algorithms **[SAC](https://arxiv.org/pdf/1801.01290)** in our environment. SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy RL framework. \n",
+ "\n",
+ "\n",
+ "> Tutorial based on [SAC TorchRL Tutorial](https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/sac.py)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Simulation overview"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> We simulate our environment with an initial population of **20 human agents**. These agents navigate the environment and eventually converge on the fastest path. After this convergence, we will transition **10 of these human agents** into **machine agents**, specifically autonomous vehicles (AVs), which will then employ the QMIX reinforcement learning algorithms to further refine their learning."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Imported libraries"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "import os\n",
+ "import matplotlib.pyplot as plt\n",
+ "import os\n",
+ "import time\n",
+ "import torch\n",
+ "\n",
+ "from tensordict.nn import TensorDictModule\n",
+ "from tensordict.nn.distributions import NormalParamExtractor\n",
+ "from torch import nn\n",
+ "from torchrl.envs.libs.pettingzoo import PettingZooWrapper\n",
+ "from torch.distributions import Categorical, OneHotCategorical\n",
+ "from torchrl._utils import logger as torchrl_logger\n",
+ "from torchrl.collectors import SyncDataCollector\n",
+ "from torchrl.data import TensorDictReplayBuffer\n",
+ "from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement\n",
+ "from torchrl.data.replay_buffers.storages import LazyTensorStorage\n",
+ "from torchrl.envs import RewardSum, TransformedEnv\n",
+ "from torchrl.envs.libs.vmas import VmasEnv\n",
+ "from torchrl.envs.utils import ExplorationType, set_exploration_type\n",
+ "from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator\n",
+ "from torchrl.modules.models.multiagent import MultiAgentMLP\n",
+ "from torchrl.objectives import DiscreteSACLoss, SACLoss, SoftUpdate, ValueEstimators\n",
+ "\n",
+ "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../')))\n",
+ "\n",
+ "from RouteRL.keychain import Keychain as kc\n",
+ "from RouteRL.environment.environment import TrafficEnvironment\n",
+ "from RouteRL.services.plotter import Plotter\n",
+ "from RouteRL.utilities import get_params\n",
+ "\n",
+ "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Hyperparameters setting"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "params = get_params(\"params.json\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "device is: cpu\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Devices\n",
+ "device = (\n",
+ " torch.device(0)\n",
+ " if torch.cuda.is_available()\n",
+ " else torch.device(\"cpu\")\n",
+ ")\n",
+ "\n",
+ "machine_agents = params[\"agent_generation_parameters\"][\"new_machines_after_mutation\"]\n",
+ "\n",
+ "\n",
+ "print(\"device is: \", device)\n",
+ "vmas_device = device # The device where the simulator is run\n",
+ "\n",
+ "# Sampling\n",
+ "frames_per_batch = 4 * machine_agents # Number of team frames collected per training iteration\n",
+ "n_iters = 10 # Number of sampling and training iterations - the episodes the plotter plots\n",
+ "total_frames = frames_per_batch * n_iters\n",
+ "\n",
+ "# Training\n",
+ "num_epochs = 10 # Number of optimization steps per training iteration\n",
+ "minibatch_size = 2 # Size of the mini-batches in each optimization step\n",
+ "lr = 3e-4 # Learning rate\n",
+ "max_grad_norm = 1.0 # Maximum norm for the gradients\n",
+ "memory_size = 8 * machine_agents # Size of the replay buffer\n",
+ "tau = 0.005\n",
+ "gamma = 0.99 # discount factor"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Environment initialization"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> In this example, the environment initially contains only human agents."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[CONFIRMED] Environment variable exists: SUMO_HOME\n",
+ "[SUCCESS] Added module directory: C:\\Program Files (x86)\\Eclipse\\Sumo\\tools\n"
+ ]
+ }
+ ],
+ "source": [
+ "env = TrafficEnvironment(params[kc.RUNNER], params[kc.ENVIRONMENT], params[kc.SIMULATOR], params[kc.AGENT_GEN], params[kc.AGENTS], params[kc.PLOTTER])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of total agents is: 20 \n",
+ "\n",
+ "Agents are: [Human 0, Human 1, Human 2, Human 3, Human 4, Human 5, Human 6, Human 7, Human 8, Human 9, Human 10, Human 11, Human 12, Human 13, Human 14, Human 15, Human 16, Human 17, Human 18, Human 19] \n",
+ "\n",
+ "Number of human agents is: 20 \n",
+ "\n",
+ "Number of machine agents (autonomous vehicles) is: 0 \n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Number of total agents is: \", len(env.all_agents), \"\\n\")\n",
+ "print(\"Agents are: \", env.all_agents, \"\\n\")\n",
+ "print(\"Number of human agents is: \", len(env.human_agents), \"\\n\")\n",
+ "print(\"Number of machine agents (autonomous vehicles) is: \", len(env.machine_agents), \"\\n\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Reset the environment and the connection with SUMO"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "({}, {})"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "env.start()\n",
+ "env.reset()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Human learning"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "num_episodes = 100\n",
+ "\n",
+ "for episode in range(num_episodes):\n",
+ " env.step()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Mutation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> **Mutation**: a portion of human agents are converted into machine agents (autonomous vehicles). You can adjust the number of agents to be mutated in the /params.json
file."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "env.mutation()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of total agents is: 20 \n",
+ "\n",
+ "Agents are: [Machine 14, Machine 5, Machine 8, Machine 7, Machine 11, Machine 13, Machine 16, Machine 17, Machine 15, Machine 19, Human 0, Human 1, Human 2, Human 3, Human 4, Human 6, Human 9, Human 10, Human 12, Human 18] \n",
+ "\n",
+ "Number of human agents is: 10 \n",
+ "\n",
+ "Number of machine agents (autonomous vehicles) is: 10 \n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Number of total agents is: \", len(env.all_agents), \"\\n\")\n",
+ "print(\"Agents are: \", env.all_agents, \"\\n\")\n",
+ "print(\"Number of human agents is: \", len(env.human_agents), \"\\n\")\n",
+ "print(\"Number of machine agents (autonomous vehicles) is: \", len(env.machine_agents), \"\\n\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Create a group that contains all the machine (RL) agents.\n",
+ "\n",
+ "> **Hint:** As a feature of TorchRL multiagent, we are able to control the grouping of agents. We can group agents together (stacking their tensors) to leverage vectorization when passing them through the same neural network. We can split agents in different groups where they are heterogenous or should be processed by different neural netowkrs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "machine_list = []\n",
+ "for machines in env.machine_agents:\n",
+ " machine_list.append(str(machines.id))\n",
+ " \n",
+ "group = {'agents': machine_list}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### PettingZoo environment wrapper"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "env = PettingZooWrapper(\n",
+ " env=env,\n",
+ " use_mask=True, # Whether to use the mask in the outputs. It is important for AEC environments to mask out non-acting agents.\n",
+ " categorical_actions=True,\n",
+ " done_on_any = False, # Whether the environment’s done keys are set by aggregating the agent keys using any() (when True) or all() (when False).\n",
+ " group_map=group,\n",
+ " device=device\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> The environment is defined by a series of metadata that describe what can be expected during its execution. \n",
+ "\n",
+ "There are four specs to look at:\n",
+ "\n",
+ "- action_spec
defines the action space;\n",
+ "\n",
+ "- reward_spec
defines the reward domain;\n",
+ "\n",
+ "- done_spec
defines the done domain;\n",
+ "\n",
+ "- observation_spec
which defines the domain of all other outputs from environment steps;"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "action_spec: CompositeSpec(\n",
+ " agents: CompositeSpec(\n",
+ " action: DiscreteTensorSpec(\n",
+ " shape=torch.Size([10]),\n",
+ " space=DiscreteBox(n=3),\n",
+ " device=cpu,\n",
+ " dtype=torch.int64,\n",
+ " domain=discrete), device=cpu, shape=torch.Size([10])), device=cpu, shape=torch.Size([])) \n",
+ "\n",
+ "\n",
+ "reward_spec: CompositeSpec(\n",
+ " agents: CompositeSpec(\n",
+ " reward: UnboundedContinuousTensorSpec(\n",
+ " shape=torch.Size([10, 1]),\n",
+ " space=None,\n",
+ " device=cpu,\n",
+ " dtype=torch.float32,\n",
+ " domain=continuous), device=cpu, shape=torch.Size([10])), device=cpu, shape=torch.Size([])) \n",
+ "\n",
+ "\n",
+ "done_spec: CompositeSpec(\n",
+ " done: DiscreteTensorSpec(\n",
+ " shape=torch.Size([1]),\n",
+ " space=DiscreteBox(n=2),\n",
+ " device=cpu,\n",
+ " dtype=torch.bool,\n",
+ " domain=discrete),\n",
+ " terminated: DiscreteTensorSpec(\n",
+ " shape=torch.Size([1]),\n",
+ " space=DiscreteBox(n=2),\n",
+ " device=cpu,\n",
+ " dtype=torch.bool,\n",
+ " domain=discrete),\n",
+ " truncated: DiscreteTensorSpec(\n",
+ " shape=torch.Size([1]),\n",
+ " space=DiscreteBox(n=2),\n",
+ " device=cpu,\n",
+ " dtype=torch.bool,\n",
+ " domain=discrete),\n",
+ " agents: CompositeSpec(\n",
+ " done: DiscreteTensorSpec(\n",
+ " shape=torch.Size([10, 1]),\n",
+ " space=DiscreteBox(n=2),\n",
+ " device=cpu,\n",
+ " dtype=torch.bool,\n",
+ " domain=discrete),\n",
+ " terminated: DiscreteTensorSpec(\n",
+ " shape=torch.Size([10, 1]),\n",
+ " space=DiscreteBox(n=2),\n",
+ " device=cpu,\n",
+ " dtype=torch.bool,\n",
+ " domain=discrete),\n",
+ " truncated: DiscreteTensorSpec(\n",
+ " shape=torch.Size([10, 1]),\n",
+ " space=DiscreteBox(n=2),\n",
+ " device=cpu,\n",
+ " dtype=torch.bool,\n",
+ " domain=discrete), device=cpu, shape=torch.Size([10])), device=cpu, shape=torch.Size([])) \n",
+ "\n",
+ "\n",
+ "observation_spec: CompositeSpec(\n",
+ " agents: CompositeSpec(\n",
+ " observation: BoundedTensorSpec(\n",
+ " shape=torch.Size([10, 3]),\n",
+ " space=ContinuousBox(\n",
+ " low=Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, contiguous=True),\n",
+ " high=Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, contiguous=True)),\n",
+ " device=cpu,\n",
+ " dtype=torch.float32,\n",
+ " domain=continuous),\n",
+ " mask: DiscreteTensorSpec(\n",
+ " shape=torch.Size([10]),\n",
+ " space=DiscreteBox(n=2),\n",
+ " device=cpu,\n",
+ " dtype=torch.bool,\n",
+ " domain=discrete), device=cpu, shape=torch.Size([10])), device=cpu, shape=torch.Size([])) \n",
+ "\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"action_spec:\", env.full_action_spec, \"\\n\\n\")\n",
+ "print(\"reward_spec:\", env.full_reward_spec, \"\\n\\n\")\n",
+ "print(\"done_spec:\", env.full_done_spec, \"\\n\\n\")\n",
+ "print(\"observation_spec:\", env.observation_spec, \"\\n\\n\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Agent group mapping"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "env.group is: {'agents': ['14', '5', '8', '7', '11', '13', '16', '17', '15', '19']} \n",
+ "\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"env.group is: \", env.group_map, \"\\n\\n\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Transforms"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> We can append any TorchRL transform we need to our environment. These will modify its input/output in some desired way. In multi-agent contexts, it is paramount to provide explicitly the keys to modify.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here we instatiate a RewardSum
transformer that will sum rewards over episode."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "env = TransformedEnv(\n",
+ " env,\n",
+ " RewardSum(in_keys=[env.reward_key], out_keys=[(\"agents\", \"episode_reward\")]),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The check_env_specs()
function runs a small rollout and compared it output against the environment specs. It will raise an error if the specs aren't properly defined."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "reset_td = env.reset()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Policy network"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Instantiate an `MPL` that can be used in multi-agent contexts."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "share_parameters_policy = False \n",
+ "\n",
+ "actor_net = torch.nn.Sequential(\n",
+ " MultiAgentMLP(\n",
+ " n_agent_inputs = env.observation_spec[\"agents\", \"observation\"].shape[-1],\n",
+ " n_agent_outputs = env.action_spec.space.n,\n",
+ " n_agents = env.n_agents,\n",
+ " centralised=False,\n",
+ " share_params=share_parameters_policy,\n",
+ " device=device,\n",
+ " depth=3,\n",
+ " num_cells=64,\n",
+ " activation_class=torch.nn.Tanh,\n",
+ " ),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "policy_module = TensorDictModule(\n",
+ " actor_net,\n",
+ " in_keys=[(\"agents\", \"observation\")],\n",
+ " out_keys=[(\"agents\", \"logits\")],\n",
+ ") "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "policy = ProbabilisticActor(\n",
+ " module=policy_module,\n",
+ " spec=env.action_spec,\n",
+ " in_keys=[(\"agents\", \"logits\")],\n",
+ " out_keys=[env.action_key],\n",
+ " distribution_class=Categorical,\n",
+ " return_log_prob=True,\n",
+ " log_prob_key=(\"agents\", \"sample_log_prob\"),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Critic network"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> The critic reads the observations and returns the corresponding value estimates."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "centralised_critic = True\n",
+ "shared_parameters = True\n",
+ "\n",
+ "module = MultiAgentMLP(\n",
+ " n_agent_inputs=env.observation_spec[\"agents\", \"observation\"].shape[-1],\n",
+ " n_agent_outputs=env.action_spec.space.n,\n",
+ " n_agents=env.n_agents,\n",
+ " centralised=centralised_critic,\n",
+ " share_params=shared_parameters,\n",
+ " device=device,\n",
+ " depth=2,\n",
+ " num_cells=256,\n",
+ " activation_class=nn.Tanh,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "value_module = ValueOperator(\n",
+ " module=module,\n",
+ " in_keys=[(\"agents\", \"observation\")],\n",
+ " out_keys=[(\"agents\", \"action_value\")],\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Collector"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Collectors perform the following operations:\n",
+ "\n",
+ "1. **Reset Environment**: Initialize the environment.\n",
+ "2. **Compute Action**: Determine the next action using the policy and the latest observation.\n",
+ "3. **Execute Step**: Step through the environment with the computed action.\n",
+ "\n",
+ "These operations repeat until the environment signals to stop."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "collector = SyncDataCollector(\n",
+ " env,\n",
+ " policy,\n",
+ " device=device,\n",
+ " storing_device=device,\n",
+ " frames_per_batch=frames_per_batch,\n",
+ " total_frames=total_frames,\n",
+ ") "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Replay buffer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "replay_buffer = TensorDictReplayBuffer(\n",
+ " storage=LazyTensorStorage(memory_size, device=device),\n",
+ " sampler=SamplerWithoutReplacement(),\n",
+ " batch_size=minibatch_size,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### SAC loss function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loss_module = DiscreteSACLoss(\n",
+ " actor_network=policy,\n",
+ " qvalue_network=value_module,\n",
+ " delay_qvalue=True,\n",
+ " num_actions=env.action_spec.space.n,\n",
+ " action_space=env.action_spec,\n",
+ " )\n",
+ "\n",
+ "loss_module.set_keys(\n",
+ " action_value=(\"agents\", \"action_value\"),\n",
+ " action=env.action_key,\n",
+ " reward=env.reward_key,\n",
+ " done=(\"agents\", \"done\"),\n",
+ " terminated=(\"agents\", \"terminated\"),\n",
+ ")\n",
+ "\n",
+ "\n",
+ "loss_module.make_value_estimator(ValueEstimators.TD0, gamma=gamma)\n",
+ "target_net_updater = SoftUpdate(loss_module, eps=1 - tau)\n",
+ "\n",
+ "optim = torch.optim.Adam(loss_module.parameters(), lr)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Training loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2024-11-19 16:02:48,989 [torchrl][INFO] \n",
+ "Iteration 0\n",
+ "2024-11-19 16:03:06,578 [torchrl][INFO] \n",
+ "Iteration 1\n",
+ "2024-11-19 16:03:24,198 [torchrl][INFO] \n",
+ "Iteration 2\n",
+ "2024-11-19 16:03:40,670 [torchrl][INFO] \n",
+ "Iteration 3\n",
+ "2024-11-19 16:03:57,879 [torchrl][INFO] \n",
+ "Iteration 4\n",
+ "2024-11-19 16:04:14,749 [torchrl][INFO] \n",
+ "Iteration 5\n",
+ "2024-11-19 16:04:31,722 [torchrl][INFO] \n",
+ "Iteration 6\n",
+ "2024-11-19 16:04:48,887 [torchrl][INFO] \n",
+ "Iteration 7\n",
+ "2024-11-19 16:05:06,673 [torchrl][INFO] \n",
+ "Iteration 8\n",
+ "2024-11-19 16:05:23,290 [torchrl][INFO] \n",
+ "Iteration 9\n"
+ ]
+ }
+ ],
+ "source": [
+ "total_time = 0\n",
+ "total_frames = 0\n",
+ "sampling_start = time.time()\n",
+ "for i, tensordict_data in enumerate(collector):\n",
+ " torchrl_logger.info(f\"\\nIteration {i}\")\n",
+ "\n",
+ " sampling_time = time.time() - sampling_start\n",
+ "\n",
+ " current_frames = tensordict_data.numel()\n",
+ " total_frames += current_frames\n",
+ " data_view = tensordict_data.reshape(-1)\n",
+ " replay_buffer.extend(data_view)\n",
+ "\n",
+ " training_tds = []\n",
+ " training_start = time.time()\n",
+ " for _ in range(num_epochs):\n",
+ " for _ in range(frames_per_batch // minibatch_size):\n",
+ " subdata = replay_buffer.sample()\n",
+ " loss_vals = loss_module(subdata)\n",
+ " training_tds.append(loss_vals.detach())\n",
+ "\n",
+ " loss_value = (\n",
+ " loss_vals[\"loss_actor\"]\n",
+ " + loss_vals[\"loss_alpha\"]\n",
+ " + loss_vals[\"loss_qvalue\"]\n",
+ " )\n",
+ "\n",
+ " loss_value.backward()\n",
+ "\n",
+ " total_norm = torch.nn.utils.clip_grad_norm_(\n",
+ " loss_module.parameters(), max_grad_norm\n",
+ " )\n",
+ " training_tds[-1].set(\"grad_norm\", total_norm.mean())\n",
+ "\n",
+ " optim.step()\n",
+ " optim.zero_grad()\n",
+ " target_net_updater.step()\n",
+ "\n",
+ " collector.update_policy_weights_()\n",
+ "\n",
+ " training_time = time.time() - training_start\n",
+ "\n",
+ " iteration_time = sampling_time + training_time\n",
+ " total_time += iteration_time\n",
+ " training_tds = torch.stack(training_tds)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Check `\\plots` directory to find the plots created from this experiment."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:matplotlib.legend:No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.\n",
+ "WARNING:matplotlib.legend:No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from RouteRL.services import plotter\n",
+ "plotter(params[kc.PLOTTER])"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "torchrl",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}