From 5faba0a05d68e61aab3bc143e9767222d5f304bc Mon Sep 17 00:00:00 2001 From: Patrick Zhang Date: Mon, 3 Apr 2023 23:25:17 +0000 Subject: [PATCH 1/7] cleanup --- format_beh.ipynb | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/format_beh.ipynb b/format_beh.ipynb index e64fb7f..10e87e2 100644 --- a/format_beh.ipynb +++ b/format_beh.ipynb @@ -38,6 +38,11 @@ "outputs": [], "source": [ "def get_X_by_bins(bin_size, data):\n", + " \"\"\"\n", + " bin_size: in miliseconds, bin size\n", + " data: dataframe for behavioral data from object features csv\n", + " Returns: new dataframe with one-hot encoding of features, feedback\n", + " \"\"\"\n", " max_time = np.max(valid_beh[\"TrialEnd\"].values)\n", " max_bin_idx = int(max_time / bin_size) + 1\n", " columns = FEATURES + [\"CORRECT\", \"INCORRECT\"]\n", @@ -73,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 84, "metadata": {}, "outputs": [], "source": [ @@ -135,11 +140,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 80, + "metadata": {}, + "outputs": [], + "source": [ + "intervals = get_trial_intervals(valid_beh, pre_interval=1500, post_interval=1500, bin_size=50)" + ] + }, + { + "cell_type": "code", + "execution_count": 82, "metadata": {}, "outputs": [], "source": [ - "get_trial_intervals(valid_beh, pre_interval=1500, post_interval=1500, bin_size=50)" + "intervals.to_pickle(\"/data/processed/sub-SA_sess-20180802_interval_1500_fb_1500_binsize_50.pickle\")" ] } ], From 740755f84aadd510c5461a7fb8867e498ee8ee0d Mon Sep 17 00:00:00 2001 From: Patrick Zhang Date: Tue, 4 Apr 2023 22:07:49 +0000 Subject: [PATCH 2/7] beh utils, work on generating design mat --- behavioral_utils.py | 70 ++++++ format_beh.ipynb | 558 +++++++++++++++++++++++++++++++++++++------- 2 files changed, 548 insertions(+), 80 deletions(-) create mode 100644 behavioral_utils.py diff --git a/behavioral_utils.py b/behavioral_utils.py new file mode 100644 index 0000000..13d064b --- /dev/null +++ b/behavioral_utils.py @@ -0,0 +1,70 @@ +from constants import FEATURES +import numpy as np +import pandas as pd + +def get_X_by_bins(bin_size, data): + """ + bin_size: in miliseconds, bin size + data: dataframe for behavioral data from object features csv + Returns: new dataframe with one-hot encoding of features, feedback + """ + max_time = np.max(data["TrialEnd"].values) + max_bin_idx = int(max_time / bin_size) + 1 + columns = FEATURES + ["CORRECT", "INCORRECT"] + types = ["f4" for _ in columns] + zipped = list(zip(columns, types)) + dtype = np.dtype(zipped) + arr = np.zeros((max_bin_idx), dtype=dtype) + + for _, row in data.iterrows(): + # grab features of item chosen + item_chosen = int(row["ItemChosen"]) + color = row[f"Item{item_chosen}Color"] + shape = row[f"Item{item_chosen}Shape"] + pattern = row[f"Item{item_chosen}Pattern"] + + chosen_time = row["FeedbackOnset"] - 800 + chosen_bin = int(chosen_time / bin_size) + arr[chosen_bin][color] = 1 + arr[chosen_bin][shape] = 1 + arr[chosen_bin][pattern] = 1 + + feedback_bin = int(row["FeedbackOnset"] / bin_size) + # print(feedback_bin) + if row["Response"] == "Correct": + arr[feedback_bin]["CORRECT"] = 1 + else: + arr[feedback_bin]["INCORRECT"] = 1 + df = pd.DataFrame(arr) + df["bin_idx"] = np.arange(len(df)) + return df + + +def get_trial_intervals(behavioral_data, event="FeedbackOnset", pre_interval=0, post_interval=0, bin_size=50): + """Per trial, finds time interval surrounding some event in the behavioral data + + Args: + behavioral_data: Dataframe describing each trial, must contain + columns: TrialNumber, whatever 'event' param describes + event: name of event to align around, must be present as a + column name in behavioral_data Dataframe + pre_interval: number of miliseconds before event + post_interval: number of miliseconds after event + + Returns: + DataFrame with num_trials length, columns: TrialNumber, + IntervalStartTime, IntervalEndTime + """ + trial_event_times = behavioral_data[["TrialNumber", event]] + + intervals = np.empty((len(trial_event_times), 3)) + intervals[:, 0] = trial_event_times["TrialNumber"] + intervals[:, 1] = trial_event_times[event] - pre_interval + intervals[:, 2] = trial_event_times[event] + post_interval + intervals_df = pd.DataFrame(columns=["TrialNumber", "IntervalStartTime", "IntervalEndTime"]) + intervals_df["TrialNumber"] = trial_event_times["TrialNumber"].astype(int) + intervals_df["IntervalStartTime"] = trial_event_times[event] - pre_interval + intervals_df["IntervalEndTime"] = trial_event_times[event] + post_interval + intervals_df["IntervalStartBin"] = (intervals_df["IntervalStartTime"] / bin_size).astype(int) + intervals_df["IntervalEndBin"] = (intervals_df["IntervalEndTime"] / bin_size).astype(int) + return intervals_df \ No newline at end of file diff --git a/format_beh.ipynb b/format_beh.ipynb index 10e87e2..e33c992 100644 --- a/format_beh.ipynb +++ b/format_beh.ipynb @@ -2,16 +2,20 @@ "cells": [ { "cell_type": "code", - "execution_count": 11, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", "import numpy as np\n", "import pandas as pd\n", "from spike_tools import (\n", " general as spike_general,\n", " analysis as spike_analysis,\n", ")\n", + "import behavioral_utils\n", "from constants import FEATURES\n", "\n", "species = 'nhp'\n", @@ -22,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -33,65 +37,46 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ - "def get_X_by_bins(bin_size, data):\n", - " \"\"\"\n", - " bin_size: in miliseconds, bin size\n", - " data: dataframe for behavioral data from object features csv\n", - " Returns: new dataframe with one-hot encoding of features, feedback\n", - " \"\"\"\n", - " max_time = np.max(valid_beh[\"TrialEnd\"].values)\n", - " max_bin_idx = int(max_time / bin_size) + 1\n", - " columns = FEATURES + [\"CORRECT\", \"INCORRECT\"]\n", - " types = [\"f4\" for _ in columns]\n", - " zipped = list(zip(columns, types))\n", - " dtype = np.dtype(zipped)\n", - " arr = np.zeros((max_bin_idx), dtype=dtype)\n", - "\n", - " for _, row in data.iterrows():\n", - " # grab features of item chosen\n", - " item_chosen = int(row[\"ItemChosen\"])\n", - " color = row[f\"Item{item_chosen}Color\"]\n", - " shape = row[f\"Item{item_chosen}Shape\"]\n", - " pattern = row[f\"Item{item_chosen}Pattern\"]\n", - "\n", - " chosen_time = row[\"FeedbackOnset\"] - 800\n", - " chosen_bin = int(chosen_time / bin_size)\n", - " arr[chosen_bin][color] = 1\n", - " arr[chosen_bin][shape] = 1\n", - " arr[chosen_bin][pattern] = 1\n", - "\n", - " feedback_bin = int(row[\"FeedbackOnset\"] / bin_size)\n", - " # print(feedback_bin)\n", - " if row[\"Response\"] == \"Correct\":\n", - " arr[feedback_bin][\"CORRECT\"] = 1\n", - " else:\n", - " arr[feedback_bin][\"INCORRECT\"] = 1\n", - " df = pd.DataFrame(arr)\n", - " df[\"bin_idx\"] = np.arange(len(df))\n", - " return df\n", - " \n" + "x_by_bins = behavioral_utils.get_X_by_bins(50, valid_beh)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "x_by_bins.to_pickle('/data/processed/sub-SA_sess-20180802_behavior_binsize_50.pickle')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Grab bin idxs of interval around fb onset" ] }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "res = get_X_by_bins(50, valid_beh)" + "intervals = behavioral_utils.get_trial_intervals(valid_beh, pre_interval=1500, post_interval=1500, bin_size=50)" ] }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 82, "metadata": {}, "outputs": [], "source": [ - "res.to_pickle('/data/processed/sub-SA_sess-20180802_behavior_binsize_50.pickle')" + "intervals.to_pickle(\"/data/processed/sub-SA_sess-20180802_interval_1500_fb_1500_binsize_50.pickle\")" ] }, { @@ -99,61 +84,474 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Grab bin idxs of interval around fb onset" + "### Grab design matrix for intervals " ] }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
TrialNumberIntervalStartTimeIntervalEndTimeIntervalStartBinIntervalEndBin
00754088.0757088.01508115141
11757892.0760892.01515715217
22762080.0765080.01524115301
33766144.0769144.01532215382
44771945.0774945.01543815498
..................
1745174513652545.013655545.0273050273110
1746174613656221.013659221.0273124273184
1747174713668992.013671992.0273379273439
1748174813673866.013676866.0273477273537
1749174913684869.013687869.0273697273757
\n", + "

