diff --git a/asteroid/metrics.py b/asteroid/metrics.py index af42b0ca7..28bc33aa4 100644 --- a/asteroid/metrics.py +++ b/asteroid/metrics.py @@ -139,6 +139,7 @@ def __init__(self, model_name, trans_df): from espnet2.bin.asr_inference import Speech2Text from espnet_model_zoo.downloader import ModelDownloader + import jiwer self.model_name = model_name d = ModelDownloader() @@ -146,12 +147,24 @@ def __init__(self, model_name, trans_df): self.input_txt_list = [] self.clean_txt_list = [] self.output_txt_list = [] + self.transcriptions = [] + self.true_txt_list = [] self.sample_rate = int(d.data_frame[d.data_frame["name"] == model_name]["fs"]) self.trans_df = trans_df self.trans_dic = self._df_to_dict(trans_df) self.mix_counter = Counter() self.clean_counter = Counter() self.est_counter = Counter() + self.transformation = jiwer.Compose( + [ + jiwer.ToLowerCase(), + jiwer.RemovePunctuation(), + jiwer.RemoveMultipleSpaces(), + jiwer.Strip(), + jiwer.SentencesToListOfWords(), + jiwer.RemoveEmptyStrings(), + ] + ) def __call__( self, @@ -172,25 +185,53 @@ def __call__( local_est_counter = Counter() # Count the mixture output for each speaker txt = self.predict_hypothesis(mix) + + # Dict to gather transcriptions and IDs + trans_dict = dict(mixture_txt={}, clean={}, estimates={}, truth={}) + # Get mixture transcription + trans_dict["mixture_txt"] = txt + # Get ground truth transcription and IDs + for i, tmp_id in enumerate(wav_id): + trans_dict["truth"][f"utt_id_{i}"] = tmp_id + trans_dict["truth"][f"txt_{i}"] = self.trans_dic[tmp_id] + self.true_txt_list.append(dict(utt_id=tmp_id, text=self.trans_dic[tmp_id])) + # Mixture for tmp_id in wav_id: - out_count = Counter(self.hsdi(truth=self.trans_dic[tmp_id], hypothesis=txt)) + out_count = Counter( + self.hsdi( + truth=self.trans_dic[tmp_id], hypothesis=txt, transformation=self.transformation + ) + ) self.mix_counter += out_count local_mix_counter += out_count self.input_txt_list.append(dict(utt_id=tmp_id, text=txt)) # Average WER for the clean pair - for wav, tmp_id in zip(clean, wav_id): + for i, (wav, tmp_id) in enumerate(zip(clean, wav_id)): txt = self.predict_hypothesis(wav) - out_count = Counter(self.hsdi(truth=self.trans_dic[tmp_id], hypothesis=txt)) + out_count = Counter( + self.hsdi( + truth=self.trans_dic[tmp_id], hypothesis=txt, transformation=self.transformation + ) + ) self.clean_counter += out_count local_clean_counter += out_count self.clean_txt_list.append(dict(utt_id=tmp_id, text=txt)) + trans_dict["clean"][f"utt_id_{i}"] = tmp_id + trans_dict["clean"][f"txt_{i}"] = txt # Average WER for the estimate pair - for est, tmp_id in zip(estimate, wav_id): + for i, (est, tmp_id) in enumerate(zip(estimate, wav_id)): txt = self.predict_hypothesis(est) - out_count = Counter(self.hsdi(truth=self.trans_dic[tmp_id], hypothesis=txt)) + out_count = Counter( + self.hsdi( + truth=self.trans_dic[tmp_id], hypothesis=txt, transformation=self.transformation + ) + ) self.est_counter += out_count local_est_counter += out_count self.output_txt_list.append(dict(utt_id=tmp_id, text=txt)) + trans_dict["estimates"][f"utt_id_{i}"] = tmp_id + trans_dict["estimates"][f"txt_{i}"] = txt + self.transcriptions.append(trans_dict) return dict( input_wer=self.wer_from_hsdi(**dict(local_mix_counter)), clean_wer=self.wer_from_hsdi(**dict(local_clean_counter)), @@ -203,11 +244,16 @@ def wer_from_hsdi(hits=0, substitutions=0, deletions=0, insertions=0): return wer @staticmethod - def hsdi(truth, hypothesis): + def hsdi(truth, hypothesis, transformation): from jiwer import compute_measures keep = ["hits", "substitutions", "deletions", "insertions"] - out = compute_measures(truth=truth, hypothesis=hypothesis).items() + out = compute_measures( + truth=truth, + hypothesis=hypothesis, + truth_transform=transformation, + hypothesis_transform=transformation, + ).items() return {k: v for k, v in out if k in keep} def predict_hypothesis(self, wav):