Skip to content

Commit

Permalink
basic pilot3 nbs
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Nov 8, 2023
1 parent 0009905 commit a2b4fea
Show file tree
Hide file tree
Showing 7 changed files with 541 additions and 223 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -36,9 +36,34 @@
"from sasc.config import FMRI_DIR"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Read all the info from stories into a single pickle file"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# load stuff\n",
"# double check all of these, intro paragraph may be the same...\n",
"output_file = join(sasc.config.RESULTS_DIR, \"pilot2_story_data.pkl\")\n",
"STORIES_DIR = join(sasc.config.RESULTS_DIR, \"stories\")\n",
"story_mapping = {\n",
" \"interactions/uts02___jun14___seed=1\": \"GenStory7_resps.npy\",\n",
" \"interactions/uts02___jun14___seed=4\": \"GenStory8_resps.npy\",\n",
" \"polysemantic/uts02___jun14___seed=6\": \"GenStory9_resps.npy\",\n",
" \"polysemantic/uts02___jun14___seed=1\": \"GenStory10_resps.npy\",\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
Expand All @@ -47,51 +72,48 @@
"['/home/chansingh/automated-explanations/results/pilot2_story_data.pkl']"
]
},
"execution_count": 4,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# load stuff\n",
"# double check all of these, intro paragraph may be the same...\n",
"story_mapping = {\n",
" \"interactions/uts02___jun14___seed=1\": \"GenStory7_resps.npy\",\n",
" \"interactions/uts02___jun14___seed=4\": \"GenStory8_resps.npy\",\n",
" \"polysemantic/uts02___jun14___seed=6\": \"GenStory9_resps.npy\",\n",
" \"polysemantic/uts02___jun14___seed=1\": \"GenStory10_resps.npy\",\n",
"}\n",
"\n",
"STORIES_DIR = join(sasc.config.RESULTS_DIR, \"stories\")\n",
"story_names = story_mapping.keys() # os.listdir(STORIES_DIR)\n",
"# cluster_neighbors = joblib.load(join(FMRI_DIR, \"voxel_neighbors_and_pcs\", \"cluster_neighbors_v1.pkl\"))\n",
"perfs = joblib.load(join(sasc.config.FMRI_DIR, 'rj_models', 'opt_model', 'new_setup_performance.jbl'))\n",
"perfs = joblib.load(join(sasc.config.FMRI_DIR, 'rj_models',\n",
" 'opt_model', 'new_setup_performance.jbl'))\n",
"\n",
"# process with timings\n",
"story_data = defaultdict(list)\n",
"for story_idx, story_name in enumerate(story_names):\n",
" story_data[\"timing\"].append(\n",
" pd.read_csv(join(STORIES_DIR, story_name, \"timings_processed.csv\"))\n",
" )\n",
" story_data[\"story_name_original\"].append(story_name)\n",
" story_data[\"story_name_new\"].append(story_mapping[story_name])\n",
" story_data[\"story_text\"].append(\n",
"# add keys\n",
"stories_data_dict = defaultdict(list)\n",
"for story_idx, story_name in enumerate(story_mapping.keys()):\n",
" # add scalar story descriptions\n",
" stories_data_dict[\"story_name_original\"].append(story_name)\n",
" stories_data_dict[\"story_setting\"].append(story_name.split(\"/\")[0])\n",
" stories_data_dict[\"story_name_new\"].append(story_mapping[story_name])\n",
" stories_data_dict[\"story_text\"].append(\n",
" open(join(STORIES_DIR, story_name, \"story.txt\"), \"r\").read()\n",
" )\n",
" prompts_paragraphs = joblib.load(\n",
" join(STORIES_DIR, story_name, \"prompts_paragraphs.pkl\")\n",
" )\n",
" story_data[\"prompts\"].append(prompts_paragraphs[\"prompts\"])\n",
" story_data[\"paragraphs\"].append(prompts_paragraphs[\"paragraphs\"])\n",
"\n",
" # add paragraph-level descriptions\n",
" stories_data_dict[\"timing\"].append(\n",
" pd.read_csv(join(STORIES_DIR, story_name, \"timings_processed.csv\"))\n",
" )\n",
" stories_data_dict[\"prompts\"].append(prompts_paragraphs[\"prompts\"])\n",
" stories_data_dict[\"paragraphs\"].append(prompts_paragraphs[\"paragraphs\"])\n",
"\n",
" # add paragraph-level metadata\n",
" # rows\n",
" # rows = pd.read_csv(join(STORIES_DIR, story_name, \"rows.csv\"))\n",
" rows = pd.read_pickle(join(STORIES_DIR, story_name, \"rows.pkl\"))\n",
" rows[\"voxel_num\"] = rows.apply(\n",
" lambda row: convert_module_num_to_voxel_num(row[\"module_num\"], row[\"subject\"]),\n",
" story_metadata_per_paragraph = pd.read_pickle(\n",
" join(STORIES_DIR, story_name, \"rows.pkl\"))\n",
" story_metadata_per_paragraph[\"voxel_num\"] = story_metadata_per_paragraph.apply(\n",
" lambda row: convert_module_num_to_voxel_num(\n",
" row[\"module_num\"], row[\"subject\"]),\n",
" axis=1,\n",
" )\n",
" rows = rows[\n",
" story_metadata_per_paragraph = story_metadata_per_paragraph[\n",
" [\n",
" \"expl\",\n",
" \"module_num\",\n",
Expand All @@ -104,9 +126,10 @@
" \"voxel_num\",\n",
" ]\n",
" ]\n",
" rows['test_corr_new'] = rows['voxel_num'].apply(lambda x: perfs[x])\n",
" story_metadata_per_paragraph['test_corr_new'] = story_metadata_per_paragraph['voxel_num'].apply(\n",
" lambda x: perfs[x])\n",
" # rows['cluster_nums'] = rows['voxel_num'].map(cluster_neighbors)\n",
" story_data[\"rows\"].append(rows)\n",
" stories_data_dict[\"rows\"].append(story_metadata_per_paragraph)\n",
"\n",
" if \"interactions\" in list(story_mapping.keys())[story_idx]:\n",
" rows1 = pd.read_pickle(join(STORIES_DIR, story_name, \"rows1.pkl\"))\n",
Expand All @@ -123,14 +146,45 @@
" ),\n",
" axis=1,\n",
" )\n",
" story_data['voxel_num1'].append(rows1['voxel_num'])\n",
" story_data['voxel_num2'].append(rows2['voxel_num'])\n",
" story_data['expl1'].append(rows1['expl'])\n",
" story_data['expl2'].append(rows2['expl'])\n",
" stories_data_dict['voxel_num1'].append(rows1['voxel_num'])\n",
" stories_data_dict['voxel_num2'].append(rows2['voxel_num'])\n",
" stories_data_dict['expl1'].append(rows1['expl'])\n",
" stories_data_dict['expl2'].append(rows2['expl'])\n",
"\n",
" else:\n",
" stories_data_dict['voxel_num1'].append([])\n",
" stories_data_dict['voxel_num2'].append([])\n",
" stories_data_dict['expl1'].append([])\n",
" stories_data_dict['expl2'].append([])\n",
"\n",
"\n",
"joblib.dump(story_data, join(sasc.config.RESULTS_DIR, \"pilot2_story_data.pkl\"))"
"joblib.dump(stories_data_dict, output_file)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"ename": "SyntaxError",
"evalue": "invalid syntax (4008562075.py, line 1)",
"output_type": "error",
"traceback": [
"\u001b[0;36m Cell \u001b[0;32mIn[17], line 1\u001b[0;36m\u001b[0m\n\u001b[0;31m [k: len(stories_data_dict[k]) for k in stories_data_dict.keys()]\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
]
}
],
"source": [
"[k: len(stories_data_dict[k]) for k in stories_data_dict.keys()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -149,7 +203,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.11.5"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
196 changes: 196 additions & 0 deletions notebooks_stories/3_analyze_pilot2/01_load_results_pilot3.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from os.path import join\n",
"from tqdm import tqdm\n",
"import pandas as pd\n",
"import sys\n",
"import joblib\n",
"from scipy.special import softmax\n",
"import sasc.config\n",
"import numpy as np\n",
"from collections import defaultdict\n",
"from copy import deepcopy\n",
"import pandas as pd\n",
"# import story_helper\n",
"from sasc.modules.fmri_module import convert_module_num_to_voxel_num\n",
"from sasc.config import FMRI_DIR"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Read all the info from stories into a single pickle file"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# load stuff\n",
"# double check all of these, intro paragraph may be the same...\n",
"output_file = join(sasc.config.RESULTS_DIR, \"pilot3_story_data.pkl\")\n",
"STORIES_DIR = join(sasc.config.RESULTS_DIR, \"stories\")\n",
"story_mapping = {\n",
" 'default/uts03___jun14___seed=5': 'GenStory12_resps.npy',\n",
" 'default/uts03___jun14___seed=1': 'GenStory13_resps.npy',\n",
"\n",
" 'interactions/uts03___jun14___seed=5': 'GenStory14_resps.npy',\n",
" 'interactions/uts03___jun14___seed=6': 'GenStory15_resps.npy',\n",
"\n",
" 'polysemantic/uts03___jun14___seed=3': 'GenStory16_resps.npy',\n",
" 'polysemantic/uts03___jun14___seed=7': 'GenStory17_resps.npy',\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['/home/chansingh/automated-explanations/results/pilot3_story_data.pkl']"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# cluster_neighbors = joblib.load(join(FMRI_DIR, \"voxel_neighbors_and_pcs\", \"cluster_neighbors_v1.pkl\"))\n",
"perfs = joblib.load(join(sasc.config.FMRI_DIR, 'rj_models',\n",
" 'opt_model', 'new_setup_performance.jbl'))\n",
"\n",
"# add keys\n",
"stories_data_dict = defaultdict(list)\n",
"for story_idx, story_name in enumerate(story_mapping.keys()):\n",
" # add scalar story descriptions\n",
" stories_data_dict[\"story_name_original\"].append(story_name)\n",
" stories_data_dict[\"story_setting\"].append(story_name.split(\"/\")[0])\n",
" stories_data_dict[\"story_name_new\"].append(story_mapping[story_name])\n",
" stories_data_dict[\"story_text\"].append(\n",
" open(join(STORIES_DIR, story_name, \"story.txt\"), \"r\").read()\n",
" )\n",
" prompts_paragraphs = joblib.load(\n",
" join(STORIES_DIR, story_name, \"prompts_paragraphs.pkl\")\n",
" )\n",
"\n",
" # add paragraph-level descriptions\n",
" stories_data_dict[\"timing\"].append(\n",
" pd.read_csv(join(STORIES_DIR, story_name, \"timings_processed.csv\"))\n",
" )\n",
" stories_data_dict[\"prompts\"].append(prompts_paragraphs[\"prompts\"])\n",
" stories_data_dict[\"paragraphs\"].append(prompts_paragraphs[\"paragraphs\"])\n",
"\n",
" # add paragraph-level metadata\n",
" # rows\n",
" # rows = pd.read_csv(join(STORIES_DIR, story_name, \"rows.csv\"))\n",
" story_metadata_per_paragraph = pd.read_pickle(\n",
" join(STORIES_DIR, story_name, \"rows.pkl\"))\n",
" story_metadata_per_paragraph[\"voxel_num\"] = story_metadata_per_paragraph.apply(\n",
" lambda row: convert_module_num_to_voxel_num(\n",
" row[\"module_num\"], row[\"subject\"]),\n",
" axis=1,\n",
" )\n",
" story_metadata_per_paragraph = story_metadata_per_paragraph[\n",
" [\n",
" \"expl\",\n",
" \"module_num\",\n",
" \"top_explanation_init_strs\",\n",
" \"subject\",\n",
" \"fmri_test_corr\",\n",
" # \"top_score_synthetic\",\n",
" \"top_score_normalized\",\n",
" \"roi_anat\",\n",
" \"roi_func\",\n",
" \"voxel_num\",\n",
" ]\n",
" ]\n",
" story_metadata_per_paragraph['test_corr_new'] = story_metadata_per_paragraph['voxel_num'].apply(\n",
" lambda x: perfs[x])\n",
" # rows['cluster_nums'] = rows['voxel_num'].map(cluster_neighbors)\n",
" stories_data_dict[\"rows\"].append(story_metadata_per_paragraph)\n",
"\n",
" if \"interactions\" in list(story_mapping.keys())[story_idx]:\n",
" rows1 = pd.read_pickle(join(STORIES_DIR, story_name, \"rows1.pkl\"))\n",
" rows2 = pd.read_pickle(join(STORIES_DIR, story_name, \"rows2.pkl\"))\n",
" rows1[\"voxel_num\"] = rows1.apply(\n",
" lambda row: convert_module_num_to_voxel_num(\n",
" row[\"module_num\"], row[\"subject\"]\n",
" ),\n",
" axis=1,\n",
" )\n",
" rows2[\"voxel_num\"] = rows2.apply(\n",
" lambda row: convert_module_num_to_voxel_num(\n",
" row[\"module_num\"], row[\"subject\"]\n",
" ),\n",
" axis=1,\n",
" )\n",
" stories_data_dict['voxel_num1'].append(rows1['voxel_num'])\n",
" stories_data_dict['voxel_num2'].append(rows2['voxel_num'])\n",
" stories_data_dict['expl1'].append(rows1['expl'])\n",
" stories_data_dict['expl2'].append(rows2['expl'])\n",
" else:\n",
" stories_data_dict['voxel_num1'].append([])\n",
" stories_data_dict['voxel_num2'].append([])\n",
" stories_data_dict['expl1'].append([])\n",
" stories_data_dict['expl2'].append([])\n",
"\n",
"\n",
"joblib.dump(stories_data_dict, output_file)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".llm",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "a9ff692d44ea03fd8a03facee7621117bbbb82def09bacaacf0a2cbc238b7b91"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit a2b4fea

Please sign in to comment.