Skip to content

Commit

Permalink
clean up time rimming
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Nov 22, 2023
1 parent 7ba8cff commit 1246476
Show file tree
Hide file tree
Showing 12 changed files with 738 additions and 530 deletions.
202 changes: 202 additions & 0 deletions notebooks_stories/2_analyze_pilot/01_check_alignment.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2023-11-21 15:15:00,370] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/chansingh/imodelsx/.venv/lib/python3.11/site-packages/thinc/compat.py:36: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'\n",
" hasattr(torch, \"has_mps\")\n",
"/home/chansingh/imodelsx/.venv/lib/python3.11/site-packages/thinc/compat.py:37: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'\n",
" and torch.has_mps # type: ignore[attr-defined]\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from os.path import join\n",
"from tqdm import tqdm\n",
"import pandas as pd\n",
"import sys\n",
"import joblib\n",
"from scipy.special import softmax\n",
"import sasc.config\n",
"import numpy as np\n",
"from collections import defaultdict\n",
"from copy import deepcopy\n",
"import pandas as pd\n",
"import sasc.viz\n",
"from sasc import analyze_helper\n",
"from sasc.modules.fmri_module import convert_module_num_to_voxel_num\n",
"from sasc.config import FMRI_DIR, RESULTS_DIR\n",
"import dvu\n",
"dvu.set_style()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load pilot pickle"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"stories_data_dict = joblib.load(\n",
" join(sasc.config.RESULTS_DIR, 'pilot_story_data.pkl'))\n",
"pilot_data_dir = '/home/chansingh/mntv1/deep-fMRI/story_data/20230504'\n",
"\n",
"# stories_data_dict = joblib.load(\n",
"# join(sasc.config.RESULTS_DIR, 'pilot3_story_data.pkl'))\n",
"# pilot_data_dir = '/home/chansingh/mntv1/deep-fMRI/story_data/20231106'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Read all the info from stories into a single pickle file"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/6 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 6/6 [00:29<00:00, 5.00s/it]\n"
]
}
],
"source": [
"# load responses\n",
"default_story_idxs = np.where(\n",
" np.array(stories_data_dict['story_setting']) == 'default')[0]\n",
"resp_np_files = [stories_data_dict['story_name_new'][i].replace('_resps', '')\n",
" for i in default_story_idxs]\n",
"resps_dict = {\n",
" k: np.load(join(pilot_data_dir, k))\n",
" for k in tqdm(resp_np_files)\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Let's check the alignment"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"timings_list = stories_data_dict['timing']\n",
"story_names_list = list(resps_dict.keys())\n",
"resps = list(resps_dict.values())"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GenStory1.npy resp_length 435 story_trs 445 story_length 890.2769680636738\n",
"GenStory2.npy resp_length 382 story_trs 392 story_length 784.0274450678589\n",
"GenStory3.npy resp_length 322 story_trs 331 story_length 663.8185519243381\n",
"GenStory4.npy resp_length 405 story_trs 415 story_length 830.2452811179473\n",
"GenStory5.npy resp_length 407 story_trs 417 story_length 834.1917651885974\n",
"GenStory6.npy resp_length 470 story_trs 480 story_length 960.9583298519948\n"
]
}
],
"source": [
"TRIM = 5\n",
"for i in range(len(resps)):\n",
" t = timings_list[i]\n",
" duration_secs = t['time_running'].max()\n",
" print(story_names_list[i], 'resp_length',\n",
" resps[i].shape[0], 'story_trs',\n",
" int(duration_secs // 2), 'story_length', duration_secs) # , 'timings',\n",
" diff = int(duration_secs // 2) - resps[i].shape[0]\n",
" assert abs(diff - TRIM * 2) <= 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Let's check the paragraph<>timing match"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".llm",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "a9ff692d44ea03fd8a03facee7621117bbbb82def09bacaacf0a2cbc238b7b91"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
"# from sasc.modules.fmri_module import convert_module_num_to_voxel_num"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## NOTE: THESE ARE NOT APPROPRIATELY ADJUSTED FOR THE TRIM"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -46,20 +53,24 @@
"source": [
"def get_resp_chunks_list(story_data, resps_dict):\n",
" resp_chunks_list = []\n",
" for story_num in range(6): # range(1, 7)\n",
" for story_num in range(6): # range(1, 7)\n",
" rows = story_data[\"rows\"][story_num]\n",
" paragraphs = story_data[\"story_text\"][story_num].split(\"\\n\\n\")\n",
" timing = story_data[\"timing\"][story_num]\n",
"\n",
" resp_story = resps_dict[story_data[\"story_name_new\"][story_num]].T # (voxels, time)\n",
" resp_chunks = sasc.analyze_helper.get_resp_chunks(timing, paragraphs, resp_story, apply_offset=False)\n",
" # (voxels, time)\n",
" resp_story = resps_dict[story_data[\"story_name_new\"][story_num]].T\n",
" resp_chunks = sasc.analyze_helper.get_resp_chunks(\n",
" timing, paragraphs, resp_story, apply_offset=False)\n",
" resp_chunks[0] *= np.nan\n",
"\n",
" args = np.argsort(rows[\"expl\"].values)\n",
" resp_chunks_list.append([resp_chunks[i][rows['voxel_num']] for i in args])\n",
" \n",
" resp_chunks_list.append(\n",
" [resp_chunks[i][rows['voxel_num']] for i in args])\n",
"\n",
" return resp_chunks_list\n",
"\n",
"\n",
"resp_chunks_list = get_resp_chunks_list(story_data, resps_dict)\n",
"# resp_chunks_arr = np.array(resp_chunks_list).mean(axis=0)\n",
"expls = story_data[\"rows\"][0].sort_values(by=\"expl\")[\"expl\"].values"
Expand Down Expand Up @@ -101,13 +112,13 @@
"C = 6\n",
"R = 3\n",
"for voxel_num in range(n_voxels):\n",
" plt.subplot(R, C, voxel_num + 1) \n",
" plt.subplot(R, C, voxel_num + 1)\n",
" resps_rep = []\n",
" for story_num in range(n_stories):\n",
" # print(resp_chunks_list[story_num][voxel_num]) \n",
" # print(resp_chunks_list[story_num][voxel_num])\n",
" resps_rep.append(resp_chunks_list[story_num][voxel_num][voxel_num])\n",
" resps_rep = sorted(resps_rep, key=lambda x: len(x))\n",
" \n",
"\n",
" # interpolate each story to 100 time points\n",
" resps_rep_interp = []\n",
" for resp_rep in resps_rep:\n",
Expand All @@ -117,7 +128,6 @@
" resps_rep_interp = np.array(resps_rep_interp)\n",
" resps_rep_mean = np.nanmean(resps_rep_interp, axis=0)\n",
"\n",
"\n",
" # print('shape', resps_rep_mean.shape)\n",
" if viz_mean:\n",
" plt.plot(resps_rep_mean, color='gray', alpha=0.5)\n",
Expand All @@ -132,7 +142,7 @@
" for i, resp in enumerate(resps_rep):\n",
" plt.plot(resp, color=cmap(i / len(resps_rep)), alpha=0.5)\n",
" plt.ylim(-3, 3)\n",
" \n",
"\n",
" if voxel_num % C != 0:\n",
" plt.yticks([])\n",
" else:\n",
Expand All @@ -143,7 +153,7 @@
" plt.title(expls[voxel_num], fontsize='small')\n",
"\n",
"\n",
"plt.subplot(R, C, voxel_num + 2) \n",
"plt.subplot(R, C, voxel_num + 2)\n",
"plt.title('Mean', color='C0')\n",
"plt.plot(np.array(resps_rep_means).mean(axis=0), color='C0')\n",
"plt.grid()\n",
Expand Down Expand Up @@ -211,7 +221,8 @@
" count += 1\n",
" counts_by_len.append(count)\n",
" resp_mean /= count\n",
" plt.plot(resp_mean, label=x, color=cmap(i / len(np.unique(lens))), alpha=0.5, lw=3)\n",
" plt.plot(resp_mean, label=x, color=cmap(\n",
" i / len(np.unique(lens))), alpha=0.5, lw=3)\n",
" resp_means_by_len.append(np.mean(resp_mean))\n",
" resp_means_by_len_5.append(np.mean(resp_mean[5:]))\n",
"plt.ylabel('Mean response')\n",
Expand Down Expand Up @@ -610,7 +621,8 @@
}
],
"source": [
"out[out['Paragraph length (TRs)'] < 23]['Response mean (excluding 1st 5 TRs)'].mean()"
"out[out['Paragraph length (TRs)'] <\n",
" 23]['Response mean (excluding 1st 5 TRs)'].mean()"
]
},
{
Expand All @@ -630,7 +642,8 @@
}
],
"source": [
"out[out['Paragraph length (TRs)'] >= 23]['Response mean (excluding 1st 5 TRs)'].mean()"
"out[out['Paragraph length (TRs)'] >=\n",
" 23]['Response mean (excluding 1st 5 TRs)'].mean()"
]
}
],
Expand Down
Loading

0 comments on commit 1246476

Please sign in to comment.