以下代码用于计算音译、语音识别的度量,代码使用了jiwer pipy包,可以直接通过pip安装。

该第三方库提供了包括Character Error Rate (CER), Word Error Rate (WER), Match Error Rate (MER), Word Information Lost (WIL) and Word Information Preserved (WIP)在内的5种度量方法(实际上只有四种,因为WIL和WIP是互补的)。

关于这几种metric的详细解释,参见paper:
From WER and RIL to MER and WIL: improved evaluation measures for connected speech recognition

import jiwer

def compute_single_metric(gt,pred,metric):
    if metric == 'cer':
        return jiwer.cer(gt,pred)
    elif metric == 'wer':
        return jiwer.wer(gt,pred)
    elif metric == 'mer':
        return jiwer.mer(gt,pred)
    elif metric == 'wil':
        return jiwer.wil(gt,pred)
    elif metric == 'wip':
        return jiwer.wip(gt,pred)
    else:
        raise KeyError("invalid metric: {} !".format(metric))

def compute_metrics(ground_truth:list,prediction:list,metrics:list)->dict:
    """compute the auto speech recognition (ASR) metrics, inlcuding:
    Character Error Rate (CER),
    Word Error Rate (WER), 
    Match Error Rate (MER), 
    Word Information Lost (WIL) and Word Information Preserved (WIP)

    Args:
        ground_truth (list): list of ground truth answer, e.g., ['apple','marry','mark twin']
        prediction (list): list of the prediction, e.g., ['appl','malli','mark twen']
        metrics (list): list of choices, i.e., ['cer','wer','mer','wil','wip']
    """   
    choices = ['cer','wer','mer','wil','wip']

    assert len(ground_truth) == len(prediction), 'length mis-match!'
    assert all([c in choices for c in metrics]), "metrics out of the pre-definition, i.e., ['cer','wer','mer','wil','wip']"

    results = dict([(c,0.0) for c in metrics])

    ## calculate the average value from all instances, traverse each metric
    for metric in metrics:
        score = compute_single_metric(ground_truth,prediction,metric)
        score = score * 100
        results[metric] = score
    
    return results

if __name__ == "__main__":
    ground_truth = ["hello world", "i like monthy python"]
    hypothesis = ["hello duck", "i like python"]
    metrics_1 = ['cer','wer','mer','wil','wip']
    metrics_2 = []
    metrics_3 = ['cer']
    metrics_4 = ['ccc']

    print(compute_metrics(ground_truth,hypothesis,metrics_1))
    print(compute_metrics(ground_truth,hypothesis,metrics_2))
    print(compute_metrics(ground_truth,hypothesis,metrics_3))
    print(compute_metrics(ground_truth,hypothesis,metrics_4))


另外,下面这几个repositories也是用来计算音译度量的,只不过没有jiwer全:

  1. FastWER
  2. 私人仓库,计算CER的代码在src/main.py

参考:

  • https://github.com/jitsi/jiwer
  • https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf
  • https://github.com/kahne/fastwer
  • https://github.com/rr250/Arabic-Handwritten-Text-Detection-and-Recognization/blob/master/src/main.py
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