1749 rows × 5 columns

\n", + "
" + ], + "text/plain": [ + " TrialNumber IntervalStartTime IntervalEndTime IntervalStartBin \\\n", + "0 0 754088.0 757088.0 15081 \n", + "1 1 757892.0 760892.0 15157 \n", + "2 2 762080.0 765080.0 15241 \n", + "3 3 766144.0 769144.0 15322 \n", + "4 4 771945.0 774945.0 15438 \n", + "... ... ... ... ... \n", + "1745 1745 13652545.0 13655545.0 273050 \n", + "1746 1746 13656221.0 13659221.0 273124 \n", + "1747 1747 13668992.0 13671992.0 273379 \n", + "1748 1748 13673866.0 13676866.0 273477 \n", + "1749 1749 13684869.0 13687869.0 273697 \n", + "\n", + " IntervalEndBin \n", + "0 15141 \n", + "1 15217 \n", + "2 15301 \n", + "3 15382 \n", + "4 15498 \n", + "... ... \n", + "1745 273110 \n", + "1746 273184 \n", + "1747 273439 \n", + "1748 273537 \n", + "1749 273757 \n", + "\n", + "[1749 rows x 5 columns]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "def get_trial_intervals(behavioral_data, event=\"FeedbackOnset\", pre_interval=0, post_interval=0, bin_size=50):\n", - " \"\"\"Per trial, finds time interval surrounding some event in the behavioral data\n", - "\n", - " Args:\n", - " behavioral_data: Dataframe describing each trial, must contain\n", - " columns: TrialNumber, whatever 'event' param describes\n", - " event: name of event to align around, must be present as a\n", - " column name in behavioral_data Dataframe\n", - " pre_interval: number of miliseconds before event\n", - " post_interval: number of miliseconds after event\n", - "\n", - " Returns:\n", - " DataFrame with num_trials length, columns: TrialNumber,\n", - " IntervalStartTime, IntervalEndTime\n", - " \"\"\"\n", - " trial_event_times = behavioral_data[[\"TrialNumber\", event]]\n", - "\n", - " intervals = np.empty((len(trial_event_times), 3))\n", - " intervals[:, 0] = trial_event_times[\"TrialNumber\"]\n", - " intervals[:, 1] = trial_event_times[event] - pre_interval\n", - " intervals[:, 2] = trial_event_times[event] + post_interval\n", - " intervals_df = pd.DataFrame(columns=[\"TrialNumber\", \"IntervalStartTime\", \"IntervalEndTime\"])\n", - " intervals_df[\"TrialNumber\"] = trial_event_times[\"TrialNumber\"].astype(int)\n", - " intervals_df[\"IntervalStartTime\"] = trial_event_times[event] - pre_interval\n", - " intervals_df[\"IntervalEndTime\"] = trial_event_times[event] + post_interval\n", - " intervals_df[\"IntervalStartBin\"] = (intervals_df[\"IntervalStartTime\"] / bin_size).astype(int)\n", - " intervals_df[\"IntervalEndBin\"] = (intervals_df[\"IntervalEndTime\"] / bin_size).astype(int)\n", - " return intervals_df\n" + "intervals" ] }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
CIRCLESQUARESTARTRIANGLECYANGREENMAGENTAYELLOWESCHERPOLKADOTRIPPLESWIRLCORRECTINCORRECTbin_idx
00.00.00.00.00.00.00.00.00.00.00.00.00.00.00
10.00.00.00.00.00.00.00.00.00.00.00.00.00.01
20.00.00.00.00.00.00.00.00.00.00.00.00.00.02
30.00.00.00.00.00.00.00.00.00.00.00.00.00.03
40.00.00.00.00.00.00.00.00.00.00.00.00.00.04
................................................
2738250.00.00.00.00.00.00.00.00.00.00.00.00.00.0273825
2738260.00.00.00.00.00.00.00.00.00.00.00.00.00.0273826
2738270.00.00.00.00.00.00.00.00.00.00.00.00.00.0273827
2738280.00.00.00.00.00.00.00.00.00.00.00.00.00.0273828
2738290.00.00.00.00.00.00.00.00.00.00.00.00.00.0273829
\n", + "

