Skip to content

Commit

Permalink
back-end metric
Browse files Browse the repository at this point in the history
  • Loading branch information
Snowdar committed Jun 29, 2020
1 parent 4d795ac commit fb19cfc
Show file tree
Hide file tree
Showing 5 changed files with 373 additions and 11 deletions.
127 changes: 127 additions & 0 deletions computeEER-like-Bosaris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright xmuspeech (Author:Snowdar 2019-01-10)

# It is a little different with Kaldi method.
# By this method, EER is estimated by avaraging the error rates of two points nearby center.

import sys
import argparse

def get_args():
# Start
parser = argparse.ArgumentParser(
description="""Compute EER.""",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
conflict_handler='resolve')

# Main
parser.add_argument("trials_path", metavar="trials_path", type=str, help="The path of trials.")
parser.add_argument("score_path", metavar="score_path", type=str, help="The path of the scores.")

# End
print(' '.join(sys.argv))
args = parser.parse_args()

return args

def load_data(data_path,n):
list=[]
print("Load data from "+data_path+"...")
with open(data_path,'r') as f:
content=f.readlines()
for line in content:
line=line.strip()
data_list=line.split()
if(n!=len(data_list)):
print('[Error] The %s file has no %s fields'%(data_path,n))
exit(1)

list.append(data_list)
return list

def abs(x):
if(x<0):
return -x
else:
return x

def compute_eer(allScores):
numP=0
numN=0
for x in allScores:
if(x[1]=="target"):
x[1]=1
numP=numP+1
elif(x[1]=="nontarget"):
x[1]=0
numN=numN+1
else:
print("[Error in compute_eer()] %s is not target or nontarget in score"%(x[1]))
exit(1)

allScores=sorted(allScores,reverse=False)

numFA=numN
numFR=0

eer=0.0
threshold=0.0
memory=[]

for tuple in allScores:
if(tuple[1]==1):
numFR=numFR+1
else:
numFA=numFA-1

far=numFA*1.0/numN
frr=numFR*1.0/numP

if(far<=frr):
lnow=abs(far-frr)
lmemory=abs(memory[0]-memory[1])
if(lnow<=lmemory):
eer=(far+frr)/2
threshold=tuple[0]
else:
eer=(memory[0]+memory[1])/2
threshold=memory[2]
return eer, threshold
else:
memory=[far,frr,tuple[0]]

def main():
args = get_args()

try:
trials = load_data(args.trials_path, 3)
scores = load_data(args.score_path, 3)

allScores = []
label_dict = {}

for x in trials:
label_dict[x[0]+x[1]]=x[2]

for x in scores:
allScores.append([float(x[2]),label_dict[x[0]+x[1]]])

eer, threshold = compute_eer(allScores)

print("EER% {:.3f} (threshold = {:.5f})".format(eer*100, threshold))
except BaseException as e:
# Look for BaseException so we catch KeyboardInterrupt, which is
# what we get when a background thread dies.
if not isinstance(e, KeyboardInterrupt):
traceback.print_exc()
sys.exit(1)

if __name__ == "__main__":
main()





17 changes: 8 additions & 9 deletions computeEER.sh
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
#!/bin/bash

# Copyright xmuspeech (Author:Snowdar 2018-7-27)
# Copyright xmuspeech (Author:Snowdar 2018-7-27 2020-06-30)

write_file=""
first=3 # <target/nontarget-field 1-based>
second=3 # <score-field 1-based>

. subtools/path.sh
. subtools/parse_options.sh

