From f908ae98da5e2d68e27b02bcbc479a58efb3a170 Mon Sep 17 00:00:00 2001 From: JorisCos Date: Thu, 21 Jan 2021 16:43:34 +0100 Subject: [PATCH 1/7] add all_transciptions as json --- asteroid/metrics.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/asteroid/metrics.py b/asteroid/metrics.py index af42b0ca7..3018382db 100644 --- a/asteroid/metrics.py +++ b/asteroid/metrics.py @@ -146,6 +146,8 @@ def __init__(self, model_name, trans_df): self.input_txt_list = [] self.clean_txt_list = [] self.output_txt_list = [] + self.transcription = [] + self.true_txt_list = [] self.sample_rate = int(d.data_frame[d.data_frame["name"] == model_name]["fs"]) self.trans_df = trans_df self.trans_dic = self._df_to_dict(trans_df) @@ -172,25 +174,42 @@ def __call__( local_est_counter = Counter() # Count the mixture output for each speaker txt = self.predict_hypothesis(mix) + + # Gather true transcription and IDs + trans_dict = dict(mixture_txt={},clean={},estimates={},truth={}) + + # Mixture + trans_dict["mixture_txt"] = txt + # Truth + for i,tmp_id in enumerate(wav_id): + trans_dict["truth"][f"ID_{i}"] = tmp_id + trans_dict["truth"][f"Txt_{i}"] = self.trans_dic[tmp_id] + self.true_txt_list.append(dict(utt_id=tmp_id, text=self.trans_dic[tmp_id])) + # Mixture for tmp_id in wav_id: out_count = Counter(self.hsdi(truth=self.trans_dic[tmp_id], hypothesis=txt)) self.mix_counter += out_count local_mix_counter += out_count self.input_txt_list.append(dict(utt_id=tmp_id, text=txt)) # Average WER for the clean pair - for wav, tmp_id in zip(clean, wav_id): + for i, (wav, tmp_id) in enumerate(zip(clean, wav_id)): txt = self.predict_hypothesis(wav) out_count = Counter(self.hsdi(truth=self.trans_dic[tmp_id], hypothesis=txt)) self.clean_counter += out_count local_clean_counter += out_count self.clean_txt_list.append(dict(utt_id=tmp_id, text=txt)) + trans_dict["clean"][f"ID_{i}"] = tmp_id + trans_dict["clean"][f"Txt_{i}"] = txt # Average WER for the estimate pair - for est, tmp_id in zip(estimate, wav_id): + for i, (est, tmp_id) in enumerate(zip(estimate, wav_id)): txt = self.predict_hypothesis(est) out_count = Counter(self.hsdi(truth=self.trans_dic[tmp_id], hypothesis=txt)) self.est_counter += out_count local_est_counter += out_count self.output_txt_list.append(dict(utt_id=tmp_id, text=txt)) + trans_dict["estimates"][f"ID_{i}"] = tmp_id + trans_dict["estimates"][f"Txt_{i}"] = txt + self.transcription.append(trans_dict) return dict( input_wer=self.wer_from_hsdi(**dict(local_mix_counter)), clean_wer=self.wer_from_hsdi(**dict(local_clean_counter)), @@ -204,10 +223,16 @@ def wer_from_hsdi(hits=0, substitutions=0, deletions=0, insertions=0): @staticmethod def hsdi(truth, hypothesis): + import jiwer from jiwer import compute_measures - + transformation = jiwer.Compose([ + jiwer.ToLowerCase(), + jiwer.RemovePunctuation() + ]) keep = ["hits", "substitutions", "deletions", "insertions"] - out = compute_measures(truth=truth, hypothesis=hypothesis).items() + out = compute_measures(truth=truth, hypothesis=hypothesis, + truth_transform=transformation, + hypothesis_transform=transformation).items() return {k: v for k, v in out if k in keep} def predict_hypothesis(self, wav): @@ -258,5 +283,8 @@ def final_df(self): ) return df + def all_transcriptions(self): + return dict(transcriptions=self.transcription) + def final_report_as_markdown(self): return self.final_df().to_markdown(index=False, tablefmt="github") From aed14a9dacad32086813b1e0ee5485f55b3ffcf0 Mon Sep 17 00:00:00 2001 From: JorisCos Date: Thu, 21 Jan 2021 17:09:25 +0100 Subject: [PATCH 2/7] update doc --- asteroid/metrics.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/asteroid/metrics.py b/asteroid/metrics.py index 3018382db..7bf1b1647 100644 --- a/asteroid/metrics.py +++ b/asteroid/metrics.py @@ -175,12 +175,11 @@ def __call__( # Count the mixture output for each speaker txt = self.predict_hypothesis(mix) - # Gather true transcription and IDs + # Dict to gather transcriptions and IDs trans_dict = dict(mixture_txt={},clean={},estimates={},truth={}) - - # Mixture + # Get mixture transcription trans_dict["mixture_txt"] = txt - # Truth + # Get ground truth transcription and IDs for i,tmp_id in enumerate(wav_id): trans_dict["truth"][f"ID_{i}"] = tmp_id trans_dict["truth"][f"Txt_{i}"] = self.trans_dic[tmp_id] From 0b5036d5cc540a5c771b1cf1692eb639240c7718 Mon Sep 17 00:00:00 2001 From: JorisCos Date: Thu, 21 Jan 2021 17:17:53 +0100 Subject: [PATCH 3/7] black reformated --- asteroid/metrics.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/asteroid/metrics.py b/asteroid/metrics.py index 7bf1b1647..9276098c1 100644 --- a/asteroid/metrics.py +++ b/asteroid/metrics.py @@ -176,11 +176,11 @@ def __call__( txt = self.predict_hypothesis(mix) # Dict to gather transcriptions and IDs - trans_dict = dict(mixture_txt={},clean={},estimates={},truth={}) + trans_dict = dict(mixture_txt={}, clean={}, estimates={}, truth={}) # Get mixture transcription trans_dict["mixture_txt"] = txt # Get ground truth transcription and IDs - for i,tmp_id in enumerate(wav_id): + for i, tmp_id in enumerate(wav_id): trans_dict["truth"][f"ID_{i}"] = tmp_id trans_dict["truth"][f"Txt_{i}"] = self.trans_dic[tmp_id] self.true_txt_list.append(dict(utt_id=tmp_id, text=self.trans_dic[tmp_id])) @@ -224,14 +224,15 @@ def wer_from_hsdi(hits=0, substitutions=0, deletions=0, insertions=0): def hsdi(truth, hypothesis): import jiwer from jiwer import compute_measures - transformation = jiwer.Compose([ - jiwer.ToLowerCase(), - jiwer.RemovePunctuation() - ]) + + transformation = jiwer.Compose([jiwer.ToLowerCase(), jiwer.RemovePunctuation()]) keep = ["hits", "substitutions", "deletions", "insertions"] - out = compute_measures(truth=truth, hypothesis=hypothesis, - truth_transform=transformation, - hypothesis_transform=transformation).items() + out = compute_measures( + truth=truth, + hypothesis=hypothesis, + truth_transform=transformation, + hypothesis_transform=transformation, + ).items() return {k: v for k, v in out if k in keep} def predict_hypothesis(self, wav): From 46a36ffec7ff8a650b7fb2cf1c1813d4c0f7d205 Mon Sep 17 00:00:00 2001 From: JorisCos Date: Fri, 22 Jan 2021 10:29:30 +0100 Subject: [PATCH 4/7] move transformation in init --- asteroid/metrics.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/asteroid/metrics.py b/asteroid/metrics.py index 9276098c1..08f45660b 100644 --- a/asteroid/metrics.py +++ b/asteroid/metrics.py @@ -139,6 +139,7 @@ def __init__(self, model_name, trans_df): from espnet2.bin.asr_inference import Speech2Text from espnet_model_zoo.downloader import ModelDownloader + import jiwer self.model_name = model_name d = ModelDownloader() @@ -146,7 +147,7 @@ def __init__(self, model_name, trans_df): self.input_txt_list = [] self.clean_txt_list = [] self.output_txt_list = [] - self.transcription = [] + self.transcriptions = [] self.true_txt_list = [] self.sample_rate = int(d.data_frame[d.data_frame["name"] == model_name]["fs"]) self.trans_df = trans_df @@ -154,6 +155,7 @@ def __init__(self, model_name, trans_df): self.mix_counter = Counter() self.clean_counter = Counter() self.est_counter = Counter() + self.transformation = jiwer.Compose([jiwer.ToLowerCase(), jiwer.RemovePunctuation()]) def __call__( self, @@ -186,14 +188,22 @@ def __call__( self.true_txt_list.append(dict(utt_id=tmp_id, text=self.trans_dic[tmp_id])) # Mixture for tmp_id in wav_id: - out_count = Counter(self.hsdi(truth=self.trans_dic[tmp_id], hypothesis=txt)) + out_count = Counter( + self.hsdi( + truth=self.trans_dic[tmp_id], hypothesis=txt, transformation=self.transformation + ) + ) self.mix_counter += out_count local_mix_counter += out_count self.input_txt_list.append(dict(utt_id=tmp_id, text=txt)) # Average WER for the clean pair for i, (wav, tmp_id) in enumerate(zip(clean, wav_id)): txt = self.predict_hypothesis(wav) - out_count = Counter(self.hsdi(truth=self.trans_dic[tmp_id], hypothesis=txt)) + out_count = Counter( + self.hsdi( + truth=self.trans_dic[tmp_id], hypothesis=txt, transformation=self.transformation + ) + ) self.clean_counter += out_count local_clean_counter += out_count self.clean_txt_list.append(dict(utt_id=tmp_id, text=txt)) @@ -202,13 +212,17 @@ def __call__( # Average WER for the estimate pair for i, (est, tmp_id) in enumerate(zip(estimate, wav_id)): txt = self.predict_hypothesis(est) - out_count = Counter(self.hsdi(truth=self.trans_dic[tmp_id], hypothesis=txt)) + out_count = Counter( + self.hsdi( + truth=self.trans_dic[tmp_id], hypothesis=txt, transformation=self.transformation + ) + ) self.est_counter += out_count local_est_counter += out_count self.output_txt_list.append(dict(utt_id=tmp_id, text=txt)) trans_dict["estimates"][f"ID_{i}"] = tmp_id trans_dict["estimates"][f"Txt_{i}"] = txt - self.transcription.append(trans_dict) + self.transcriptions.append(trans_dict) return dict( input_wer=self.wer_from_hsdi(**dict(local_mix_counter)), clean_wer=self.wer_from_hsdi(**dict(local_clean_counter)), @@ -221,11 +235,9 @@ def wer_from_hsdi(hits=0, substitutions=0, deletions=0, insertions=0): return wer @staticmethod - def hsdi(truth, hypothesis): - import jiwer + def hsdi(truth, hypothesis, transformation): from jiwer import compute_measures - transformation = jiwer.Compose([jiwer.ToLowerCase(), jiwer.RemovePunctuation()]) keep = ["hits", "substitutions", "deletions", "insertions"] out = compute_measures( truth=truth, @@ -284,7 +296,7 @@ def final_df(self): return df def all_transcriptions(self): - return dict(transcriptions=self.transcription) + return dict(transcriptions=self.transcriptions) def final_report_as_markdown(self): return self.final_df().to_markdown(index=False, tablefmt="github") From 7de1eee8520315aae173cb075ab6c19ea0dc8374 Mon Sep 17 00:00:00 2001 From: JorisCos Date: Thu, 28 Jan 2021 14:51:17 +0100 Subject: [PATCH 5/7] lowercase --- asteroid/metrics.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/asteroid/metrics.py b/asteroid/metrics.py index 08f45660b..96f785db8 100644 --- a/asteroid/metrics.py +++ b/asteroid/metrics.py @@ -183,8 +183,8 @@ def __call__( trans_dict["mixture_txt"] = txt # Get ground truth transcription and IDs for i, tmp_id in enumerate(wav_id): - trans_dict["truth"][f"ID_{i}"] = tmp_id - trans_dict["truth"][f"Txt_{i}"] = self.trans_dic[tmp_id] + trans_dict["truth"][f"utt_id_{i}"] = tmp_id + trans_dict["truth"][f"txt_{i}"] = self.trans_dic[tmp_id] self.true_txt_list.append(dict(utt_id=tmp_id, text=self.trans_dic[tmp_id])) # Mixture for tmp_id in wav_id: @@ -207,8 +207,8 @@ def __call__( self.clean_counter += out_count local_clean_counter += out_count self.clean_txt_list.append(dict(utt_id=tmp_id, text=txt)) - trans_dict["clean"][f"ID_{i}"] = tmp_id - trans_dict["clean"][f"Txt_{i}"] = txt + trans_dict["clean"][f"utt_id_{i}"] = tmp_id + trans_dict["clean"][f"txt_{i}"] = txt # Average WER for the estimate pair for i, (est, tmp_id) in enumerate(zip(estimate, wav_id)): txt = self.predict_hypothesis(est) @@ -220,8 +220,8 @@ def __call__( self.est_counter += out_count local_est_counter += out_count self.output_txt_list.append(dict(utt_id=tmp_id, text=txt)) - trans_dict["estimates"][f"ID_{i}"] = tmp_id - trans_dict["estimates"][f"Txt_{i}"] = txt + trans_dict["estimates"][f"utt_id_{i}"] = tmp_id + trans_dict["estimates"][f"txt_{i}"] = txt self.transcriptions.append(trans_dict) return dict( input_wer=self.wer_from_hsdi(**dict(local_mix_counter)), From 6cd1c9004f18ac571231d248f4aef2f8d2c06c43 Mon Sep 17 00:00:00 2001 From: JorisCos Date: Fri, 29 Jan 2021 15:50:29 +0100 Subject: [PATCH 6/7] add transformations remove all_transcriptions method --- asteroid/metrics.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/asteroid/metrics.py b/asteroid/metrics.py index 96f785db8..fb37c4b01 100644 --- a/asteroid/metrics.py +++ b/asteroid/metrics.py @@ -155,7 +155,14 @@ def __init__(self, model_name, trans_df): self.mix_counter = Counter() self.clean_counter = Counter() self.est_counter = Counter() - self.transformation = jiwer.Compose([jiwer.ToLowerCase(), jiwer.RemovePunctuation()]) + self.transformation = jiwer.Compose( + [jiwer.ToLowerCase(), + jiwer.RemovePunctuation(), + jiwer.RemoveMultipleSpaces(), + jiwer.Strip(), + jiwer.SentencesToListOfWords(), + jiwer.RemoveEmptyStrings() + ]) def __call__( self, @@ -295,8 +302,5 @@ def final_df(self): ) return df - def all_transcriptions(self): - return dict(transcriptions=self.transcriptions) - def final_report_as_markdown(self): return self.final_df().to_markdown(index=False, tablefmt="github") From f7d9c6bd6cf0b665bf2855244d33e271fa01e8b3 Mon Sep 17 00:00:00 2001 From: JorisCos Date: Tue, 2 Feb 2021 15:43:48 +0100 Subject: [PATCH 7/7] /lint --- asteroid/metrics.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/asteroid/metrics.py b/asteroid/metrics.py index fb37c4b01..28bc33aa4 100644 --- a/asteroid/metrics.py +++ b/asteroid/metrics.py @@ -156,13 +156,15 @@ def __init__(self, model_name, trans_df): self.clean_counter = Counter() self.est_counter = Counter() self.transformation = jiwer.Compose( - [jiwer.ToLowerCase(), - jiwer.RemovePunctuation(), - jiwer.RemoveMultipleSpaces(), - jiwer.Strip(), - jiwer.SentencesToListOfWords(), - jiwer.RemoveEmptyStrings() - ]) + [ + jiwer.ToLowerCase(), + jiwer.RemovePunctuation(), + jiwer.RemoveMultipleSpaces(), + jiwer.Strip(), + jiwer.SentencesToListOfWords(), + jiwer.RemoveEmptyStrings(), + ] + ) def __call__( self,