273830 rows × 15 columns

\n", + "
" + ], + "text/plain": [ + " CIRCLE SQUARE STAR TRIANGLE CYAN GREEN MAGENTA YELLOW ESCHER \\\n", + "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", + "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", + "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", + "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", + "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", + "... ... ... ... ... ... ... ... ... ... \n", + "273825 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", + "273826 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", + "273827 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", + "273828 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", + "273829 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", + "\n", + " POLKADOT RIPPLE SWIRL CORRECT INCORRECT bin_idx \n", + "0 0.0 0.0 0.0 0.0 0.0 0 \n", + "1 0.0 0.0 0.0 0.0 0.0 1 \n", + "2 0.0 0.0 0.0 0.0 0.0 2 \n", + "3 0.0 0.0 0.0 0.0 0.0 3 \n", + "4 0.0 0.0 0.0 0.0 0.0 4 \n", + "... ... ... ... ... ... ... \n", + "273825 0.0 0.0 0.0 0.0 0.0 273825 \n", + "273826 0.0 0.0 0.0 0.0 0.0 273826 \n", + "273827 0.0 0.0 0.0 0.0 0.0 273827 \n", + "273828 0.0 0.0 0.0 0.0 0.0 273828 \n", + "273829 0.0 0.0 0.0 0.0 0.0 273829 \n", + "\n", + "[273830 rows x 15 columns]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "intervals = get_trial_intervals(valid_beh, pre_interval=1500, post_interval=1500, bin_size=50)" + "x_by_bins" ] }, { "cell_type": "code", - "execution_count": 82, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "intervals.to_pickle(\"/data/processed/sub-SA_sess-20180802_interval_1500_fb_1500_binsize_50.pickle\")" + "# pre_interval, post_interval, bin_size\n", + "# for interval in intervals\n", + "# generate bin idx\n", + "window_size = \n", + "\n", + "for row in intervals.iterrows():\n", + " trial_bins = np.arange(row.IntervalStartBin, row.IntervalEndBin)\n" ] } ], From b0449f7df487c3a05b7843e9b7cfaab9da7ca695 Mon Sep 17 00:00:00 2001 From: Patrick Zhang Date: Wed, 5 Apr 2023 00:27:28 +0000 Subject: [PATCH 3/7] finished design matrix logic, refactored into data utils --- behavioral_utils.py => data_utils.py | 66 +++- format_beh.ipynb | 486 ++------------------------- 2 files changed, 93 insertions(+), 459 deletions(-) rename behavioral_utils.py => data_utils.py (52%) diff --git a/behavioral_utils.py b/data_utils.py similarity index 52% rename from behavioral_utils.py rename to data_utils.py index 13d064b..57f7fa7 100644 --- a/behavioral_utils.py +++ b/data_utils.py @@ -2,13 +2,13 @@ import numpy as np import pandas as pd -def get_X_by_bins(bin_size, data): +def get_behavior_by_bins(bin_size, beh): """ bin_size: in miliseconds, bin size data: dataframe for behavioral data from object features csv Returns: new dataframe with one-hot encoding of features, feedback """ - max_time = np.max(data["TrialEnd"].values) + max_time = np.max(beh["TrialEnd"].values) max_bin_idx = int(max_time / bin_size) + 1 columns = FEATURES + ["CORRECT", "INCORRECT"] types = ["f4" for _ in columns] @@ -16,7 +16,7 @@ def get_X_by_bins(bin_size, data): dtype = np.dtype(zipped) arr = np.zeros((max_bin_idx), dtype=dtype) - for _, row in data.iterrows(): + for _, row in beh.iterrows(): # grab features of item chosen item_chosen = int(row["ItemChosen"]) color = row[f"Item{item_chosen}Color"] @@ -40,6 +40,28 @@ def get_X_by_bins(bin_size, data): return df +def get_spikes_by_bins(bin_size, spike_times): + """Given a bin_size and a series of spike times, return spike counts by bin. + Args: + bin_size: size of bins in miliseconds + spike_times: dataframe with unit_id, spike times. + Returns: + df with bin_idx, unit_* as columns, filled with spike counts + """ + + units = np.unique(spike_times.UnitID.values) + time_stamp_max = int(spike_times.SpikeTime.max()) + 1 + + num_time_bins = int(time_stamp_max/bin_size) + 1 + bins = np.arange(num_time_bins) * bin_size + + df = pd.DataFrame(data={'bin_idx': np.arange(num_time_bins)[:-1]}) + for unit in units: + unit_spike_times = spike_times[spike_times.UnitID==unit].SpikeTime.values + unit_spike_counts, bin_edges = np.histogram(unit_spike_times, bins=bins) + df[f'unit_{unit}'] = unit_spike_counts + return df + def get_trial_intervals(behavioral_data, event="FeedbackOnset", pre_interval=0, post_interval=0, bin_size=50): """Per trial, finds time interval surrounding some event in the behavioral data @@ -67,4 +89,40 @@ def get_trial_intervals(behavioral_data, event="FeedbackOnset", pre_interval=0, intervals_df["IntervalEndTime"] = trial_event_times[event] + post_interval intervals_df["IntervalStartBin"] = (intervals_df["IntervalStartTime"] / bin_size).astype(int) intervals_df["IntervalEndBin"] = (intervals_df["IntervalEndTime"] / bin_size).astype(int) - return intervals_df \ No newline at end of file + return intervals_df + + +def get_design_matrix(spikes_by_bins, beh_by_bins, columns, tau_pre, tau_post): + """ + Reformats data as a design matrix dataframe, where for each of the specified columns, + additional columns are added for each of the time points between tau_pre and tau_post + Args: + spike_by_bins: df with bin_idx, unit_* as columns + beh_by_bins: df with bin_idx, behavioral vars of interest as columns + columns: columns to include, must be present in either spike_by_bins or beh_by_bins + tau_pre: number of bins to look in the past + tau_post: number of bins to look in the future + Returns: + df with bin_idx, columns for each time points between tau_pre and tau_post + """ + joint = pd.merge(spikes_by_bins, beh_by_bins, on="bin_idx", how="inner") + res = pd.DataFrame() + taus = np.arange(-tau_pre, tau_post) + for tau in taus: + shift_idx = -1 * tau + column_names = [f"{x}_{tau}" for x in columns] + res[column_names] = joint.shift(shift_idx)[columns] + res["bin_idx"] = joint["bin_idx"] + return res + + +def get_interval_bins(intervals): + """ + Gets all the bins belonging to all the intervals + Args: + intervals: df with trialnumber, IntervalStartBin, IntervalEndBin + Returns: + np array of all bins for all trials falling between startbin and endbin + """ + interval_bins = intervals.apply(lambda x: np.arange(x.IntervalStartBin, x.IntervalEndBin).astype(int), axis=1) + return np.concatenate(interval_bins.to_numpy()) \ No newline at end of file diff --git a/format_beh.ipynb b/format_beh.ipynb index e33c992..731b908 100644 --- a/format_beh.ipynb +++ b/format_beh.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -15,7 +15,7 @@ " general as spike_general,\n", " analysis as spike_analysis,\n", ")\n", - "import behavioral_utils\n", + "import data_utils\n", "from constants import FEATURES\n", "\n", "species = 'nhp'\n", @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -37,11 +37,11 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "x_by_bins = behavioral_utils.get_X_by_bins(50, valid_beh)" + "behavior_by_bins = data_utils.get_behavior_by_bins(50, valid_beh)" ] }, { @@ -50,7 +50,7 @@ "metadata": {}, "outputs": [], "source": [ - "x_by_bins.to_pickle('/data/processed/sub-SA_sess-20180802_behavior_binsize_50.pickle')" + "behavior_by_bins.to_pickle('/data/processed/sub-SA_sess-20180802_behavior_binsize_50.pickle')" ] }, { @@ -67,7 +67,7 @@ "metadata": {}, "outputs": [], "source": [ - "intervals = behavioral_utils.get_trial_intervals(valid_beh, pre_interval=1500, post_interval=1500, bin_size=50)" + "intervals = data_utils.get_trial_intervals(valid_beh, pre_interval=1500, post_interval=1500, bin_size=50)" ] }, { @@ -89,469 +89,45 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 14, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TrialNumberIntervalStartTimeIntervalEndTimeIntervalStartBinIntervalEndBin
00754088.0757088.01508115141
11757892.0760892.01515715217
22762080.0765080.01524115301
33766144.0769144.01532215382
44771945.0774945.01543815498
..................
1745174513652545.013655545.0273050273110
1746174613656221.013659221.0273124273184
1747174713668992.013671992.0273379273439
1748174813673866.013676866.0273477273537
1749174913684869.013687869.0273697273757
\n", - "

