Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wer tracker #414

Merged
merged 7 commits into from
Feb 2, 2021
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions asteroid/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,23 @@ def __init__(self, model_name, trans_df):

from espnet2.bin.asr_inference import Speech2Text
from espnet_model_zoo.downloader import ModelDownloader
import jiwer

self.model_name = model_name
d = ModelDownloader()
self.asr_model = Speech2Text(**d.download_and_unpack(model_name))
self.input_txt_list = []
self.clean_txt_list = []
self.output_txt_list = []
self.transcriptions = []
self.true_txt_list = []
self.sample_rate = int(d.data_frame[d.data_frame["name"] == model_name]["fs"])
self.trans_df = trans_df
self.trans_dic = self._df_to_dict(trans_df)
self.mix_counter = Counter()
self.clean_counter = Counter()
self.est_counter = Counter()
self.transformation = jiwer.Compose([jiwer.ToLowerCase(), jiwer.RemovePunctuation()])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this transformation enough?
The default is

[<jiwer.transforms.RemoveMultipleSpaces at 0x7fbc79a75df0>,
 <jiwer.transforms.Strip at 0x7fbc79a75e20>,
 <jiwer.transforms.SentencesToListOfWords at 0x7fbc79a75f10>,
 <jiwer.transforms.RemoveEmptyStrings at 0x7fbc7aa17bb0>]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I tested on CHIME4 these were the two that made a difference but you are right let's add the others. It doesn't cost that much anyway.


def __call__(
self,
Expand All @@ -172,25 +176,53 @@ def __call__(
local_est_counter = Counter()
# Count the mixture output for each speaker
txt = self.predict_hypothesis(mix)

# Dict to gather transcriptions and IDs
trans_dict = dict(mixture_txt={}, clean={}, estimates={}, truth={})
# Get mixture transcription
trans_dict["mixture_txt"] = txt
# Get ground truth transcription and IDs
for i, tmp_id in enumerate(wav_id):
trans_dict["truth"][f"utt_id_{i}"] = tmp_id
trans_dict["truth"][f"txt_{i}"] = self.trans_dic[tmp_id]
self.true_txt_list.append(dict(utt_id=tmp_id, text=self.trans_dic[tmp_id]))
# Mixture
for tmp_id in wav_id:
out_count = Counter(self.hsdi(truth=self.trans_dic[tmp_id], hypothesis=txt))
out_count = Counter(
self.hsdi(
truth=self.trans_dic[tmp_id], hypothesis=txt, transformation=self.transformation
)
)
self.mix_counter += out_count
local_mix_counter += out_count
self.input_txt_list.append(dict(utt_id=tmp_id, text=txt))
# Average WER for the clean pair
for wav, tmp_id in zip(clean, wav_id):
for i, (wav, tmp_id) in enumerate(zip(clean, wav_id)):
mpariente marked this conversation as resolved.
Show resolved Hide resolved
txt = self.predict_hypothesis(wav)
out_count = Counter(self.hsdi(truth=self.trans_dic[tmp_id], hypothesis=txt))
out_count = Counter(
self.hsdi(
truth=self.trans_dic[tmp_id], hypothesis=txt, transformation=self.transformation
)
)
self.clean_counter += out_count
local_clean_counter += out_count
self.clean_txt_list.append(dict(utt_id=tmp_id, text=txt))
trans_dict["clean"][f"utt_id_{i}"] = tmp_id
trans_dict["clean"][f"txt_{i}"] = txt
# Average WER for the estimate pair
for est, tmp_id in zip(estimate, wav_id):
for i, (est, tmp_id) in enumerate(zip(estimate, wav_id)):
txt = self.predict_hypothesis(est)
out_count = Counter(self.hsdi(truth=self.trans_dic[tmp_id], hypothesis=txt))
out_count = Counter(
self.hsdi(
truth=self.trans_dic[tmp_id], hypothesis=txt, transformation=self.transformation
)
)
self.est_counter += out_count
local_est_counter += out_count
self.output_txt_list.append(dict(utt_id=tmp_id, text=txt))
trans_dict["estimates"][f"utt_id_{i}"] = tmp_id
trans_dict["estimates"][f"txt_{i}"] = txt
self.transcriptions.append(trans_dict)
return dict(
input_wer=self.wer_from_hsdi(**dict(local_mix_counter)),
clean_wer=self.wer_from_hsdi(**dict(local_clean_counter)),
Expand All @@ -203,11 +235,16 @@ def wer_from_hsdi(hits=0, substitutions=0, deletions=0, insertions=0):
return wer

@staticmethod
def hsdi(truth, hypothesis):
def hsdi(truth, hypothesis, transformation):
from jiwer import compute_measures

keep = ["hits", "substitutions", "deletions", "insertions"]
out = compute_measures(truth=truth, hypothesis=hypothesis).items()
out = compute_measures(
truth=truth,
hypothesis=hypothesis,
truth_transform=transformation,
hypothesis_transform=transformation,
).items()
return {k: v for k, v in out if k in keep}

def predict_hypothesis(self, wav):
Expand Down Expand Up @@ -258,5 +295,8 @@ def final_df(self):
)
return df

def all_transcriptions(self):
return dict(transcriptions=self.transcriptions)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really see the point of the dict with one field, returning the list.
I'd remove this method entirely


def final_report_as_markdown(self):
return self.final_df().to_markdown(index=False, tablefmt="github")