diff --git a/tutorials/MLinPL/qmix_mutation.ipynb b/tutorials/MLinPL/qmix_mutation.ipynb index e9793a20f..607ea6097 100644 --- a/tutorials/MLinPL/qmix_mutation.ipynb +++ b/tutorials/MLinPL/qmix_mutation.ipynb @@ -606,13 +606,15 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "module = TensorDictModule(\n", - " net, in_keys=[(\"agents\", \"observation\")], out_keys=[(\"agents\", \"action_value\")]\n", - ")" + "## We need to wrap our MultiAgentMLP in a TensorDictModule\n", + "## in_keys =[(\"agents\", \"observation\")]\n", + "## out_keys=[(\"agents\", \"action_value\")]\n", + "\n", + "module = ..." ] }, { @@ -649,11 +651,13 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "qnet = SafeSequential(module, value_module)" + "## We need to wrap the two modules in a SafeSequential\n", + "\n", + "qnet = SafeSequential(...)" ] }, { @@ -732,13 +736,15 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "## We need to initialize the DataCollector\n", + "\n", "collector = SyncDataCollector(\n", - " env,\n", - " qnet_explore,\n", + " ...,\n", + " ...,\n", " device=device,\n", " storing_device=device,\n", " frames_per_batch=frames_per_batch,\n", @@ -790,11 +796,12 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "loss_module = QMixerLoss(qnet, mixer, delay_value=True)\n", + "## Create the QmixerLoss object\n", + "loss_module = QMixerLoss(..., ..., delay_value=True)\n", "\n", "loss_module.set_keys(\n", " action_value=(\"agents\", \"action_value\"),\n",