if [[ $# != 4 ]];then
if [[ $# != 2 ]];then
echo "[exit] Num of parameters is not equal to 4"
echo "usage:$0 [--write-file \"\" | filepath] <score-file> <score-field 1-based> <trials> <target/nontarget-field 1-based>"
echo "[note] You should specify field of score in score-file and field of target/nontarget in trials"
echo "usage:$0 [--write-file \"\" | filepath] <trials> <score-file>"
exit 1
fi

score=$1
first=$2
trials=$3
second=$4
trials=$1
score=$2

workout=`awk -v first=$first '{print $first}' $score | paste - <(awk -v second=$second '{print $second}' $trials ) | \
workout=`awk -v second=$second '{print $second}' $score | paste - <(awk -v first=$first '{print $first}' $trials ) | \
awk '{if(NF==2){print $0}}' | compute-eer - `

if [ "$write_file" != "" ];then
Expand Down
236 changes: 236 additions & 0 deletions computeMin-t-DCF.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright xmuspeech (Author:Snowdar 2019-01-10)

# This metric is for ASVspoof Challenge in 2019.

import sys

# Num of data
N_bona_cm=0
N_spoof_cm=0

# Priors
Pi_tar=0.9405
Pi_non=0.0095
Pi_spoof=0.05

# ASV costs
C_miss_asv=1
C_fa_asv=10

# CM costs
C_miss_cm=1
C_fa_cm=10

def load_data(data_path,n):
list=[]
print("Load data from "+data_path+"...")
with open(data_path,'r') as f:
content=f.readlines()
for line in content:
line=line.strip()
data_list=line.split()
if(n!=len(data_list)):
print('[Error] The %s file has no %s fields'%(data_path,n))
exit(1)

list.append(data_list)
return list

def abs(x):
if(x<0):
return -x
else:
return x

def compute_eer(allScores):
numP=0
numN=0
for x in allScores:
if(x[1]=="target"):
x[1]=1
numP=numP+1
elif(x[1]=="nontarget"):
x[1]=0
numN=numN+1
else:
print("[Error in compute_eer()] %s is not target or nontarget in score"%(x[1]))
exit(1)

allScores=sorted(allScores,reverse=False)

numFA=numN
numFR=0

eer=0.0
threshold=0.0
memory=[]

for tuple in allScores:
if(tuple[1]==1):
numFR=numFR+1
else:
numFA=numFA-1

far=numFA*1.0/numN
frr=numFR*1.0/numP

if(far<=frr):
lnow=abs(far-frr)
lmemory=abs(memory[0]-memory[1])
if(lnow<=lmemory):
eer=(far+frr)/2
threshold=tuple[0]
else:
eer=(memory[0]+memory[1])/2
threshold=memory[2]
return eer,threshold
else:
memory=[far,frr,tuple[0]]


def t_DCF_min(dcf):
return min(dcf)

def t_DCF_norm(beta,P_miss_cm,P_fa_cm):
return beta * P_miss_cm + P_fa_cm

def get_rate(x,y):
if(y==0):
return 0
else:
return x*1.0/y

def obtain_asv_error_rates(asv_score,asv_threshold):
N_tar_asv=0
N_non_asv=0
N_spoof_asv=0

count_tar=0
count_non=0
count_spoof=0

for x in asv_score:
if(x[1]=="target"):
N_tar_asv=N_tar_asv+1
if(float(x[2])<asv_threshold):
count_tar=count_tar+1
elif(x[1]=="nontarget"):
N_non_asv=N_non_asv+1
if(float(x[2])>=asv_threshold):
count_non=count_non+1
elif(x[1]=="spoof"):
N_spoof_asv=N_spoof_asv+1
if(float(x[2])<asv_threshold):
count_spoof=count_spoof+1
else:
print("[Error in obtain_asv_error_rates()] %s is not target or nontarget or spoof in score"%(x[1]))

P_miss_asv=get_rate(count_tar,N_tar_asv)
P_fa_asv=get_rate(count_non,N_non_asv)
P_miss_spoof_asv=get_rate(count_spoof,N_spoof_asv)

return P_miss_asv,P_fa_asv,P_miss_spoof_asv

def check():
if(Pi_tar+Pi_non+Pi_spoof!=1):
print("[Error in check()] Pi_tar+Pi_non+Pi_spoof != 1 ")
exit(1)

## main ##
if len(sys.argv)-1 != 2 :
print 'usage: '+sys.argv[0]+' [options] <asv-score> <cm-score>'
exit(1)
"""
asv-score format with every line:
attack-way target/nontarget/spoof score
example:
- target 4.23
- nontarget 1.24
VC_1 spoof 2.55
cm-score format with every line:
bonafide/spoof score
example:
bonafide 2.34
spoof -1.2
"""
asv_score_path=sys.argv[1]
cm_score_path=sys.argv[2]

check()

#- start -#
asv_score_file=load_data(asv_score_path,3)
cm_score_file=load_data(cm_score_path,2)

#- asv -#
asv_score_for_eer=[]
for x in asv_score_file:
if(x[1]=="target" or x[1]=="nontarget"):
asv_score_for_eer.append([float(x[2]),x[1]])

asv_eer,asv_threshold=compute_eer(asv_score_for_eer)
P_miss_asv,P_fa_asv,P_miss_spoof_asv=obtain_asv_error_rates(asv_score_file,asv_threshold)

#- cm -#
cm_score=[]
cm_score_for_eer=[]
for x in cm_score_file:
if(x[0]=="bonafide"):
lable=1
text="target"
N_bona_cm=N_bona_cm+1
elif(x[0]=="spoof"):
lable=0
text="nontarget"
N_spoof_cm=N_spoof_cm+1
else:
print("[Error in main-cm] the lable %s is not bonafide or spoof"%(x[0]))
exit(1)

cm_score.append([float(x[1]),lable])
cm_score_for_eer.append([float(x[1]),text])

cm_eer,cm_threshold=compute_eer(cm_score_for_eer)

#- t-DCF -#
C1=Pi_tar * (C_miss_cm - C_miss_asv * P_miss_asv) - Pi_non * C_fa_asv * P_fa_asv
C2=C_fa_cm * Pi_spoof * (1 - P_miss_spoof_asv)
beta=C1/C2

cm_score=sorted(cm_score,reverse=False)

count_bona=0
count_spoof=N_spoof_cm
dcf=[]

P_miss_cm=count_bona*1.0/N_bona_cm
P_fa_cm=count_spoof*1.0/N_spoof_cm

dcf.append(t_DCF_norm(beta,P_miss_cm,P_fa_cm))

for tuple in cm_score:
if(tuple[1]==1):
count_bona=count_bona+1
else:
count_spoof=count_spoof-1

P_miss_cm=count_bona*1.0/N_bona_cm
P_fa_cm=count_spoof*1.0/N_spoof_cm
dcf.append(t_DCF_norm(beta,P_miss_cm,P_fa_cm))

min_tDCF=t_DCF_min(dcf)

#- print -#
print("\n[Report]")
print("ASV EER=%f%%, threshold=%f"%(asv_eer*100,asv_threshold))
print("ASV Pfa=%f%%, Pmiss=%f%%, 1-Pmiss,spoof=%f%%"%(P_fa_asv*100,P_miss_asv*100,(1-P_miss_spoof_asv)*100))
print("CM EER=%f%%, threshold=%f"%(cm_eer*100,cm_threshold))
print("Final min-tDCF=%f"%(min_tDCF))




Loading

0 comments on commit fb19cfc

Please sign in to comment.