eer

作者: 04282aba96e3 | 来源:发表于2018-02-08 10:12 被阅读1次
#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 02 11:36:03 2018

@author: zhangzhihui
"""
import os
import sys
import getopt

'''
Common Functions
'''
def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False
def load_threhold_by_filepath(filepath):
    threhold_buf = []
    err_file_name='err.txt'
    if os.path.isfile(err_file_name):
        os.remove(err_file_name)
    err_writer=open(err_file_name, 'a')
    with open(filepath, 'r') as fd:
        lines_buf = fd.readlines()
        for cur_line in lines_buf:
            comp_buf = cur_line.strip().split(',')
            if is_number(comp_buf[2]):
                threhold_buf.append(float(comp_buf[2]))
            else:
                '''fail format error'''
                err_writer.write('is Str:' + cur_line)
    fd.close()
    err_writer.close()
    return threhold_buf
'''UnUsed
def get_average(arr):
    return sum(arr)/len(arr)
def get_variance(arr):
    average = get_average(arr)
    sum_tmp=0.0
    for item in arr:
        cur_tmp=item-average
        sum_tmp=sum_tmp+(cur_tmp*cur_tmp)
    return sum_tmp/len(arr)
'''
def cal_threhold_by_fr(arr, percentage, ascending):
    if percentage<0 or percentage>1 or len(arr)<=0 :
        return 0,0
    index=int(len(arr)*percentage)
    if(index>=len(arr)):
        return 0
    else:
        return arr[index]    
def cal_far_frr_threhold(arr1,arr2,percentage,ascending):
    if percentage<0 or percentage>1 or len(arr1)<=0 or len(arr2) <=0 :
        return 0,0
    index=int(len(arr1)*percentage)
    if(index>=len(arr1)):
        return 0,0
    
    if ascending:
        arr1_value=arr1[index]
        for index in range(0,len(arr2)):
            if arr2[index]<=arr1_value:
                fr=float(index)/float(len(arr2))
                return fr, arr1_value
    else:
        arr1_value=arr1[index]
        for index in range(0,len(arr2)):
            if arr2[index]>=arr1_value:
                fr=float(index)/float(len(arr2))
                return fr, arr1_value
'''
Business Logic Functions
'''
def cal_far_by_frr(truespeaker_arr,imposter_arr,frr):
    return cal_far_frr_threhold(truespeaker_arr,imposter_arr,frr,True)
def cal_frr_by_far(imposter_arr,truespeaker_arr,far):
    return cal_far_frr_threhold(imposter_arr,truespeaker_arr,far,False)
def cal_threhold_by_frr(truespeaker_arr,frr):
    return cal_threhold_by_fr(truespeaker_arr,frr,True)
def cal_threhold_by_far(imposter_arr,far):
    return cal_threhold_by_fr(imposter_arr,far,False)
def cal_eer(truespeaker_arr,imposter_arr):
    min_value_tmp=10000.0
    eer_result=0
    threhold_result=0    
    for index in range(0,len(truespeaker_arr)):
        frr=float(index)/float(len(truespeaker_arr))
        threhold_frr=truespeaker_arr[index]
        threhold_far=cal_threhold_by_far(imposter_arr,frr)
        tmp=abs(threhold_frr-threhold_far)
        if (tmp - min_value_tmp) < 0.00000001:
            min_value_tmp=tmp
            eer_result=frr
            threhold_result=(threhold_frr+threhold_far)/2
    return eer_result,threhold_result
def main(argv):
    helpstr='err_tool.py -i <imposter_path> -t <truespeaker_path> -s <sore_map>'
    try:
        opts, args = getopt.getopt(argv,"hi:t:s:",["help","imposter_path=","truespeaker_path=","score_map"])
    except getopt.GetoptError:
        print(helpstr)
        sys.exit(2) 

    truespeaker_path = r'.\data\truespeaker.txt'
    imposter_path = r'.\data\imposter.txt'
    score_map={'level1':0.025,'level2':0.015,'level3':0.005}
    for opt, arg in opts:
        if opt == '-h':
            print(helpstr)
            sys.exit()
        elif opt in ("-i", "--imposter_path"):
            imposter_path = arg
        elif opt in ("-t", "--truespeaker_path"):
            truespeaker_path = arg
        elif opt in ("-s", "--socre_map"):
            score_map={'level1':0.025,'level2':0.015,'level3':0.005}
    
    truespeaker_threhold_buf = load_threhold_by_filepath(truespeaker_path);
    imposter_threhold_buf = load_threhold_by_filepath(imposter_path);

    truespeaker_threhold_buf.sort()
    imposter_threhold_buf.sort(reverse=True)
    
    err, err_threhold = cal_eer(truespeaker_threhold_buf,imposter_threhold_buf)

    score60_frr,score60_threhold_value = cal_frr_by_far(imposter_threhold_buf, truespeaker_threhold_buf, score_map['level1'])
    score70_frr,score70_threhold_value = cal_frr_by_far(imposter_threhold_buf, truespeaker_threhold_buf, score_map['level2'])
    score80_frr,score80_threhold_value = cal_frr_by_far(imposter_threhold_buf, truespeaker_threhold_buf, score_map['level3'])
    
    print('ERR: ' + str(err) + ' threhold is:' + str(err_threhold))
    print('score = 80 FAR = ' + str(score_map['level3']) + ' FRR:' + str(score80_frr) +' threhold value is:' + str(score80_threhold_value))
    print('score = 70 FAR = ' + str(score_map['level2']) + ' FRR:' + str(score70_frr) +' threhold value is:' + str(score70_threhold_value))
    print('score = 60 FAR = ' + str(score_map['level1']) + ' FRR:' + str(score60_frr) +' threhold value is:' + str(score60_threhold_value))
    '''
    print('min threhold value is:' + str(min_threhold_value))
    print('max threhold value is:' + str(max_threhold_value))
    '''
if __name__ == '__main__':
    main(sys.argv[1:])
    print('script is finished!!!')

相关文章

网友评论

      本文标题:eer

      本文链接:https://www.haomeiwen.com/subject/schczxtx.html