-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
09de61b
commit 7e18dac
Showing
2 changed files
with
309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,309 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "3dcd8e2d-94c3-4972-a2de-e0813bc02689", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Imports\n", | ||
"import scipy.io\n", | ||
"import numpy as np\n", | ||
"import torch\n", | ||
"import torch.nn as nn" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "7ff3167f-9472-492d-a236-6f254d075378", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load the .mat file\n", | ||
"data = scipy.io.loadmat('data/condsForSimJ2moMuscles.mat')\n", | ||
"\n", | ||
"# Extract condsForSim struct\n", | ||
"conds_for_sim = data['condsForSim']" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "e42fa7c4-279a-4ae1-922d-53207ed26455", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/tmp/ipykernel_204181/828212273.py:27: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:230.)\n", | ||
" go_envelope_all.append(torch.tensor(go_envelope_condition, dtype=torch.float32))\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Go Envelope Tensor Shape: torch.Size([216, 296, 1])\n", | ||
"Plan Tensor Shape: torch.Size([216, 296, 15])\n", | ||
"Muscle Tensor Shape: torch.Size([216, 296, 8])\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Initialize lists to store data for all conditions\n", | ||
"go_envelope_all = []\n", | ||
"plan_all = []\n", | ||
"muscle_all = []\n", | ||
"\n", | ||
"# Get the number of conditions (rows) and delay durations (columns)\n", | ||
"num_conditions, num_delays = conds_for_sim.shape\n", | ||
"\n", | ||
"# Loop through each condition and extract data\n", | ||
"for i in range(num_conditions): # 27 conditions\n", | ||
" go_envelope_condition = []\n", | ||
" plan_condition = []\n", | ||
" muscle_condition = []\n", | ||
"\n", | ||
" for j in range(num_delays): # 8 delay durations\n", | ||
" condition = conds_for_sim[i, j]\n", | ||
"\n", | ||
" go_envelope = condition['goEnvelope']\n", | ||
" plan = condition['plan']\n", | ||
" muscle = condition['muscle']\n", | ||
"\n", | ||
" go_envelope_condition.append(go_envelope)\n", | ||
" plan_condition.append(plan)\n", | ||
" muscle_condition.append(muscle)\n", | ||
"\n", | ||
" # Stack data for each condition\n", | ||
" go_envelope_all.append(torch.tensor(go_envelope_condition, dtype=torch.float32))\n", | ||
" plan_all.append(torch.tensor(plan_condition, dtype=torch.float32))\n", | ||
" muscle_all.append(torch.tensor(muscle_condition, dtype=torch.float32))\n", | ||
"\n", | ||
"# Stack data for all conditions\n", | ||
"go_envelope_tensor = torch.stack(go_envelope_all)\n", | ||
"plan_tensor = torch.stack(plan_all)\n", | ||
"muscle_tensor = torch.stack(muscle_all)\n", | ||
"\n", | ||
"# Reshape to merge the first two dimensions\n", | ||
"go_envelope_tensor = go_envelope_tensor.reshape(-1, *go_envelope_tensor.shape[2:])\n", | ||
"plan_tensor = plan_tensor.reshape(-1, *plan_tensor.shape[2:])\n", | ||
"muscle_tensor = muscle_tensor.reshape(-1, *muscle_tensor.shape[2:])\n", | ||
"\n", | ||
"# Print shapes\n", | ||
"print(f\"Go Envelope Tensor Shape: {go_envelope_tensor.shape}\")\n", | ||
"print(f\"Plan Tensor Shape: {plan_tensor.shape}\")\n", | ||
"print(f\"Muscle Tensor Shape: {muscle_tensor.shape}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "d84ac60b-be75-45af-8d22-eca9c41fec6c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Define a RNN model\n", | ||
"\n", | ||
"class SimpleRNN(nn.Module):\n", | ||
" def __init__(self, input_size, hidden_size, output_size):\n", | ||
" super(SimpleRNN, self).__init__()\n", | ||
" self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)\n", | ||
" self.fc = nn.Linear(hidden_size, output_size)\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" out, _ = self.rnn(x)\n", | ||
" out = self.fc(out)\n", | ||
" return out\n", | ||
"\n", | ||
"# Assuming the sizes from data\n", | ||
"input_size = 16 # Calculated based on input shapes\n", | ||
"hidden_size = 64 \n", | ||
"output_size = 8 # Based on the output shape\n", | ||
"\n", | ||
"model = SimpleRNN(input_size, hidden_size, output_size)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "e8ecaf47-96ba-4120-87e6-39b6c9d08188", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Adjust the shape of go_envelope_tensor\n", | ||
"go_envelope_tensor_adjusted = go_envelope_tensor.squeeze(-1) # Removes the last dimension\n", | ||
"\n", | ||
"# Check dimensions after squeezing\n", | ||
"if go_envelope_tensor_adjusted.dim() == plan_tensor.dim() - 1:\n", | ||
" # Add an extra dimension to go_envelope_tensor_adjusted to match plan_tensor\n", | ||
" go_envelope_tensor_adjusted = go_envelope_tensor_adjusted.unsqueeze(-1)\n", | ||
"\n", | ||
" # Now concatenate along the last dimension\n", | ||
" input_tensor = torch.cat((go_envelope_tensor_adjusted, plan_tensor), dim=-1)\n", | ||
"else:\n", | ||
" raise RuntimeError(\"Dimension mismatch after adjustment\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "160bdb5c-2056-40f8-a2f3-597a7e771c53", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Train Input Size: torch.Size([172, 296, 16])\n", | ||
"Test Input Size: torch.Size([44, 296, 16])\n", | ||
"Train Target Size: torch.Size([172, 296, 8])\n", | ||
"Test Target Size: torch.Size([44, 296, 8])\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Split data into training and testing sets\n", | ||
"train_size = int(0.8 * input_tensor.size(0))\n", | ||
"train_input = input_tensor[:train_size]\n", | ||
"test_input = input_tensor[train_size:]\n", | ||
"train_target = muscle_tensor[:train_size]\n", | ||
"test_target = muscle_tensor[train_size:]\n", | ||
"\n", | ||
"# Verify the sizes\n", | ||
"print(f\"Train Input Size: {train_input.size()}\")\n", | ||
"print(f\"Test Input Size: {test_input.size()}\")\n", | ||
"print(f\"Train Target Size: {train_target.size()}\")\n", | ||
"print(f\"Test Target Size: {test_target.size()}\")\n", | ||
"\n", | ||
"# Define loss function and optimizer\n", | ||
"criterion = nn.MSELoss()\n", | ||
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "1d4aab79-08da-4fc1-be1c-cc580a5f2d70", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Epoch [0/500], Loss: 0.06945484131574631, Test Loss: 0.06215358525514603\n", | ||
"Epoch [10/500], Loss: 0.021324191242456436, Test Loss: 0.0196370966732502\n", | ||
"Epoch [20/500], Loss: 0.019416367635130882, Test Loss: 0.01877027191221714\n", | ||
"Epoch [30/500], Loss: 0.01791669800877571, Test Loss: 0.016670530661940575\n", | ||
"Epoch [40/500], Loss: 0.017395811155438423, Test Loss: 0.01648167334496975\n", | ||
"Epoch [50/500], Loss: 0.01716764271259308, Test Loss: 0.016113314777612686\n", | ||
"Epoch [60/500], Loss: 0.016996093094348907, Test Loss: 0.016008485108613968\n", | ||
"Epoch [70/500], Loss: 0.016899049282073975, Test Loss: 0.015784764662384987\n", | ||
"Epoch [80/500], Loss: 0.016822611913084984, Test Loss: 0.01577206887304783\n", | ||
"Epoch [90/500], Loss: 0.016755780205130577, Test Loss: 0.01563873328268528\n", | ||
"Epoch [100/500], Loss: 0.016685090959072113, Test Loss: 0.01555646862834692\n", | ||
"Epoch [110/500], Loss: 0.01658358983695507, Test Loss: 0.015401231124997139\n", | ||
"Epoch [120/500], Loss: 0.016355616971850395, Test Loss: 0.015009786002337933\n", | ||
"Epoch [130/500], Loss: 0.01487487368285656, Test Loss: 0.02703819051384926\n", | ||
"Epoch [140/500], Loss: 0.016456956043839455, Test Loss: 0.01516042836010456\n", | ||
"Epoch [150/500], Loss: 0.016544444486498833, Test Loss: 0.015271191485226154\n", | ||
"Epoch [160/500], Loss: 0.016446998342871666, Test Loss: 0.01525502372533083\n", | ||
"Epoch [170/500], Loss: 0.016341326758265495, Test Loss: 0.01504436507821083\n", | ||
"Epoch [180/500], Loss: 0.01623906008899212, Test Loss: 0.014940104447305202\n", | ||
"Epoch [190/500], Loss: 0.016118096187710762, Test Loss: 0.014749975875020027\n", | ||
"Epoch [200/500], Loss: 0.015950532630085945, Test Loss: 0.014500990509986877\n", | ||
"Epoch [210/500], Loss: 0.01567700505256653, Test Loss: 0.014078512787818909\n", | ||
"Epoch [220/500], Loss: 0.015037690289318562, Test Loss: 0.013052952475845814\n", | ||
"Epoch [230/500], Loss: 0.02549377828836441, Test Loss: 0.014654227532446384\n", | ||
"Epoch [240/500], Loss: 0.017849242314696312, Test Loss: 0.01578238233923912\n", | ||
"Epoch [250/500], Loss: 0.01701129786670208, Test Loss: 0.016172470524907112\n", | ||
"Epoch [260/500], Loss: 0.016579564660787582, Test Loss: 0.015089149586856365\n", | ||
"Epoch [270/500], Loss: 0.01641070283949375, Test Loss: 0.01525115966796875\n", | ||
"Epoch [280/500], Loss: 0.01629340648651123, Test Loss: 0.014967095106840134\n", | ||
"Epoch [290/500], Loss: 0.01620139181613922, Test Loss: 0.014865301549434662\n", | ||
"Epoch [300/500], Loss: 0.01611190102994442, Test Loss: 0.014730663038790226\n", | ||
"Epoch [310/500], Loss: 0.016006972640752792, Test Loss: 0.014566123485565186\n", | ||
"Epoch [320/500], Loss: 0.0158640518784523, Test Loss: 0.014335619285702705\n", | ||
"Epoch [330/500], Loss: 0.015639012679457664, Test Loss: 0.013991651125252247\n", | ||
"Epoch [340/500], Loss: 0.015176555141806602, Test Loss: 0.013243998400866985\n", | ||
"Epoch [350/500], Loss: 0.013796578161418438, Test Loss: 0.0115931062027812\n", | ||
"Epoch [360/500], Loss: 0.015857117250561714, Test Loss: 0.014362258836627007\n", | ||
"Epoch [370/500], Loss: 0.01589975878596306, Test Loss: 0.014321285299956799\n", | ||
"Epoch [380/500], Loss: 0.015759805217385292, Test Loss: 0.014182702638208866\n", | ||
"Epoch [390/500], Loss: 0.015622653998434544, Test Loss: 0.013964013196527958\n", | ||
"Epoch [400/500], Loss: 0.015410303138196468, Test Loss: 0.013604477979242802\n", | ||
"Epoch [410/500], Loss: 0.01506776176393032, Test Loss: 0.013094850815832615\n", | ||
"Epoch [420/500], Loss: 0.014383680187165737, Test Loss: 0.011991310864686966\n", | ||
"Epoch [430/500], Loss: 0.013644035905599594, Test Loss: 0.010731290094554424\n", | ||
"Epoch [440/500], Loss: 0.013392072170972824, Test Loss: 0.010501268319785595\n", | ||
"Epoch [450/500], Loss: 0.013526298105716705, Test Loss: 0.010595179162919521\n", | ||
"Epoch [460/500], Loss: 0.013375792652368546, Test Loss: 0.010180925950407982\n", | ||
"Epoch [470/500], Loss: 0.013830088078975677, Test Loss: 0.010914870537817478\n", | ||
"Epoch [480/500], Loss: 0.013349335640668869, Test Loss: 0.01034696213901043\n", | ||
"Epoch [490/500], Loss: 0.013634550385177135, Test Loss: 0.011291691102087498\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"criterion = nn.MSELoss()\n", | ||
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", | ||
"\n", | ||
"num_epochs = 500\n", | ||
"\n", | ||
"# Training loop\n", | ||
"for epoch in range(num_epochs):\n", | ||
" model.train()\n", | ||
" optimizer.zero_grad()\n", | ||
" output = model(train_input)\n", | ||
" loss = criterion(output, train_target)\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
"\n", | ||
" # Evaluation on test data\n", | ||
" model.eval()\n", | ||
" with torch.no_grad():\n", | ||
" test_output = model(test_input)\n", | ||
" test_loss = criterion(test_output, test_target)\n", | ||
"\n", | ||
" # Print loss every few epochs for monitoring\n", | ||
" if epoch % 10 == 0:\n", | ||
" print(f'Epoch [{epoch}/{num_epochs}], Loss: {loss.item()}, Test Loss: {test_loss.item()}')\n", | ||
"\n", | ||
"# Save the model\n", | ||
"torch.save(model.state_dict(), 'simple_model_checkpoint.pth')\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "8e0c04e1-5944-4e14-b058-740f0af10007", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Binary file not shown.