Skip to content

Commit

Permalink
Merge pull request #404 from neuromatch/prepod-day-9
Browse files Browse the repository at this point in the history
d9
  • Loading branch information
SamueleBolotta authored Aug 12, 2024
2 parents fbc2449 + c97a010 commit 605ee69
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 126 deletions.
116 changes: 74 additions & 42 deletions tutorials/W2D5_Mysteries/W2D5_Tutorial1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -490,10 +490,12 @@
" plt.show()\n",
"\n",
"# Function to test the model using the configured testing patterns\n",
"def testing(testing_patterns, n_samples, loaded_model, loaded_model_2, factor):\n",
"# Function to test the model using the configured testing patterns\n",
"def testing(testing_patterns, n_samples, loaded_model, loaded_model_2,factor):\n",
"\n",
" def generate_chance_level(shape):\n",
" chance_level = np.random.rand(*shape).tolist()\n",
" return chance_level\n",
" chance_level = np.random.rand(*shape).tolist()\n",
" return chance_level\n",
"\n",
" results_for_plotting = []\n",
" max_values_output_first_order = []\n",
Expand All @@ -506,43 +508,37 @@
" mse_losses_values = []\n",
" discrimination_performances = []\n",
"\n",
" # Assume you have a predefined device\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
" # Move models to the correct device\n",
" loaded_model.to(device)\n",
" loaded_model_2.to(device)\n",
"\n",
" # Iterate through each set of testing patterns and targets\n",
" for i in range(len(testing_patterns)):\n",
" with torch.no_grad(): # Ensure no gradients are computed during testing\n",
" # For low vision the stimulus threshold was set to 0.3 as can seen in the generate_patters function\n",
" threshold = 0.5\n",
" if i == 2:\n",
" threshold = 0.3\n",
"\n",
" # Obtain input data and move to the correct device\n",
" input_data = testing_patterns[i][0].to(device)\n",
" #For low vision the stimulus threshold was set to 0.3 as can seen in the generate_patters function\n",
" threshold=0.5\n",
" if i==2:\n",
" threshold=0.15\n",
"\n",
" # Obtain output from the first order model\n",
" hidden_representation, output_first_order = loaded_model(input_data)\n",
" input_data = testing_patterns[i][0]\n",
" hidden_representation, output_first_order = loaded_model(input_data)\n",
" output_second_order = loaded_model_2(input_data, output_first_order)\n",
"\n",
" delta = 100 * factor\n",
" delta=100*factor\n",
"\n",
" # Calculate discrimination performance\n",
" discrimination_performance = round(\n",
" (output_first_order[delta:].argmax(dim=1) == input_data[delta:].argmax(dim=1)).to(float).mean().item(), 2\n",
" )\n",
" print(\"driscriminator\")\n",
" print((output_first_order[delta:].argmax(dim=1) == input_data[delta:].argmax(dim=1)).to(float).mean())\n",
" discrimination_performance = round((output_first_order[delta:].argmax(dim=1) == input_data[delta:].argmax(dim=1)).to(float).mean().item(), 2)\n",
" discrimination_performances.append(discrimination_performance)\n",
"\n",
" # Generate chance level and move to the correct device\n",
" chance_level = torch.Tensor(generate_chance_level((200 * factor, 100))).to(device)\n",
" discrimination_random = round(\n",
" (chance_level[delta:].argmax(dim=1) == input_data[delta:].argmax(dim=1)).to(float).mean().item(), 2\n",
" )\n",
"\n",
" # Count all patterns in the dataset\n",
" chance_level = torch.Tensor( generate_chance_level((200*factor,100)))\n",
" discrimination_random= round((chance_level[delta:].argmax(dim=1) == input_data[delta:].argmax(dim=1)).to(float).mean().item(), 2)\n",
" print(\"chance level\" , discrimination_random)\n",
"\n",
"\n",
"\n",
" #count all patterns in the dataset\n",
" wagers = output_second_order[delta:].cpu()\n",
"\n",
" _, targets_2 = torch.max(testing_patterns[i][1], 1)\n",
Expand All @@ -555,11 +551,16 @@
" predicted_np = wagers.numpy().flatten()\n",
" targets_2_np = targets_2.numpy()\n",
"\n",
" #print(\"number of targets,\" , len(targets_2_np))\n",
"\n",
" print(predicted_np)\n",
" print(targets_2_np)\n",
"\n",
" # Calculate True Positives, True Negatives, False Positives, and False Negatives\n",
" TP = np.sum((predicted_np > threshold) & (targets_2_np > threshold))\n",
" TN = np.sum((predicted_np < threshold) & (targets_2_np < threshold))\n",
" FP = np.sum((predicted_np > threshold) & (targets_2_np < threshold))\n",
" FN = np.sum((predicted_np < threshold) & (targets_2_np > threshold))\n",
" TP = np.sum((predicted_np > threshold) & (targets_2_np > threshold))\n",
" TN = np.sum((predicted_np < threshold ) & (targets_2_np < threshold))\n",
" FP = np.sum((predicted_np > threshold) & (targets_2_np < threshold))\n",
" FN = np.sum((predicted_np < threshold) & (targets_2_np > threshold))\n",
"\n",
" # Compute precision, recall, F1 score, and accuracy for both high and low wager scenarios\n",
" precision_h, recall_h, f1_score_h, accuracy_h = compute_metrics(TP, TN, FP, FN)\n",
Expand All @@ -570,8 +571,8 @@
" results_for_plotting.append({\n",
" \"counts\": [[TP, FP, TP + FP]],\n",
" \"metrics\": [[precision_h, recall_h, f1_score_h, accuracy_h]],\n",
" \"title_results\": f\"Results Table - Set {i + 1}\",\n",
" \"title_metrics\": f\"Metrics Table - Set {i + 1}\"\n",
" \"title_results\": f\"Results Table - Set {i+1}\",\n",
" \"title_metrics\": f\"Metrics Table - Set {i+1}\"\n",
" })\n",
"\n",
" # Plot input and output of the first-order network\n",
Expand All @@ -587,25 +588,45 @@
" max_values_patterns_tensor.append(max_vals_pat.tolist())\n",
" max_indices_patterns_tensor.append(max_inds_pat.tolist())\n",
"\n",
" fig, axs = plt.subplots(1, 2, figsize=(15, 5))\n",
"\n",
" # Scatter plot of indices: patterns_tensor vs. output_first_order\n",
" axs[0].scatter(max_indices_patterns_tensor[i], max_indices_output_first_order[i], alpha=0.5)\n",
" axs[0].set_title(f'Stimuli location: Condition {i+1} - First Order Input vs. First Order Output')\n",
" axs[0].set_xlabel('First Order Input Indices')\n",
" axs[0].set_ylabel('First Order Output Indices')\n",
"\n",
" # Add quadratic fit to scatter plot\n",
" x_indices = max_indices_patterns_tensor[i]\n",
" y_indices = max_indices_output_first_order[i]\n",
" y_pred_indices = perform_quadratic_regression(x_indices, y_indices)\n",
" axs[0].plot(x_indices, y_pred_indices, color='skyblue')\n",
"\n",
"\n",
" # Calculate MSE loss for indices\n",
" mse_loss_indices = np.mean((np.array(x_indices) - np.array(y_indices)) ** 2)\n",
" mse_losses_indices.append(mse_loss_indices)\n",
"\n",
" # Scatter plot of values: patterns_tensor vs. output_first_order\n",
" axs[1].scatter(max_values_patterns_tensor[i], max_values_output_first_order[i], alpha=0.5)\n",
" axs[1].set_title(f'Stimuli Values: Condition {i+1} - First Order Input vs. First Order Output')\n",
" axs[1].set_xlabel('First Order Input Values')\n",
" axs[1].set_ylabel('First Order Output Values')\n",
"\n",
" # Add quadratic fit to scatter plot\n",
" x_values = max_values_patterns_tensor[i]\n",
" y_values = max_values_output_first_order[i]\n",
" y_pred_values = perform_quadratic_regression(x_values, y_values)\n",
" axs[1].plot(x_values, y_pred_values, color='skyblue')\n",
"\n",
" # Calculate MSE loss for values\n",
" mse_loss_values = np.mean((np.array(x_values) - np.array(y_values)) ** 2)\n",
" mse_losses_values.append(mse_loss_values)\n",
"\n",
" return f1_scores_wager, mse_losses_indices, mse_losses_values, discrimination_performances, results_for_plotting\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
" return f1_scores_wager, mse_losses_indices , mse_losses_values, discrimination_performances, results_for_plotting\n",
"\n",
"def generate_patterns(patterns_number, num_units, factor, condition = 0):\n",
" \"\"\"\n",
Expand All @@ -617,45 +638,56 @@
" # Returns lists of patterns, stimulus present/absent indicators, and second order targets\n",
" \"\"\"\n",
"\n",
" patterns_number=patterns_number*factor\n",
" patterns_number= patterns_number*factor\n",
"\n",
" patterns = [] # Store generated patterns\n",
" stim_present = [] # Indicators for when a stimulus is present in the pattern\n",
" stim_absent = [] # Indicators for when no stimulus is present\n",
" order_2_pr = [] # Second order network targets based on the presence or absence of stimulus\n",
"\n",
" baseline = 0\n",
" multiplier = 1\n",
" if condition == 0:\n",
" random_limit= 0.0\n",
" baseline = 0\n",
" multiplier = 1\n",
"\n",
" if condition == 1:\n",
" baseline = 0.020\n",
" random_limit= 0.02\n",
" baseline = 0.0012\n",
" multiplier = 1\n",
"\n",
" if condition == 2:\n",
" random_limit= 0.02\n",
" baseline = 0.0012\n",
" multiplier = 0.3\n",
"\n",
" # Generate patterns, half noise and half potential stimuli\n",
" for i in range(patterns_number):\n",
"\n",
" # First half: Noise patterns\n",
" if i < patterns_number // 2:\n",
" pattern = multiplier * np.random.uniform(0.0, 0.02, num_units) + baseline # Generate a noise pattern\n",
"\n",
" pattern = multiplier * np.random.uniform(0.0, random_limit, num_units) + baseline # Generate a noise pattern\n",
" patterns.append(pattern)\n",
" stim_present.append(np.zeros(num_units)) # Stimulus absent\n",
" order_2_pr.append([0.0 , 1.0]) # No stimulus, low wager\n",
"\n",
" # Second half: Stimulus patterns\n",
" else:\n",
" stimulus_number = random.randint(0, num_units - 1) # Choose a unit for potential stimulus\n",
" pattern = np.random.uniform(0.0, 0.02, num_units) + baseline\n",
" pattern[stimulus_number] = np.random.uniform(0.0, 1.0) # Set stimulus intensity\n",
" pattern = np.random.uniform(0.0, random_limit, num_units) + baseline\n",
" pattern[stimulus_number] = np.random.uniform(0.0, 1.0) * multiplier # Set stimulus intensity\n",
"\n",
" patterns.append(pattern)\n",
" present = np.zeros(num_units)\n",
" # Determine if stimulus is above discrimination threshold\n",
" if pattern[stimulus_number] >= 0.5:\n",
" if pattern[stimulus_number] >= multiplier/2:\n",
" order_2_pr.append([1.0 , 0.0]) # Stimulus detected, high wager\n",
" present[stimulus_number] = 1.0\n",
" else:\n",
" order_2_pr.append([0.0 , 1.0]) # Stimulus not detected, low wager\n",
" present[stimulus_number] = 0.0\n",
"\n",
" stim_present.append(present)\n",
" pattern[stimulus_number] = pattern[stimulus_number] * multiplier\n",
"\n",
"\n",
" patterns_tensor = torch.Tensor(patterns).to(device).requires_grad_(True)\n",
Expand Down
Loading

0 comments on commit 605ee69

Please sign in to comment.