1749 rows × 5 columns

\n", - "
" - ], - "text/plain": [ - " TrialNumber IntervalStartTime IntervalEndTime IntervalStartBin \\\n", - "0 0 754088.0 757088.0 15081 \n", - "1 1 757892.0 760892.0 15157 \n", - "2 2 762080.0 765080.0 15241 \n", - "3 3 766144.0 769144.0 15322 \n", - "4 4 771945.0 774945.0 15438 \n", - "... ... ... ... ... \n", - "1745 1745 13652545.0 13655545.0 273050 \n", - "1746 1746 13656221.0 13659221.0 273124 \n", - "1747 1747 13668992.0 13671992.0 273379 \n", - "1748 1748 13673866.0 13676866.0 273477 \n", - "1749 1749 13684869.0 13687869.0 273697 \n", - "\n", - " IntervalEndBin \n", - "0 15141 \n", - "1 15217 \n", - "2 15301 \n", - "3 15382 \n", - "4 15498 \n", - "... ... \n", - "1745 273110 \n", - "1746 273184 \n", - "1747 273439 \n", - "1748 273537 \n", - "1749 273757 \n", - "\n", - "[1749 rows x 5 columns]" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "intervals" + "spikes_by_bins = pd.read_pickle('/data/processed/sub-SA_sess-20180802_spike_counts_binsize_50.pickle')\n", + "beh_by_bins = pd.read_pickle('/data/processed/sub-SA_sess-20180802_behavior_binsize_50.pickle')\n", + "intervals = pd.read_pickle(\"/data/processed/sub-SA_sess-20180802_interval_1500_fb_1500_binsize_50.pickle\")\n", + "\n", + "NUM_UNITS = 59\n", + "column_names_w_units = FEATURES + [\"CORRECT\", \"INCORRECT\"] + [f\"unit_{i}\" for i in range(0, NUM_UNITS)]\n", + "column_names = FEATURES + [\"CORRECT\" + \"INCORRECT\"]" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
CIRCLESQUARESTARTRIANGLECYANGREENMAGENTAYELLOWESCHERPOLKADOTRIPPLESWIRLCORRECTINCORRECTbin_idx
00.00.00.00.00.00.00.00.00.00.00.00.00.00.00
10.00.00.00.00.00.00.00.00.00.00.00.00.00.01
20.00.00.00.00.00.00.00.00.00.00.00.00.00.02
30.00.00.00.00.00.00.00.00.00.00.00.00.00.03
40.00.00.00.00.00.00.00.00.00.00.00.00.00.04
................................................
2738250.00.00.00.00.00.00.00.00.00.00.00.00.00.0273825
2738260.00.00.00.00.00.00.00.00.00.00.00.00.00.0273826
2738270.00.00.00.00.00.00.00.00.00.00.00.00.00.0273827
2738280.00.00.00.00.00.00.00.00.00.00.00.00.00.0273828
2738290.00.00.00.00.00.00.00.00.00.00.00.00.00.0273829
\n", - "

273830 rows × 15 columns

\n", - "
" - ], - "text/plain": [ - " CIRCLE SQUARE STAR TRIANGLE CYAN GREEN MAGENTA YELLOW ESCHER \\\n", - "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", - "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", - "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", - "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", - "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", - "... ... ... ... ... ... ... ... ... ... \n", - "273825 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", - "273826 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", - "273827 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", - "273828 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", - "273829 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", - "\n", - " POLKADOT RIPPLE SWIRL CORRECT INCORRECT bin_idx \n", - "0 0.0 0.0 0.0 0.0 0.0 0 \n", - "1 0.0 0.0 0.0 0.0 0.0 1 \n", - "2 0.0 0.0 0.0 0.0 0.0 2 \n", - "3 0.0 0.0 0.0 0.0 0.0 3 \n", - "4 0.0 0.0 0.0 0.0 0.0 4 \n", - "... ... ... ... ... ... ... \n", - "273825 0.0 0.0 0.0 0.0 0.0 273825 \n", - "273826 0.0 0.0 0.0 0.0 0.0 273826 \n", - "273827 0.0 0.0 0.0 0.0 0.0 273827 \n", - "273828 0.0 0.0 0.0 0.0 0.0 273828 \n", - "273829 0.0 0.0 0.0 0.0 0.0 273829 \n", - "\n", - "[273830 rows x 15 columns]" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "x_by_bins" + "design_mat = data_utils.get_design_matrix(spikes_by_bins, beh_by_bins, column_names_w_units, 20, 0)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ - "# pre_interval, post_interval, bin_size\n", - "# for interval in intervals\n", - "# generate bin idx\n", - "window_size = \n", - "\n", - "for row in intervals.iterrows():\n", - " trial_bins = np.arange(row.IntervalStartBin, row.IntervalEndBin)\n" + "interval_bins = data_utils.get_interval_bins(intervals)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "mat_in_intervals = design_mat[design_mat.bin_idx.isin(interval_bins)]\n", + "assert len(mat_in_intervals) == len(interval_bins)" ] } ], From c7cc127d5274cd39a556172ea507d6bec582858c Mon Sep 17 00:00:00 2001 From: Patrick Zhang Date: Wed, 5 Apr 2023 00:44:51 +0000 Subject: [PATCH 4/7] refactor into another notebook --- constants.py | 8 ++- create_design_matrix.ipynb | 99 ++++++++++++++++++++++++++++++++++++++ format_beh.ipynb | 51 -------------------- 3 files changed, 106 insertions(+), 52 deletions(-) create mode 100644 create_design_matrix.ipynb diff --git a/constants.py b/constants.py index cc9d0cd..8696000 100644 --- a/constants.py +++ b/constants.py @@ -1,6 +1,12 @@ +# useful constants during analysis FEATURES = [ 'CIRCLE', 'SQUARE', 'STAR', 'TRIANGLE', 'CYAN', 'GREEN', 'MAGENTA', 'YELLOW', 'ESCHER', 'POLKADOT', 'RIPPLE', 'SWIRL' -] \ No newline at end of file +] + +NUM_UNITS = 59 + +COLUMN_NAMES_W_UNITS = FEATURES + ["CORRECT", "INCORRECT"] + [f"unit_{i}" for i in range(0, NUM_UNITS)] +COLUMN_NAMES = FEATURES + ["CORRECT", "INCORRECT"] \ No newline at end of file diff --git a/create_design_matrix.ipynb b/create_design_matrix.ipynb new file mode 100644 index 0000000..08f5cf1 --- /dev/null +++ b/create_design_matrix.ipynb @@ -0,0 +1,99 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Notebook to create and store a design matrix of behavior and spikes " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from spike_tools import (\n", + " general as spike_general,\n", + " analysis as spike_analysis,\n", + ")\n", + "import data_utils\n", + "from constants import FEATURES, COLUMN_NAMES_W_UNITS\n", + "\n", + "species = 'nhp'\n", + "subject = 'SA'\n", + "exp = 'WCST'\n", + "session = 20180802 # this is the session for which there are spikes at the moment. \n", + "\n", + "tau_pre = 20\n", + "tau_post = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "spikes_by_bins = pd.read_pickle('/data/processed/sub-SA_sess-20180802_spike_counts_binsize_50.pickle')\n", + "beh_by_bins = pd.read_pickle('/data/processed/sub-SA_sess-20180802_behavior_binsize_50.pickle')\n", + "intervals = pd.read_pickle(\"/data/processed/sub-SA_sess-20180802_interval_1500_fb_1500_binsize_50.pickle\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "design_mat = data_utils.get_design_matrix(spikes_by_bins, beh_by_bins, COLUMN_NAMES_W_UNITS, tau_pre, tau_post)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "design_mat.to_pickle(\"/data/processed/sub-SA_sess-20180802_design_mat_taupre_20_taupost_0_binsize_50.pickle\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/format_beh.ipynb b/format_beh.ipynb index 731b908..9409d36 100644 --- a/format_beh.ipynb +++ b/format_beh.ipynb @@ -78,57 +78,6 @@ "source": [ "intervals.to_pickle(\"/data/processed/sub-SA_sess-20180802_interval_1500_fb_1500_binsize_50.pickle\")" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Grab design matrix for intervals " - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "spikes_by_bins = pd.read_pickle('/data/processed/sub-SA_sess-20180802_spike_counts_binsize_50.pickle')\n", - "beh_by_bins = pd.read_pickle('/data/processed/sub-SA_sess-20180802_behavior_binsize_50.pickle')\n", - "intervals = pd.read_pickle(\"/data/processed/sub-SA_sess-20180802_interval_1500_fb_1500_binsize_50.pickle\")\n", - "\n", - "NUM_UNITS = 59\n", - "column_names_w_units = FEATURES + [\"CORRECT\", \"INCORRECT\"] + [f\"unit_{i}\" for i in range(0, NUM_UNITS)]\n", - "column_names = FEATURES + [\"CORRECT\" + \"INCORRECT\"]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "design_mat = data_utils.get_design_matrix(spikes_by_bins, beh_by_bins, column_names_w_units, 20, 0)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "interval_bins = data_utils.get_interval_bins(intervals)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "mat_in_intervals = design_mat[design_mat.bin_idx.isin(interval_bins)]\n", - "assert len(mat_in_intervals) == len(interval_bins)" - ] } ], "metadata": { From 19f9d3d2f383378f920cabe8d07b852b562edbc3 Mon Sep 17 00:00:00 2001 From: Patrick Zhang Date: Thu, 6 Apr 2023 18:11:10 +0000 Subject: [PATCH 5/7] refactor, remove num_neuron constant --- .../create_design_matrix.ipynb | 11 ++++++----- format_beh.ipynb => notebooks/format_beh.ipynb | 4 ++-- format_spikes.ipynb => notebooks/format_spikes.ipynb | 0 constants.py => wcst_encode/constants.py | 0 data_utils.py => wcst_encode/data_utils.py | 2 +- 5 files changed, 9 insertions(+), 8 deletions(-) rename create_design_matrix.ipynb => notebooks/create_design_matrix.ipynb (89%) rename format_beh.ipynb => notebooks/format_beh.ipynb (96%) rename format_spikes.ipynb => notebooks/format_spikes.ipynb (100%) rename constants.py => wcst_encode/constants.py (100%) rename data_utils.py => wcst_encode/data_utils.py (99%) diff --git a/create_design_matrix.ipynb b/notebooks/create_design_matrix.ipynb similarity index 89% rename from create_design_matrix.ipynb rename to notebooks/create_design_matrix.ipynb index 08f5cf1..b7ca2e6 100644 --- a/create_design_matrix.ipynb +++ b/notebooks/create_design_matrix.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -32,8 +32,8 @@ " general as spike_general,\n", " analysis as spike_analysis,\n", ")\n", - "import data_utils\n", - "from constants import FEATURES, COLUMN_NAMES_W_UNITS\n", + "import wcst_encode.data_utils as data_utils\n", + "from wcst_encode.constants import COLUMN_NAMES\n", "\n", "species = 'nhp'\n", "subject = 'SA'\n", @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -61,7 +61,8 @@ "metadata": {}, "outputs": [], "source": [ - "design_mat = data_utils.get_design_matrix(spikes_by_bins, beh_by_bins, COLUMN_NAMES_W_UNITS, tau_pre, tau_post)" + "column_names_w_units = COLUMN_NAMES + spikes_by_bins.columns[1:].tolist()\n", + "design_mat = data_utils.get_design_matrix(spikes_by_bins, beh_by_bins, column_names_w_units, tau_pre, tau_post)" ] }, { diff --git a/format_beh.ipynb b/notebooks/format_beh.ipynb similarity index 96% rename from format_beh.ipynb rename to notebooks/format_beh.ipynb index 9409d36..a576234 100644 --- a/format_beh.ipynb +++ b/notebooks/format_beh.ipynb @@ -15,8 +15,8 @@ " general as spike_general,\n", " analysis as spike_analysis,\n", ")\n", - "import data_utils\n", - "from constants import FEATURES\n", + "import wcst_encode.data_utils\n", + "from wcst_encode.constants import FEATURES\n", "\n", "species = 'nhp'\n", "subject = 'SA'\n", diff --git a/format_spikes.ipynb b/notebooks/format_spikes.ipynb similarity index 100% rename from format_spikes.ipynb rename to notebooks/format_spikes.ipynb diff --git a/constants.py b/wcst_encode/constants.py similarity index 100% rename from constants.py rename to wcst_encode/constants.py diff --git a/data_utils.py b/wcst_encode/data_utils.py similarity index 99% rename from data_utils.py rename to wcst_encode/data_utils.py index 57f7fa7..9cba542 100644 --- a/data_utils.py +++ b/wcst_encode/data_utils.py @@ -1,4 +1,4 @@ -from constants import FEATURES +from .constants import FEATURES import numpy as np import pandas as pd From d0c1984ddde4c6139fb5449067ba5b978c49195d Mon Sep 17 00:00:00 2001 From: Patrick Zhang Date: Tue, 11 Apr 2023 19:11:46 +0000 Subject: [PATCH 6/7] address comments --- notebooks/create_design_matrix.ipynb | 15 ++-------- notebooks/format_beh.ipynb | 21 +++++++++---- wcst_encode/constants.py | 7 +++-- wcst_encode/data_utils.py | 44 ++++++++++++++-------------- 4 files changed, 44 insertions(+), 43 deletions(-) diff --git a/notebooks/create_design_matrix.ipynb b/notebooks/create_design_matrix.ipynb index b7ca2e6..a0f51ed 100644 --- a/notebooks/create_design_matrix.ipynb +++ b/notebooks/create_design_matrix.ipynb @@ -10,18 +10,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", @@ -46,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/notebooks/format_beh.ipynb b/notebooks/format_beh.ipynb index a576234..9c5a4f5 100644 --- a/notebooks/format_beh.ipynb +++ b/notebooks/format_beh.ipynb @@ -2,9 +2,18 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", @@ -15,7 +24,7 @@ " general as spike_general,\n", " analysis as spike_analysis,\n", ")\n", - "import wcst_encode.data_utils\n", + "import wcst_encode.data_utils as data_utils\n", "from wcst_encode.constants import FEATURES\n", "\n", "species = 'nhp'\n", @@ -26,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -63,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ diff --git a/wcst_encode/constants.py b/wcst_encode/constants.py index 8696000..13c99e9 100644 --- a/wcst_encode/constants.py +++ b/wcst_encode/constants.py @@ -6,7 +6,8 @@ 'ESCHER', 'POLKADOT', 'RIPPLE', 'SWIRL' ] -NUM_UNITS = 59 +COLUMN_NAMES = FEATURES + ["CORRECT", "INCORRECT"] -COLUMN_NAMES_W_UNITS = FEATURES + ["CORRECT", "INCORRECT"] + [f"unit_{i}" for i in range(0, NUM_UNITS)] -COLUMN_NAMES = FEATURES + ["CORRECT", "INCORRECT"] \ No newline at end of file +# time in miliseconds for required fixation on a card to register a choice +# also time between choice and feedback signals +CHOICE_FIXATION_TIME = 800 \ No newline at end of file diff --git a/wcst_encode/data_utils.py b/wcst_encode/data_utils.py index 9cba542..8991ca3 100644 --- a/wcst_encode/data_utils.py +++ b/wcst_encode/data_utils.py @@ -1,6 +1,7 @@ -from .constants import FEATURES +from .constants import FEATURES, CHOICE_FIXATION_TIME import numpy as np import pandas as pd +from itertools import repeat def get_behavior_by_bins(bin_size, beh): """ @@ -9,10 +10,9 @@ def get_behavior_by_bins(bin_size, beh): Returns: new dataframe with one-hot encoding of features, feedback """ max_time = np.max(beh["TrialEnd"].values) - max_bin_idx = int(max_time / bin_size) + 1 + max_bin_idx = int(np.ceil(max_time / bin_size)) columns = FEATURES + ["CORRECT", "INCORRECT"] - types = ["f4" for _ in columns] - zipped = list(zip(columns, types)) + zipped = list(zip(columns, repeat("f4"))) dtype = np.dtype(zipped) arr = np.zeros((max_bin_idx), dtype=dtype) @@ -23,18 +23,20 @@ def get_behavior_by_bins(bin_size, beh): shape = row[f"Item{item_chosen}Shape"] pattern = row[f"Item{item_chosen}Pattern"] - chosen_time = row["FeedbackOnset"] - 800 - chosen_bin = int(chosen_time / bin_size) + chosen_time = row["FeedbackOnset"] - CHOICE_FIXATION_TIME + chosen_bin = int(np.floor(chosen_time / bin_size)) arr[chosen_bin][color] = 1 arr[chosen_bin][shape] = 1 arr[chosen_bin][pattern] = 1 - feedback_bin = int(row["FeedbackOnset"] / bin_size) + feedback_bin = int(np.floor(row["FeedbackOnset"] / bin_size)) # print(feedback_bin) if row["Response"] == "Correct": arr[feedback_bin]["CORRECT"] = 1 - else: + elif row["Response"] == "Incorrect": arr[feedback_bin]["INCORRECT"] = 1 + else: + raise ValueError(f"{row['Response']} is undefined") df = pd.DataFrame(arr) df["bin_idx"] = np.arange(len(df)) return df @@ -50,15 +52,13 @@ def get_spikes_by_bins(bin_size, spike_times): """ units = np.unique(spike_times.UnitID.values) - time_stamp_max = int(spike_times.SpikeTime.max()) + 1 - - num_time_bins = int(time_stamp_max/bin_size) + 1 - bins = np.arange(num_time_bins) * bin_size + num_time_bins = int(spike_times.SpikeTime.max() / bin_size) + 1 + bin_edges = np.arange(num_time_bins) * bin_size df = pd.DataFrame(data={'bin_idx': np.arange(num_time_bins)[:-1]}) for unit in units: unit_spike_times = spike_times[spike_times.UnitID==unit].SpikeTime.values - unit_spike_counts, bin_edges = np.histogram(unit_spike_times, bins=bins) + unit_spike_counts, _ = np.histogram(unit_spike_times, bins=bin_edges) df[f'unit_{unit}'] = unit_spike_counts return df @@ -67,28 +67,27 @@ def get_trial_intervals(behavioral_data, event="FeedbackOnset", pre_interval=0, Args: behavioral_data: Dataframe describing each trial, must contain - columns: TrialNumber, whatever 'event' param describes + columns: TrialNumber, as well as the column corresponding to the `event` parameter event: name of event to align around, must be present as a column name in behavioral_data Dataframe - pre_interval: number of miliseconds before event - post_interval: number of miliseconds after event + pre_interval: number of miliseconds before the event to include. Should be >= 0 + post_interval: number of miliseconds after the event to include. Should be >= 0 Returns: DataFrame with num_trials length, columns: TrialNumber, IntervalStartTime, IntervalEndTime """ + assert (pre_interval >= 0), "pre interval cannot be negative" + assert (post_interval >= 0), "post interval cannot be negative" + trial_event_times = behavioral_data[["TrialNumber", event]] - intervals = np.empty((len(trial_event_times), 3)) - intervals[:, 0] = trial_event_times["TrialNumber"] - intervals[:, 1] = trial_event_times[event] - pre_interval - intervals[:, 2] = trial_event_times[event] + post_interval intervals_df = pd.DataFrame(columns=["TrialNumber", "IntervalStartTime", "IntervalEndTime"]) intervals_df["TrialNumber"] = trial_event_times["TrialNumber"].astype(int) intervals_df["IntervalStartTime"] = trial_event_times[event] - pre_interval intervals_df["IntervalEndTime"] = trial_event_times[event] + post_interval - intervals_df["IntervalStartBin"] = (intervals_df["IntervalStartTime"] / bin_size).astype(int) - intervals_df["IntervalEndBin"] = (intervals_df["IntervalEndTime"] / bin_size).astype(int) + intervals_df["IntervalStartBin"] = np.floor(intervals_df["IntervalStartTime"] / bin_size).astype(int) + intervals_df["IntervalEndBin"] = np.floor(intervals_df["IntervalEndTime"] / bin_size).astype(int) return intervals_df @@ -104,6 +103,7 @@ def get_design_matrix(spikes_by_bins, beh_by_bins, columns, tau_pre, tau_post): tau_post: number of bins to look in the future Returns: df with bin_idx, columns for each time points between tau_pre and tau_post + missing time shift values will be filled with nans """ joint = pd.merge(spikes_by_bins, beh_by_bins, on="bin_idx", how="inner") res = pd.DataFrame() From 9ff1e121d4ccf0c0ff3a589031202a5c7a04fd02 Mon Sep 17 00:00:00 2001 From: Patrick Zhang Date: Tue, 11 Apr 2023 21:23:11 +0000 Subject: [PATCH 7/7] valueerror, math.floor/ceil --- wcst_encode/data_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/wcst_encode/data_utils.py b/wcst_encode/data_utils.py index 8991ca3..2fead6b 100644 --- a/wcst_encode/data_utils.py +++ b/wcst_encode/data_utils.py @@ -1,5 +1,6 @@ from .constants import FEATURES, CHOICE_FIXATION_TIME import numpy as np +import math import pandas as pd from itertools import repeat @@ -10,7 +11,7 @@ def get_behavior_by_bins(bin_size, beh): Returns: new dataframe with one-hot encoding of features, feedback """ max_time = np.max(beh["TrialEnd"].values) - max_bin_idx = int(np.ceil(max_time / bin_size)) + max_bin_idx = math.ceil(max_time / bin_size) columns = FEATURES + ["CORRECT", "INCORRECT"] zipped = list(zip(columns, repeat("f4"))) dtype = np.dtype(zipped) @@ -24,7 +25,7 @@ def get_behavior_by_bins(bin_size, beh): pattern = row[f"Item{item_chosen}Pattern"] chosen_time = row["FeedbackOnset"] - CHOICE_FIXATION_TIME - chosen_bin = int(np.floor(chosen_time / bin_size)) + chosen_bin = math.floor(chosen_time / bin_size) arr[chosen_bin][color] = 1 arr[chosen_bin][shape] = 1 arr[chosen_bin][pattern] = 1 @@ -77,9 +78,9 @@ def get_trial_intervals(behavioral_data, event="FeedbackOnset", pre_interval=0, DataFrame with num_trials length, columns: TrialNumber, IntervalStartTime, IntervalEndTime """ - assert (pre_interval >= 0), "pre interval cannot be negative" - assert (post_interval >= 0), "post interval cannot be negative" - + if pre_interval >= 0 or post_interval >= 0: + raise ValueError("Neither pre_interval: {pre_interval} or post_interval: {post_interval} should be negative") + trial_event_times = behavioral_data[["TrialNumber", event]] intervals_df = pd.DataFrame(columns=["TrialNumber", "IntervalStartTime", "IntervalEndTime"])