-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathphase1_analysis.py
315 lines (276 loc) · 12.3 KB
/
phase1_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import os
import multiprocessing
import datetime
import pymworks
import matplotlib.pyplot as plt
def get_animals_and_their_session_filenames(path):
'''
Returns a dict with animal names as keys (it gets their names from the
folder names in 'input' folder--each animal should have its own
folder with .mwk session files) and a list of .mwk filename strings as
values.
e.g. {'V1': ['V1_140501.mwk', 'V1_140502.mwk']}
:param path: a string of the directory name containing animals' folders
'''
#TODO maybe make this better, it's slow as hell and ugly
result = {}
dirs_list = [each for each in os.walk(path)]
for each in dirs_list[1:]:
files_list = each[2]
animal_name = each[0].split("/")[len(each[0].split("/")) - 1]
result[animal_name] = [] #list of filenames
for filename in files_list:
if not filename.startswith('.'): #dont want hidden files
result[animal_name].append(filename)
print("Starting analysis for animals:")
for each in result.keys():
print(each)
return result
def analyze_sessions(animals_and_sessions, graph_as_group=False):
'''
Starts analysis for each animals' sessions in a new process to use cores.
We don't want to wait all day for this, y'all.
:param animals_and_sessions: a dict with animal names as keys and
a list of their session filenames as values.
'''
#use all CPU cores to process data
pool = multiprocessing.Pool(None)
results = [] #list of multiprocessing.AsyncResult objects
for animal, sessions in animals_and_sessions.iteritems():
result = pool.apply_async(analyze_animal_sessions,
args=(animal, sessions))
results.append(result)
pool.close()
pool.join() #block until all the data has been processed
if graph_as_group:
raise NotImplementedError, "Group graphing coming soon..."
#
print("Graphing session data...")
for each in results:
data_for_animal = each.get() #returns analyze_animal_sessions result
make_a_figure(data_for_animal)
print("Finished")
def make_a_figure(data):
'''
Shows a graph of an animal's performance and trial info.
:param data: a dict with x and y value lists returned by
analyze_animal_sessions()
'''
f, ax_arr = plt.subplots(2, 2) #make 4 subplots for figure
f.suptitle(data["animal_name"]) #set figure title to animal's name
f.subplots_adjust(bottom=0.08, hspace=0.4) #fix overlapping labels
ax_arr[0, 0].plot(data["x_vals"], data["total_pct_correct_y_vals"], "bo")
ax_arr[0, 0].set_title("% correct - all trials")
ax_arr[0, 0].axis([0, len(data["x_vals"]), 0.0, 100.0])
ax_arr[0, 0].set_xlabel("Session number")
ax_arr[0, 1].plot(data["x_vals"], data["pct_corr_in_center_y_vals"], "bo")
ax_arr[0, 1].set_title("% correct - trials with stim in center")
ax_arr[0, 1].axis([0, len(data["x_vals"]), 0.0, 100.0])
ax_arr[0, 1].set_xlabel("Session number")
ax_arr[1, 0].plot(data["x_vals"], data["total_trials_y_vals"], "bo")
ax_arr[1, 0].set_title("Total trials in session")
ax_arr[1, 0].axis([0, len(data["x_vals"]), 0, \
max(data["total_trials_y_vals"])])
#largest y axis tick is largest number of trials in a session
ax_arr[1, 0].set_xlabel("Session number")
ax_arr[1, 1].plot(data["x_vals"], data["num_trials_stim_in_center_y_vals"],
"bo")
ax_arr[1, 1].set_title("Total trials with stim in center of the screen")
ax_arr[1, 1].axis([0, len(data["x_vals"]), 0, \
max(data["total_trials_y_vals"])])
#largest y axis tick is largest number of trials in a session
#so it's easier to compare total trials and total trials with
#stim in center
ax_arr[1, 1].set_xlabel("Session number")
plt.show() #show each figure, user can save if he/she wants
#make plot of the % of trials in center
plt.close("all")
plt.plot(data["x_vals"], data["pct_trials_stim_in_center"], "bo")
plt.axis([0, len(data["x_vals"]), 0.0, 100.0])
plt.title("% trials with stim in center " + data["animal_name"])
plt.xlabel("Session number")
plt.show()
def analyze_animal_sessions(animal_name, sessions):
'''
Analyzes one animals' sessions and outputs dict with x and y value lists
for different types of graphs, e.g. percent correct, total trials, etc.
See return dict below.
This is wrapped by analyze_sessions() so it can run in a process on
each CPU core.
:param animal_name: name of the animal (string)
:param sessions: the animal's session filenames (list of strings)
'''
list_of_session_stats = get_stats_for_each_session(animal_name, sessions)
x_vals = [each["session_number"] for each in list_of_session_stats]
pct_corr_whole_session_y = [each["pct_correct_whole_session"] for each in \
list_of_session_stats]
pct_corr_in_center_y = [each["pct_correct_stim_in_center"] for each in \
list_of_session_stats]
total_num_trials_y = [each["total_trials"] for each in \
list_of_session_stats]
total_trials_stim_in_center_y = [each["trials_with_stim_in_center"] for \
each in list_of_session_stats]
pct_trials_stim_in_center = [each["pct_trials_stim_in_center"] for \
each in list_of_session_stats]
return {"x_vals": x_vals, #x axis will be session number for all graphs
"total_pct_correct_y_vals": pct_corr_whole_session_y,
"pct_corr_in_center_y_vals": pct_corr_in_center_y,
"total_trials_y_vals": total_num_trials_y,
"num_trials_stim_in_center_y_vals": total_trials_stim_in_center_y,
"pct_trials_stim_in_center": pct_trials_stim_in_center,
"animal_name": animal_name}
def get_stats_for_each_session(animal_name, sessions):
'''
Returns a list of dicts with statistics about each session for an
animal. e.g.
result = [{
'session_number': 1,
'ignores': 2,
'successes': 2,
'failures': 0,
'pct_correct_whole_session': 50.0,
'pct_correct_stim_in_center': 50.0,
'total_trials': 4,
'trials_with_stim_in_center': 4,
'pct_trials_stim_in_center': 100.0
},
#Note the NoneType values in this session
{ 'session_number': 2,
'ignores': 0,
'successes': 0,
'failures': 0,
'pct_correct_whole_session': None,
'pct_correct_stim_in_center': None,
'total_trials': 0,
'trials_with_stim_in_center': 0,
'pct_trials_stim_in_center': None}]
NOTE: if there are no trials for the denominator of a percentage key
(e.g. pct_correct_stim_in_center), the key's value is set to None.
Behavior outcomes (e.g. ignores, successes, etc.) with no occurances
are left with value = 0.
'''
result = []
session_num = 1
for session in sessions:
all_trials = get_session_statistics(animal_name, session)
#make dict to store session data
session_result = {"session_number": session_num,
"total_trials": len(all_trials),
"filename": session}
#go through each trial to get stats
all_success = 0
all_failure = 0
all_ignore = 0
success_in_center = 0
failure_in_center = 0
ignore_in_center = 0
for trial in all_trials:
if trial["behavior_outcome"] == "success":
if trial["stm_pos_x"] == 0.0:
success_in_center += 1
all_success += 1
elif trial["behavior_outcome"] == "failure":
if trial["stm_pos_x"] == 0.0:
failure_in_center += 1
all_failure += 1
elif trial["behavior_outcome"] == "ignore":
if trial["stm_pos_x"] == 0.0:
ignore_in_center += 1
all_ignore += 1
else:
print "No behavior_outcome for %s %s\
, trial number %s" % (animal_name, session, trial["trial_num"])
#add session data to session result dict
session_result["successes"] = all_success
session_result["failures"] = all_failure
session_result["ignores"] = all_ignore
try:
session_result["pct_correct_whole_session"] = (float(all_success)/\
(float(all_success + all_ignore + all_failure))) * 100.0
except ZeroDivisionError:
session_result["pct_correct_whole_session"] = None
try:
session_result["pct_correct_stim_in_center"] = \
(float(success_in_center)/(float(success_in_center + \
failure_in_center + ignore_in_center))) * 100.0
except ZeroDivisionError:
session_result["pct_correct_stim_in_center"] = None
session_result["trials_with_stim_in_center"] = \
success_in_center + failure_in_center + ignore_in_center
try:
session_result["pct_trials_stim_in_center"] = \
(float(session_result["trials_with_stim_in_center"])/\
(float(len(all_trials)))) * 100.0
except ZeroDivisionError:
session_result["pct_trials_stim_in_center"] = None
#add each session's result dict to the list of session result dicts
result.append(session_result)
session_num += 1
return result
def get_session_statistics(animal_name, session_filename):
'''
Returns a time-ordered list of dicts, where each dict is info about a trial.
e.g. [{"trial_num": 1,
"behavior_outcome": "failure",
"stm_pos_x": 7.5,
},
{"trial_num": 2,
"behavior_outcome": "success",
"stm_pos_x": -7.5
}]
NOTE: trial_num: 1 corresponds to the FIRST trial in the session,
and trials occur when Announce_TrialStart and Announce_TrialEnd
events have success, failure, or ignore events between them with
value=1.
:param animal_name: name of the animal string
:param session_filename: filename for the session (string)
'''
#TODO: unfuck this: hard coded paths not ideal for code reuse
path = 'input/' + 'phase1/' + animal_name + '/' + session_filename
df = pymworks.open_file(path)
events = df.get_events(["Announce_TrialStart", "Announce_TrialEnd",
"success", "failure", "ignore", "stm_pos_x"])
result = []
index = 0
temp_events = []
last_announce = None
trial_num = 0
while index < len(events):
if events[index].name == "Announce_TrialStart":
temp_events = []
last_announce = "Announce_TrialStart"
elif events[index].name == "Announce_TrialEnd":
if last_announce == "Announce_TrialStart":
trial_result = {}
for ev in temp_events:
if ev.name == "success" and ev.value == 1:
trial_result["behavior_outcome"] = "success"
elif ev.name == "failure" and ev.value == 1:
trial_result["behavior_outcome"] = "failure"
elif ev.name == "ignore" and ev.value == 1:
trial_result["behavior_outcome"] = "ignore"
elif ev.name == "stm_pos_x":
trial_result["stm_pos_x"] = ev.value
else:
pass
if "behavior_outcome" in trial_result:
trial_num += 1
trial_result["trial_num"] = trial_num
result.append(trial_result)
last_announce = "Announce_TrialEnd"
else:
temp_events.append(events[index])
index += 1
#FYI, testing showed some good filtering of weird events here...
#blah = df.get_events(["success", "failure", "ignore"])
#print "EVENTS EQUAL? ", len(result) == len(blah) - 6, session_filename
#subtract 6 because session initialization emits 2 behavior outcomes per
#outcome type
#print len(result), len(blah) - 6
#lines above unequal in 6/77 sessions for AB3&7 because of random behavior
#outcome events firing in rapid succession. They happens within a couple
#microseconds of one another so filtering these out is probably good
return result
if __name__ == "__main__":
animals_and_sessions = get_animals_and_their_session_filenames('input/phase1')
analyze_sessions(animals_and_sessions)