-
Notifications
You must be signed in to change notification settings - Fork 0
/
trillion.py
1387 lines (1199 loc) · 47.8 KB
/
trillion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import sys
import os
import urllib
import random
import csv
import json
from itertools import combinations
from datetime import datetime
from collections import OrderedDict,Counter
import numpy as np
import matplotlib.pyplot as plt
from scipy.misc import factorial
from scipy.stats import binom
HERE = os.path.dirname(__file__)
VERBOSE = False
DILUTION = {'1/4':0.25, '1/2':0.5, 'not diluted':1.0}
CORRECT = {'right':True, 'wrong':False}
OVERLAPS = {10:[9,6,3,0],
20:[19,15,10,5,0],
30:[29,20,10,0]}
ALPHA = 0.05
ALPHAS_LIST = [0.0005,0.001,0.0025,0.005,0.01,0.025,0.05]
C = 128
C_LIST = C*2.0**np.arange(-2,13)
N_TESTS = 20
N_TESTS_LIST = np.ceil(N_TESTS*2.0**np.arange(3.5,-2.5,-0.5))
N_SUBJECTS = 26
N_SUBJECTS_LIST = np.ceil(N_SUBJECTS*2.0**np.arange(3.5,-2.5,-0.5))
def print_(*args,**kwargs):
global VERBOSE
if VERBOSE:
print(args,kwargs)
def list_combos(n,k):
"""
Returns a list of all combinations of n choose k.
"""
return list(combinations(list(range(n)),k))
def get_n_combos(n,k):
"""
Returns the number of combinations of n choose k.
Uses Stirling's approximation when numbers are very large.
"""
result = None
try:
fac_n = factorial(n)
fac_k = factorial(k)
fac_nk = factorial(n-k)
summ = fac_n + fac_k + fac_nk
if summ == np.inf or summ != summ:
raise ValueError("Values too large. Using Stirling's approximation")
result = fac_n/fac_k
result /= fac_nk
except: # Catch all large number exceptions.
pass
if not result or result == np.inf or result != result: # No result yet.
if n == k or k == 0:
result = 1
else:
x = stirling(n) - stirling(k) - stirling(n-k) # Use Stirling's approx.
result = np.exp(x)
return result
def stirling(n):
"""
Given n, returns Stirling's approximation for log(n!).
"""
return n*np.log(n) - n + 0.5*np.log(2*np.pi*n)
def sphere(N,C,R):
"""
Formula for sphere from Bushdid supplemental material.
N = Number of components in a mixture.
C = Number of components to choose from.
R = Number of differing components.
"""
if R == 0:
result = 1
else:
result = get_n_combos(N,min(R,N))*get_n_combos(C-N,min(R,C-N))
return result
def ball(N,C,R):
"""
Formula for ball from Bushdid supplemental material.
N = Number of components in a mixture.
C = Number of components to choose from.
R = Maximum number of differing components.
"""
assert R<=N
result = 0
for r in range(0,R+1):
result += sphere(N,C,r)
return result
def disc(N,C,d):
"""
Formula for number of discriminable stimuli from Bushdid supplemental material.
N = Number of components in a mixture.
C = Number of components to choose from.
d = Discriminability limen.
"""
# Formulas only makes sense with integers, so we round up and down.
low = get_n_combos(C,N) / ball(N,C,int(np.floor(d/2)))
high = get_n_combos(C,N) / ball(N,C,int(np.ceil(d/2)))
# Geometric interpolation so we can graph smooth lines.
low = np.log(low)
high = np.log(high)
result = low+(high-low)*(d/2-np.floor(d/2)) # Interpolate.
result = np.exp(result)
if result < 1:
result = 1
return result
def disc_prime(N,C,d):
"""
Corrected formula for number of discriminable stimuli from Gerkin and Castro.
N = Number of components in a mixture.
C = Number of components to choose from.
d = Discriminability limen.
"""
#d = d-1
# Formulas only makes sense with integers, so we round up and down.
low = get_n_combos(C,N) / ball(N,C,int(np.floor(d)))
high = get_n_combos(C,N) / ball(N,C,int(np.ceil(d)))
# Geometric interpolation so we can graph smooth lines.
low = np.log(low)
high = np.log(high)
result = low+(high-low)*(d-np.floor(d)) # Interpolate.
result = np.exp(result)
if result < 1:
result = 1
return result
'''
def disc_brute(N,C,d):
"""
Compute discriminable stimuli by brute force.
N = Number of components in a mixture.
C = Number of components to choose from.
d = Discriminability limen.
"""
combos = list_combos(C,N)
n_combos = len(combos)
discriminable = np.zeros((len(combos),len(combos)))
for i in range(n_combos):
for j in range(n_combos):
discriminable[i][j] = len(set(combos[i]).difference(combos[j])) >= d
result = 1+discriminable.sum(axis=1).max()
return result
'''
def fdr(alpha,p_list):
"""
Controls for false discovery rate using the Benjamin and Hochberg procedure.
Given a nominal Type 1 error rate alpha and a list of nominal p-values,
returns a corresponding list of booleans, set to True only if the null
hypothesis should still be rejected after controlling the false discovery rate.
"""
m = len(p_list)
p_list_sorted = sorted(p_list)
reject = [False for p in p_list]
for k,p in enumerate(p_list):
print_(p,k*alpha/m)
if p <= k*alpha/m:
reject[p_list.index(p)] = True
return reject
class Odorant(object):
"""
A mixture of molecules, defined by the presence of absence of the
candidate molecules in the mixture.
"""
def __init__(self, components=None):
"""
Builds odorant from a list of components.
"""
self.components = components if components else []
name = None # Name of odorant, built from a hash of component names.
C = 128 # Number of components from which to choose.
@property
def N(self):
"""
Number of components in this odorant.
"""
return len(self.components)
def r(self,other):
"""
Number of replacements (swaps) to get from self to another other odorant.
"""
if len(self.components) == len(other.components):
return self.hamming(other)/2
else:
return None
def overlap(self,other,percent=False):
"""
Overlap between self and another odorant. Complement of r.
Optionally report result as percent relative to number of components.
"""
overlap = self.N - self.r(other)
if percent:
overlap = overlap*100.0/self.N
return overlap
def hamming(self, other):
"""
Hamming distance between self and another odorant.
Synonymous with number of d, the number of total 'moves' to go from
one odorant to another.
"""
x = set(self.components)
y = set(other.components)
diff = len(x)+len(y)-2*len(x.intersection(y))
return diff
def add_component(self, component):
"""
Adds one component to an odorant.
"""
self.components.append(component)
def remove_component(self, component):
"""
Removes one component to an odorant.
"""
self.components.remove(component)
def descriptor_list(self,source):
"""
Given a data source, returns a list of descriptors about this odorant.
"""
descriptors = []
for component in self.components:
if source in component.descriptors:
desc = component.descriptors[source]
if type(desc) == list:
descriptors += desc
if type(desc) == dict:
descriptors += [key for key,value in list(desc.items()) if value > 0.0]
return list(set(descriptors)) # Remove duplicates.
def descriptor_vector(self,source,all_descriptors):
"""
Given a data source, returns a vector of descriptors about this odorant.
The vector will contain positive floats.
"""
vector = np.zeros(len(all_descriptors[source]))
for component in self.components:
if source in component.descriptors:
desc = component.descriptors[source]
if type(desc) == list:
for descriptor in desc:
index = all_descriptors[source].index(descriptor)
assert index >= 0
vector[index] += 1
if type(desc) == dict:
this_vector = np.array([value for key,value in sorted(desc.items())])
vector += this_vector
return vector
def descriptor_vector2(self,all_descriptors):
"""
Returns a vector of descriptors about this odorant, combining multiple
data sources.
"""
n_descriptors_dravnieks = len(all_descriptors['dravnieks'])
n_descriptors_sigma_ff = len(all_descriptors['sigma_ff'])
vector = np.zeros(n_descriptors_dravnieks+n_descriptors_sigma_ff)
for component in self.components:
if 'dravnieks' in component.descriptors:
desc = component.descriptors['dravnieks']
this_vector = np.array([value for key,value in sorted(desc.items())])
vector[0:n_descriptors_dravnieks] += this_vector
elif 'sigma_ff' in component.descriptors:
desc = component.descriptors['sigma_ff']
for descriptor in desc:
index = all_descriptors['sigma_ff'].index(descriptor)
assert index >= 0
vector[n_descriptors_dravnieks+index] += 1
return vector
def described_components(self,source):
"""
Given a data source, returns a list of the components which are
described by that source, i.e. those that have descriptors.
"""
return [component for component in self.components \
if source in component.descriptors]
def n_described_components(self,source):
"""
Given a data source, returns the number of components that are
described by that data source.
"""
return len(self.described_components(source))
def fraction_components_described(self,source):
"""
Given a data source, returns the fraction of components that are
described by that data source.
"""
return self.n_described_components(source) / self.N
def matrix(self,features,weights=None):
matrix = np.vstack([component.vector(features,weights=weights) \
for component in self.components \
if component.cid in features])
if 0:#matrix.shape[0] != self.N:
print('Odorant has %d components but only %d vectors were computed' % \
(self.N,matrix.shape[0]))
return matrix
def vector(self,features,weights=None,method='sum'):
matrix = self.matrix(features,weights=weights)
if method == 'sum':
vector = matrix.sum(axis=0)
else:
vector = None
return vector
def __str__(self):
"""
String representation of the odorant.
"""
return ','.join([str(x) for x in self.components])
class Component(object):
"""
A single molecule, which may or may not be present in an odorant.
"""
def __init__(self,component_id,name,cas,percent,solvent):
"""
Components are defined by a component_id from the Bushdid et al
supplemental material, a name, a CAS number, a percent dilution,
and a solvent.
"""
self.id = component_id
self.name = name
self.cas = cas
self.cid_ = None
self.percent = percent
self.solvent = solvent
self.descriptors = {} # An empty dictionary.
@property
def cid(self):
if self.cid_:
cid = self.cid_
else:
url_template = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/%s/cids/JSON"
for query in self.cas,self.name:
try:
url = url_template % query
page = urllib.request.urlopen(url)
string = page.read().decode('utf-8')
json_data = json.loads(string)
cid = json_data['IdentifierList']['CID'][0]
except urllib.error.HTTPError as e:
print(query)
else:
break
self.cid_ = cid
return cid
def set_descriptors(self,source,cas_descriptors):
"""
Given a data source, sets descriptors for this odorant using
a dictionary where CAS numbers are keys, and descriptors are values.
"""
assert type(source)==str and len(source)
if self.cas in cas_descriptors:
self.descriptors[source] = cas_descriptors[self.cas]
# For sigma_ff this will be a list.
# For dravnieks this will be a dict.
def vector(self,features,weights=None):
if self.cid in features:
feature_values = np.array(list(features[self.cid].values()))
if weights is None:
weights = np.ones(feature_values.shape)
result = feature_values * weights
else:
result = None
return result
def __str__(self):
return self.name
class Test(object):
"""
One kind of experimental test performed by Bushdid et al.
This corresponding to a 'triangle test' with two odorants, and is
defined by those odorants.
"""
def __init__(self,test_uid,odorants,dilution,correct):
"""
Tests are defined by their universal identifier (UID), the 3 odorants
used (2 should be identical), the dilution, and the identity of the
correct response, which should be the odd-ball.
"""
self.id = test_uid
self.odorants = odorants
self.dilution = dilution
self.correct = correct
def add_odorant(self,odorant):
"""
Adds one odorant to this test.
"""
self.odorants.append(odorant)
@property
def double(self):
"""
Returns the odorant present twice in this test.
"""
for odorant in self.odorants:
if self.odorants.count(odorant) == 2:
return odorant
return None
@property
def single(self):
"""
Returns the odorant present once in this test.
"""
for odorant in self.odorants:
if self.odorants.count(odorant) == 1:
return odorant
return None
@property
def pair(self):
"""
Returns the odorant pair in this test, with the odorant present
twice listed first.
"""
return (self.double,self.single)
@property
def N(self):
"""
Returns the number of components in each of the odorants.
This a single value since they should all have the same number
of components.
"""
return self.double.N
@property
def r(self):
"""
Returns the number of component replacements (swaps) separating one of
the odorants from the other.
"""
return self.double.r(self.single)
def overlap(self, percent=False):
"""
Returns the overlap (complement of r) between the two odorants.
Optionally returns this as a percentage of N.
"""
return self.double.overlap(self.single,percent=percent)
@property
def common_components(self):
"""
Returns a list of components common to the two odorants.
"""
d = set(self.double.components)
s = set(self.single.components)
return list(s.intersection(d))
@property
def unique_components(self):
"""
Returns a list of components that exactly one of the two odorants has.
"""
d = set(self.double.components)
s = set(self.single.components)
return list(s.symmetric_difference(d))
def unique_descriptors(self,source):
"""
Given a data source, returns a list of descriptors that
exactly one of the two odorants has.
"""
sl = self.single.descriptor_list(source)
dl = self.double.descriptor_list(source)
unique = set(dl).symmetric_difference(set(sl))
return list(unique)
def common_descriptors(self,source):
"""
Given a data source, returns a list of descriptors that
are common to the two odorants.
"""
sl = self.single.descriptor_list(source)
dl = self.double.descriptor_list(source)
unique = set(dl).intersection(set(sl))
return list(unique)
def descriptors_correlation(self,source,all_descriptors):
"""
Given a data source, returns the correlation between the descriptors
of the two odorants.
"""
sv = self.single.descriptor_vector(source,all_descriptors)
dv = self.double.descriptor_vector(source,all_descriptors)
return np.corrcoef(sv,dv)[1][0]
def descriptors_correlation2(self,all_descriptors):
"""
Returns the correlation between the descriptors
of the two odorants, combining multiple data sources.
"""
sv = self.single.descriptor_vector2(all_descriptors)
dv = self.double.descriptor_vector2(all_descriptors)
return np.corrcoef(sv,dv)[1][0]
def descriptors_difference(self,source,all_descriptors):
"""
Given a data source, returns the absolute difference between the descriptors
of the two odorants.
"""
sv = self.single.descriptor_vector(source,all_descriptors)
dv = self.double.descriptor_vector(source,all_descriptors)
return np.abs(sv-dv)
def n_undescribed(self,source):
"""
Given a data source, returns the number of components from among the
two odorants that are not described by that source.
"""
d = self.double.n_described_components(source)
s = self.single.n_described_components(source)
return (self.N-d,self.N-s)
@classmethod
def length(cls,v):
return np.sqrt(np.dot(v, v))
@classmethod
def find_angle(cls,v1,v2):
return np.arccos(np.dot(v1, v2) / (cls.length(v1) * cls.length(v2)))
@classmethod
def circmean(cls,angles):
return np.arctan2(np.mean(np.sin(angles)),np.mean(np.cos(angles)))
def angle(self,features,weights=None,method='sum',method_param=1.0):
if method == 'sum':
v1 = self.single.vector(features,weights=weights,method=method)
v2 = self.double.vector(features,weights=weights,method=method)
angle = self.find_angle(v1,v2)
elif method == 'nn': # Nearest-Neighbor.
m1 = self.single.matrix(features,weights=weights)
m2 = self.double.matrix(features,weights=weights)
angles = []
for i in range(m1.shape[0]):
angles_i = []
for j in range(m2.shape[0]):
one_angle = self.find_angle(m1[i,:],m2[j,:])
if np.isnan(one_angle):
one_angle = 1.0
angles_i.append(one_angle)
angles_i = np.array(sorted(angles_i))
from scipy.stats import geom
weights_i = geom.pmf(range(1,len(angles_i)+1),method_param)
angles.append(np.dot(angles_i,weights_i))
angle = np.abs(angles).mean()#circmean(angles)
return angle
def norm(self,features,order=1,weights=None,method='sum'):
v1 = self.single.vector(features,weights=weights,method=method)
v2 = self.double.vector(features,weights=weights,method=method)
dv = v1-v2
dv = np.abs(dv)**order
return np.sum(dv)
def distance(self,features,weights=None,method='sum'):
v1 = self.single.vector(features,weights=weights,method=method)
v2 = self.double.vector(features,weights=weights,method=method)
return np.sqrt(((v1-v2)**2).sum())
def fraction_correct(self,results):
num,denom = 0.0,0.0
for result in results:
if result.test.id == self.id:
num += result.correct
denom += 1
return num/denom
class Result(object):
"""
A test result, corresponding to one test given to one subject.
"""
def __init__(self, test, subject_id, correct):
"""
Results are defined by the test to which they correspond,
the id of the subject taking that test, and whether the subject
gave the correct answer.
"""
self.test = test
self.subject_id = subject_id
self.correct = correct
class Distance(object):
"""
An odorant distance, corresponding to distance between two odorants.
No particular implementation for computing distance is mandated.
"""
def __init__(self, odorant_i, odorant_j, distance):
self.odorant_i = odorant_i
self.odorant_j = odorant_j
self.distance = distance
def load_components():
"""
Loads all odorant components from Supplemental Table 1 of Bushdid et al.
"""
components = []
f = open(os.path.join(HERE,'Bushdid-tableS1.csv'),'r',encoding='latin1')
reader = csv.reader(f)
next(reader)
component_id = 0
for row in reader:
name,cas,percent,solvent = row[:4]
if len(name):
component = Component(component_id,name,cas,percent,solvent)
components.append(component)
component_id += 1
else:
break
return components
def load_odorants_tests_results(all_components):
"""
Given all odor components, loads the odorants, tests, and test results
from Supplemental Table 2 of Bushdid et al.
"""
odorants = {}
tests = {}
results = []
f = open(os.path.join(HERE,'Bushdid-tableS2.csv'),'r',encoding='latin1')
reader = csv.reader(f)
next(reader)
row_num = 0
for row in reader:
uid,n,r,percent,dilution,correct = row[:6]
component_names = [x for x in row[6:36] if len(x)]
# The next line is required to account for inconsistent naming of one
# of the components across the two supplemental tables.
component_names = [x.replace('4-Methyl-3-penten-2-one',
'4-methylpent-3-en-2-one') for x in component_names]
outcomes = row[36:62]
if uid.isdigit():
uid = int(uid)
dilution = DILUTION[dilution]
odorant_key = hash(tuple(component_names))
if odorant_key not in odorants:
components = [component for component in all_components \
if component.name in component_names]
if len(components) not in [1,10,20,30]:
# If an odorant has a number of components which is not
# either 1, 10, 20, or 30.
print_(uid,[x for x in component_names if x not in [y.name for y in components]])
odorant = Odorant(components)
odorant.name = odorant_key
elif row_num % 3 == 0:
# If any component is repeated across all the tests.
print_("Repeat of this odorant: %d" % odorant_key)
odorants[odorant_key] = odorant
if uid not in tests:
tests[uid] = Test(uid,[],dilution,correct)
test = tests[uid]
test.add_odorant(odorant)
if correct == 'right':
test.correct = tests[uid].odorants.index(odorant)
if len(outcomes[0]):
for i,outcome in enumerate(outcomes):
result = Result(test,i+1,CORRECT[outcome])
results.append(result)
row_num += 1
return odorants,tests,results
def odorant_distances(results,subject_id=None):
"""
Given the test results, returns a dictionary whose keys are odorant pairs
and whose values are psychometric distances between those pairs,
defined as the fraction of discriminations that were incorrect.
This can be limited to one subject indicated by subject_id, or else
by default it pools across all subjects.
"""
distances = {}
distance_n_subjects = {}
for result in results:
if subject_id and result.subject_id != subject_id:
continue
pair = result.test.pair
if pair not in distances:
distances[pair] = 0
distance_n_subjects[pair] = 0
distances[pair] += 0.0 if result.correct else 1.0
distance_n_subjects[pair] += 1
for pair in list(distances.keys()):
# Divided by the total number of subjects.
distances[pair] /= distance_n_subjects[pair]
return distances
def ROC(results, N):
"""
Given test results and a number of components N, returns a distribution
of the number of distinct components 'r' for correct trials (right) and
incorrect trials (wrong), in tests using odorants with N total components.
These can later be plotted or used to generated an ROC curve.
"""
right = []
wrong = []
for result in results:
if result.test.N == N:
r = result.test.r
if result.correct:
right.append(r)
else:
wrong.append(r)
right = np.array(right) # Distribution of r for correct trials.
wrong = np.array(wrong) # Distribution of r for incorrect trials.
return (right,wrong)
def correct_matrix(results,N,overlap):
"""
Given test results, a number of components N, and a level of overlap
between odorants, returns a num_subjects by num_test matrix of booleans
corresponding to the correctness of that subject's response on that test.
"""
results = [r for r in results if (N is None or r.test.N==N) and (overlap is None or r.test.overlap()==overlap)]
subjects = [r.subject_id for r in results]
subjects = list(set(subjects))
tests = [r.test for r in results]
tests = list(set(tests))
correct = np.zeros((len(subjects),len(tests)))
correct -= 1 # Set to make sure every point gets set to 0 or 1 later.
for result in results:
i = subjects.index(result.subject_id)
j = tests.index(result.test)
correct[i,j] = result.correct
return correct, subjects, tests
def fraction_disc(results,N,overlap,fig,alpha=None,multiple_correction=False,n_replicates=None):
"""
Given test results, a number of components N, a level of overlap between
odorants, a reference figure panel ('a' or 'b'), an optional choice of
significance threshold alpha, whether or not to do multiple comparisons
correction (false discovery rate method), and an optional new number of
replicates (subjects or tests), returns an array containing either the
fraction of correct responses (if alpha is None) or whether or not that
fraction is significantly above chance (if alpha is a number).
This function assists with generating variants of Figs. 2B, 2C, 3A,
and 3B in Bushdid et al.
"""
assert fig in ['a','b']
correct,_,_ = correct_matrix(results,N,overlap)
if fig == 'a':
dim = 1
elif fig == 'b':
dim = 0
fract_correct = np.mean(correct,dim)
if alpha is not None:
if not n_replicates:
n_replicates = correct.shape[dim] # n_subjects or n_tests.
ps = 1.0 - binom.cdf(fract_correct*n_replicates,n_replicates,1.0/3)
if multiple_correction == 'bonferroni':
alpha = alpha/len(ps)
if multiple_correction == 'fdr':
ps = np.array([p*len(ps)/(k+1) for k,p in enumerate(sorted(ps))])
fract_sig = ps < alpha/2
return fract_sig
else:
return fract_correct
def fig2x(results,fig='b',plot=True):
"""
Given test results, a reference figure panel ('b' or 'c'), plots the data
summary corresponding to Fig. 2 from Bushdid et al.
"""
assert fig in ('b','c')
overlap_dict = {10:[9,6,3,0],
20:[19,15,10,5,0],
30:[29,20,10,0]}
f, axarr = plt.subplots(1, 3, sharey=True)
for i,(N,overlaps) in enumerate(overlap_dict.items()):
X = []
Y = []
for j,overlap in enumerate(overlaps):
fract = fraction_disc(results,N,overlap,chr(ord(fig)-1),alpha=None)
Y += list(fract)
counts = Counter(fract)
observed = {_:0 for _ in fract}
for value in fract:
inc = (observed[value] + 1)/(counts[value] + 1)-0.5
X += [j+inc]
observed[value] += 1
axarr[i].scatter(X,Y)
axarr[i].set_xticks([0,1,2,3,4])
overlaps = np.array(overlaps)*100.0/N
axarr[i].set_xticklabels(['%.2g'%_ for _ in overlaps])
axarr[0].set_ylabel("% correct")
axarr[1].set_xlabel("% mixture overlap")
plt.tight_layout()
plt.subplots_adjust(wspace=0.4)
plt.savefig('fig2%s.eps' % fig)
#plt.show()
def fig3x(results,fig='a',alpha=ALPHA,multiple_correction=False,n_replicates=None,threshold=50.0,plot=True):
"""
Given test results, a reference figure panel ('a' or 'b'), an optional
choice of significance threshold alpha, whether or not to do multiple comparisons
correction (false discovery rate method), and an optional new number of
replicates (subjects or tests), returns the percent overlap of components
at which 50 percent of subjects/tests discriminate/can be discriminated
significantly above chance levels. It does this by generating an analogue
of Fig. 3A or 3B from Bushdid et al, and computing the linear regression as
in those figures, and identifying the point at which the line intersects
with a horizontal line going through 50 percent on the ordinate.
By default, it plots the corresponding figure.
"""
tens_overlap = 100.0*np.array((9,6,3,0))/10.0
twenties_overlap = 100.0*np.array([19,15,10,5,0])/20.0
thirties_overlap = 100.0*np.array([29,20,10,0])/30.0
def do(N,x,results):
y = np.zeros(len(x))
for i,overlap in enumerate(x):
f = fraction_disc(results,N,int(overlap*N/100.0),fig,alpha=alpha,multiple_correction=multiple_correction,n_replicates=n_replicates)
y[i] = np.mean(f)
return y*100.0
tens = do(10,tens_overlap,results)
twenties = do(20,twenties_overlap,results)
thirties = do(30,thirties_overlap,results)
if plot:
plt.figure()
plt.scatter(tens_overlap,tens,s=40,c='w',marker='D')
plt.scatter(twenties_overlap,twenties,s=40,c='magenta')
plt.scatter(thirties_overlap,thirties,s=40,c='g',marker='s')
plt.xlim(101,-1)
plt.ylim(-1,101)
for name,x,y in [('tens',tens_overlap,tens),
('twenties',twenties_overlap,twenties),
('thirties',thirties_overlap,thirties)]:
np.savetxt("fig3%s_%s.csv" % (fig,name),
np.array((x,y)).transpose(), delimiter=",")
overlap = np.concatenate((tens_overlap,twenties_overlap,thirties_overlap))
percent_disc = np.concatenate((tens,twenties,thirties))
A = np.array([np.ones(len(overlap)), overlap])
w = np.linalg.lstsq(A.T,percent_disc)[0] # obtaining the parameters
xi = np.arange(0,100)
line = w[0]+w[1]*xi # Regression line
if plot:
plt.plot(xi,line,'k-') # Plotting the line
plt.xlabel('% mixture overlap')
if fig == 'a':
plt.ylabel('% subjects that can discriminate')
elif fig == 'b':
plt.ylabel('% mixtures that are discriminable')
if w[0] == 0.0 and w[1] == 0.0:
overlap = 0.0
else:
overlap = (threshold - w[0])/w[1]
print_(('%f%% discrimination at %.3g%% overlap for alpha = '+('%g' if alpha else '%s')) \
% (threshold,overlap,alpha))
if overlap > 100.0:
overlap = 100.0
if overlap < 0.0:
overlap = 0.0
if plot:
plt.plot([100,0],[threshold,threshold],'k--')
plt.plot([overlap,overlap],[0,threshold],'k:')
return overlap
def fig4x(results,fig='c',alpha=ALPHA,gamma=2,
axes=None,Ns=[10,20,30],xlabel=True,ylabel=True,lines=True):
if fig=='c': # 4c corresponds to...
fig=='a' # ...3a.
elif fig=='d': # 4d corresponds to...
fig=='b' # ...3b.
alpha_ = alpha if fig=='a' else alpha/2
n_replicates = N_TESTS if fig=='a' else N_SUBJECTS
result = []
if axes is None:
plt.figure()
axes = plt.gca()
for color,N_ in [('blue',10),('magenta',20),('green',30)]:
if N_ not in Ns:
continue
x = np.arange(0,N_+1)
z = [disc(N_,C,_*2/gamma) for _ in x]
O = 100*(1-x/N_) # Overlap, but expressed as a percent of N_.
D = fig3x(results,fig=fig,alpha=alpha_,plot=False)
D_f = (100-D)/100 # Express as fraction distinct components.
axes.plot(O,z,color=color,linewidth=2)
#np.savetxt("z_vs_overlap_%s_%s.csv" % (disc.__name__,color), np.array((o,y)).transpose(), delimiter=",")
z_ = disc(N_,C,N_*D_f*2/gamma)
if lines:
axes.plot([0,D],[z_,z_],'k')
axes.plot([D,D],[1,z_],'k')
axes.set_yscale('log')
if xlabel:
axes.set_xlabel("% mixture overlap \n allowing discrimination",
{'size':12})
if ylabel:
axes.set_ylabel('Estimated number of \n discriminable stimuli $\hat{z}$',
{'size':12})
result.append((N_,z_))
np.savetxt('%d_%d.dat' % (N_,gamma),np.vstack((np.array(O),np.array(z))).transpose())
return result
def overlap(results,fig='a',alphas=ALPHA*10.0**np.arange(-2,0.25,0.25),multiple_correction=False,n_replicates=None):
"""
Given test results, a reference figure panel ('a' or 'b'), a range of
significance thresholds alpha, whether or not to do multiple comparisons