diff --git a/tutorials/W1D1_Generalization/W1D1_Tutorial5.ipynb b/tutorials/W1D1_Generalization/W1D1_Tutorial5.ipynb new file mode 100644 index 000000000..9044dc409 --- /dev/null +++ b/tutorials/W1D1_Generalization/W1D1_Tutorial5.ipynb @@ -0,0 +1,309 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "3dcd8e2d-94c3-4972-a2de-e0813bc02689", + "metadata": {}, + "outputs": [], + "source": [ + "# Imports\n", + "import scipy.io\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7ff3167f-9472-492d-a236-6f254d075378", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the .mat file\n", + "data = scipy.io.loadmat('data/condsForSimJ2moMuscles.mat')\n", + "\n", + "# Extract condsForSim struct\n", + "conds_for_sim = data['condsForSim']" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e42fa7c4-279a-4ae1-922d-53207ed26455", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_204181/828212273.py:27: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:230.)\n", + " go_envelope_all.append(torch.tensor(go_envelope_condition, dtype=torch.float32))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Go Envelope Tensor Shape: torch.Size([216, 296, 1])\n", + "Plan Tensor Shape: torch.Size([216, 296, 15])\n", + "Muscle Tensor Shape: torch.Size([216, 296, 8])\n" + ] + } + ], + "source": [ + "# Initialize lists to store data for all conditions\n", + "go_envelope_all = []\n", + "plan_all = []\n", + "muscle_all = []\n", + "\n", + "# Get the number of conditions (rows) and delay durations (columns)\n", + "num_conditions, num_delays = conds_for_sim.shape\n", + "\n", + "# Loop through each condition and extract data\n", + "for i in range(num_conditions): # 27 conditions\n", + " go_envelope_condition = []\n", + " plan_condition = []\n", + " muscle_condition = []\n", + "\n", + " for j in range(num_delays): # 8 delay durations\n", + " condition = conds_for_sim[i, j]\n", + "\n", + " go_envelope = condition['goEnvelope']\n", + " plan = condition['plan']\n", + " muscle = condition['muscle']\n", + "\n", + " go_envelope_condition.append(go_envelope)\n", + " plan_condition.append(plan)\n", + " muscle_condition.append(muscle)\n", + "\n", + " # Stack data for each condition\n", + " go_envelope_all.append(torch.tensor(go_envelope_condition, dtype=torch.float32))\n", + " plan_all.append(torch.tensor(plan_condition, dtype=torch.float32))\n", + " muscle_all.append(torch.tensor(muscle_condition, dtype=torch.float32))\n", + "\n", + "# Stack data for all conditions\n", + "go_envelope_tensor = torch.stack(go_envelope_all)\n", + "plan_tensor = torch.stack(plan_all)\n", + "muscle_tensor = torch.stack(muscle_all)\n", + "\n", + "# Reshape to merge the first two dimensions\n", + "go_envelope_tensor = go_envelope_tensor.reshape(-1, *go_envelope_tensor.shape[2:])\n", + "plan_tensor = plan_tensor.reshape(-1, *plan_tensor.shape[2:])\n", + "muscle_tensor = muscle_tensor.reshape(-1, *muscle_tensor.shape[2:])\n", + "\n", + "# Print shapes\n", + "print(f\"Go Envelope Tensor Shape: {go_envelope_tensor.shape}\")\n", + "print(f\"Plan Tensor Shape: {plan_tensor.shape}\")\n", + "print(f\"Muscle Tensor Shape: {muscle_tensor.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d84ac60b-be75-45af-8d22-eca9c41fec6c", + "metadata": {}, + "outputs": [], + "source": [ + "# Define a RNN model\n", + "\n", + "class SimpleRNN(nn.Module):\n", + " def __init__(self, input_size, hidden_size, output_size):\n", + " super(SimpleRNN, self).__init__()\n", + " self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)\n", + " self.fc = nn.Linear(hidden_size, output_size)\n", + "\n", + " def forward(self, x):\n", + " out, _ = self.rnn(x)\n", + " out = self.fc(out)\n", + " return out\n", + "\n", + "# Assuming the sizes from data\n", + "input_size = 16 # Calculated based on input shapes\n", + "hidden_size = 64 \n", + "output_size = 8 # Based on the output shape\n", + "\n", + "model = SimpleRNN(input_size, hidden_size, output_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e8ecaf47-96ba-4120-87e6-39b6c9d08188", + "metadata": {}, + "outputs": [], + "source": [ + "# Adjust the shape of go_envelope_tensor\n", + "go_envelope_tensor_adjusted = go_envelope_tensor.squeeze(-1) # Removes the last dimension\n", + "\n", + "# Check dimensions after squeezing\n", + "if go_envelope_tensor_adjusted.dim() == plan_tensor.dim() - 1:\n", + " # Add an extra dimension to go_envelope_tensor_adjusted to match plan_tensor\n", + " go_envelope_tensor_adjusted = go_envelope_tensor_adjusted.unsqueeze(-1)\n", + "\n", + " # Now concatenate along the last dimension\n", + " input_tensor = torch.cat((go_envelope_tensor_adjusted, plan_tensor), dim=-1)\n", + "else:\n", + " raise RuntimeError(\"Dimension mismatch after adjustment\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "160bdb5c-2056-40f8-a2f3-597a7e771c53", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Input Size: torch.Size([172, 296, 16])\n", + "Test Input Size: torch.Size([44, 296, 16])\n", + "Train Target Size: torch.Size([172, 296, 8])\n", + "Test Target Size: torch.Size([44, 296, 8])\n" + ] + } + ], + "source": [ + "# Split data into training and testing sets\n", + "train_size = int(0.8 * input_tensor.size(0))\n", + "train_input = input_tensor[:train_size]\n", + "test_input = input_tensor[train_size:]\n", + "train_target = muscle_tensor[:train_size]\n", + "test_target = muscle_tensor[train_size:]\n", + "\n", + "# Verify the sizes\n", + "print(f\"Train Input Size: {train_input.size()}\")\n", + "print(f\"Test Input Size: {test_input.size()}\")\n", + "print(f\"Train Target Size: {train_target.size()}\")\n", + "print(f\"Test Target Size: {test_target.size()}\")\n", + "\n", + "# Define loss function and optimizer\n", + "criterion = nn.MSELoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1d4aab79-08da-4fc1-be1c-cc580a5f2d70", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [0/500], Loss: 0.06945484131574631, Test Loss: 0.06215358525514603\n", + "Epoch [10/500], Loss: 0.021324191242456436, Test Loss: 0.0196370966732502\n", + "Epoch [20/500], Loss: 0.019416367635130882, Test Loss: 0.01877027191221714\n", + "Epoch [30/500], Loss: 0.01791669800877571, Test Loss: 0.016670530661940575\n", + "Epoch [40/500], Loss: 0.017395811155438423, Test Loss: 0.01648167334496975\n", + "Epoch [50/500], Loss: 0.01716764271259308, Test Loss: 0.016113314777612686\n", + "Epoch [60/500], Loss: 0.016996093094348907, Test Loss: 0.016008485108613968\n", + "Epoch [70/500], Loss: 0.016899049282073975, Test Loss: 0.015784764662384987\n", + "Epoch [80/500], Loss: 0.016822611913084984, Test Loss: 0.01577206887304783\n", + "Epoch [90/500], Loss: 0.016755780205130577, Test Loss: 0.01563873328268528\n", + "Epoch [100/500], Loss: 0.016685090959072113, Test Loss: 0.01555646862834692\n", + "Epoch [110/500], Loss: 0.01658358983695507, Test Loss: 0.015401231124997139\n", + "Epoch [120/500], Loss: 0.016355616971850395, Test Loss: 0.015009786002337933\n", + "Epoch [130/500], Loss: 0.01487487368285656, Test Loss: 0.02703819051384926\n", + "Epoch [140/500], Loss: 0.016456956043839455, Test Loss: 0.01516042836010456\n", + "Epoch [150/500], Loss: 0.016544444486498833, Test Loss: 0.015271191485226154\n", + "Epoch [160/500], Loss: 0.016446998342871666, Test Loss: 0.01525502372533083\n", + "Epoch [170/500], Loss: 0.016341326758265495, Test Loss: 0.01504436507821083\n", + "Epoch [180/500], Loss: 0.01623906008899212, Test Loss: 0.014940104447305202\n", + "Epoch [190/500], Loss: 0.016118096187710762, Test Loss: 0.014749975875020027\n", + "Epoch [200/500], Loss: 0.015950532630085945, Test Loss: 0.014500990509986877\n", + "Epoch [210/500], Loss: 0.01567700505256653, Test Loss: 0.014078512787818909\n", + "Epoch [220/500], Loss: 0.015037690289318562, Test Loss: 0.013052952475845814\n", + "Epoch [230/500], Loss: 0.02549377828836441, Test Loss: 0.014654227532446384\n", + "Epoch [240/500], Loss: 0.017849242314696312, Test Loss: 0.01578238233923912\n", + "Epoch [250/500], Loss: 0.01701129786670208, Test Loss: 0.016172470524907112\n", + "Epoch [260/500], Loss: 0.016579564660787582, Test Loss: 0.015089149586856365\n", + "Epoch [270/500], Loss: 0.01641070283949375, Test Loss: 0.01525115966796875\n", + "Epoch [280/500], Loss: 0.01629340648651123, Test Loss: 0.014967095106840134\n", + "Epoch [290/500], Loss: 0.01620139181613922, Test Loss: 0.014865301549434662\n", + "Epoch [300/500], Loss: 0.01611190102994442, Test Loss: 0.014730663038790226\n", + "Epoch [310/500], Loss: 0.016006972640752792, Test Loss: 0.014566123485565186\n", + "Epoch [320/500], Loss: 0.0158640518784523, Test Loss: 0.014335619285702705\n", + "Epoch [330/500], Loss: 0.015639012679457664, Test Loss: 0.013991651125252247\n", + "Epoch [340/500], Loss: 0.015176555141806602, Test Loss: 0.013243998400866985\n", + "Epoch [350/500], Loss: 0.013796578161418438, Test Loss: 0.0115931062027812\n", + "Epoch [360/500], Loss: 0.015857117250561714, Test Loss: 0.014362258836627007\n", + "Epoch [370/500], Loss: 0.01589975878596306, Test Loss: 0.014321285299956799\n", + "Epoch [380/500], Loss: 0.015759805217385292, Test Loss: 0.014182702638208866\n", + "Epoch [390/500], Loss: 0.015622653998434544, Test Loss: 0.013964013196527958\n", + "Epoch [400/500], Loss: 0.015410303138196468, Test Loss: 0.013604477979242802\n", + "Epoch [410/500], Loss: 0.01506776176393032, Test Loss: 0.013094850815832615\n", + "Epoch [420/500], Loss: 0.014383680187165737, Test Loss: 0.011991310864686966\n", + "Epoch [430/500], Loss: 0.013644035905599594, Test Loss: 0.010731290094554424\n", + "Epoch [440/500], Loss: 0.013392072170972824, Test Loss: 0.010501268319785595\n", + "Epoch [450/500], Loss: 0.013526298105716705, Test Loss: 0.010595179162919521\n", + "Epoch [460/500], Loss: 0.013375792652368546, Test Loss: 0.010180925950407982\n", + "Epoch [470/500], Loss: 0.013830088078975677, Test Loss: 0.010914870537817478\n", + "Epoch [480/500], Loss: 0.013349335640668869, Test Loss: 0.01034696213901043\n", + "Epoch [490/500], Loss: 0.013634550385177135, Test Loss: 0.011291691102087498\n" + ] + } + ], + "source": [ + "criterion = nn.MSELoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", + "\n", + "num_epochs = 500\n", + "\n", + "# Training loop\n", + "for epoch in range(num_epochs):\n", + " model.train()\n", + " optimizer.zero_grad()\n", + " output = model(train_input)\n", + " loss = criterion(output, train_target)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # Evaluation on test data\n", + " model.eval()\n", + " with torch.no_grad():\n", + " test_output = model(test_input)\n", + " test_loss = criterion(test_output, test_target)\n", + "\n", + " # Print loss every few epochs for monitoring\n", + " if epoch % 10 == 0:\n", + " print(f'Epoch [{epoch}/{num_epochs}], Loss: {loss.item()}, Test Loss: {test_loss.item()}')\n", + "\n", + "# Save the model\n", + "torch.save(model.state_dict(), 'simple_model_checkpoint.pth')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e0c04e1-5944-4e14-b058-740f0af10007", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/W1D1_Generalization/simple_model_checkpoint.pth b/tutorials/W1D1_Generalization/simple_model_checkpoint.pth new file mode 100644 index 000000000..eeba2a007 Binary files /dev/null and b/tutorials/W1D1_Generalization/simple_model_checkpoint.pth differ