-
Notifications
You must be signed in to change notification settings - Fork 423
/
eval.py
178 lines (165 loc) · 7.29 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
Author: Joseph(Junzhe) Zhu, 2021/5. Email: [email protected] / [email protected]
For the original code for the paper[1], please refer to https://github.com/JunzheJosephZhu/MultiDecoder-DPRNN
Demo Page: https://junzhejosephzhu.github.io/Multi-Decoder-DPRNN/
Multi-Decoder DPRNN is a method for source separation when the number of speakers is unknown.
Our contribution is using multiple output heads, with each head modelling a distinct number of source outputs.
In addition, we design a selector network which determines which output head to use, i.e. estimates the number of sources.
The "DPRNN" part of the architecture is orthogonal to our contribution, and can be replaced with any other separator, e.g. Conv/LSTM-TasNet.
References:
[1] "Multi-Decoder DPRNN: High Accuracy Source Counting and Separation",
Junzhe Zhu, Raymond Yeh, Mark Hasegawa-Johnson. https://arxiv.org/abs/2011.12022
"""
from metrics import Penalized_PIT_Wrapper, pairwise_neg_sisdr_loss
import os
import json
import yaml
import argparse
import random
import torch
from tqdm import tqdm
import pandas as pd
import soundfile as sf
from pprint import pprint
from asteroid.utils import tensors_to_device
from asteroid import torch_utils
from model import load_best_model, make_model_and_optimizer
from wsj0_mix_variable import Wsj0mixVariable
import glob
import requests
import librosa
parser = argparse.ArgumentParser()
parser.add_argument(
"--task",
default="sep_count",
type=str,
help="One of `enh_single`, `enh_both`, " "`sep_clean` or `sep_noisy`",
)
parser.add_argument(
"--output_dir", type=str, default="output", help="Output folder for inference results"
)
parser.add_argument(
"--test_dir",
type=str,
default="",
help="Test directory including the WSJ0-mix(variable #speakers) test set json files",
)
parser.add_argument(
"--use_gpu",
type=int,
default=0,
help="Whether to use the GPU for model execution. Enter 1 or 0",
)
parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root")
parser.add_argument(
"--n_save_ex", type=int, default=50, help="Number of audio examples to save, -1 means all"
)
def main(conf):
best_model_path = os.path.join(conf["exp_dir"], "checkpoints", "best-model.ckpt")
if not os.path.exists(best_model_path):
# make pth from checkpoint
model = load_best_model(
conf["train_conf"], conf["exp_dir"], sample_rate=conf["sample_rate"]
)
torch.save(model.state_dict(), best_model_path)
else:
model, _ = make_model_and_optimizer(conf["train_conf"], sample_rate=conf["sample_rate"])
model.eval()
checkpoint = torch.load(best_model_path, map_location="cpu")
model = torch_utils.load_state_dict_in(checkpoint["state_dict"], model)
# Handle device placement
if conf["use_gpu"] and torch.cuda.is_available():
model.cuda()
model_device = next(model.parameters()).device
test_dirs = [
conf["test_dir"].format(n_src) for n_src in conf["train_conf"]["masknet"]["n_srcs"]
]
# evaluate metrics
if conf["test_dir"]:
test_set = Wsj0mixVariable(
json_dirs=test_dirs,
n_srcs=conf["train_conf"]["masknet"]["n_srcs"],
sample_rate=conf["train_conf"]["data"]["sample_rate"],
seglen=None,
minlen=None,
)
# Randomly choose the indexes of sentences to save.
ex_save_dir = os.path.join(conf["exp_dir"], "examples/")
if conf["n_save_ex"] == -1:
conf["n_save_ex"] = len(test_set)
save_idx = random.sample(range(len(test_set)), conf["n_save_ex"])
series_list = []
torch.no_grad().__enter__()
for idx in tqdm(range(len(test_set))):
# Forward the network on the mixture.
mix, sources = [
torch.Tensor(x) for x in tensors_to_device(test_set[idx], device=model_device)
]
est_sources = model.separate(mix[None])
p_si_snr = Penalized_PIT_Wrapper(pairwise_neg_sisdr_loss)(est_sources, sources)
utt_metrics = {
"P-Si-SNR": p_si_snr.item(),
"counting_accuracy": float(sources.size(0) == est_sources.size(0)),
}
utt_metrics["mix_path"] = test_set.data[idx][0]
series_list.append(pd.Series(utt_metrics))
# Save some examples in a folder. Wav files and metrics as text.
if idx in save_idx:
mix_np = mix[None].cpu().data.numpy()
sources_np = sources.cpu().data.numpy()
est_sources_np = est_sources.cpu().data.numpy()
local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx))
os.makedirs(local_save_dir, exist_ok=True)
sf.write(local_save_dir + "mixture.wav", mix_np[0], conf["sample_rate"])
# Loop over the sources and estimates
for src_idx, src in enumerate(sources_np):
sf.write(
local_save_dir + "s{}.wav".format(src_idx + 1), src, conf["sample_rate"]
)
for src_idx, est_src in enumerate(est_sources_np):
sf.write(
local_save_dir + "s{}_estimate.wav".format(src_idx + 1),
est_src,
conf["sample_rate"],
)
# Write local metrics to the example folder.
with open(local_save_dir + "metrics.json", "w") as f:
json.dump(utt_metrics, f, indent=0)
# Save all metrics to the experiment folder.
all_metrics_df = pd.DataFrame(series_list)
all_metrics_df.to_csv(os.path.join(conf["exp_dir"], "all_metrics.csv"))
# Print and save summary metrics
final_results = {}
for metric_name in ["P-Si-SNR", "counting_accuracy"]:
final_results[metric_name] = all_metrics_df[metric_name].mean()
print("Overall metrics :")
pprint(final_results)
with open(os.path.join(conf["exp_dir"], "final_metrics.json"), "w") as f:
json.dump(final_results, f, indent=0)
if __name__ == "__main__":
args = parser.parse_args()
arg_dic = dict(vars(args))
# create an exp and checkpoints folder if none exist
os.makedirs(os.path.join(args.exp_dir, "checkpoints"), exist_ok=True)
# Download a checkpoint if none exists
if len(glob.glob(os.path.join(args.exp_dir, "checkpoints", "*.ckpt"))) == 0:
r = requests.get(
"https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN/resolve/main/best-model.ckpt"
)
with open(os.path.join(args.exp_dir, "checkpoints", "best-model.ckpt"), "wb") as handle:
handle.write(r.content)
# if conf doesn't exist, copy default one
conf_path = os.path.join(args.exp_dir, "conf.yml")
if not os.path.exists(conf_path):
conf_path = "local/conf.yml"
# Load training config
with open(conf_path) as f:
train_conf = yaml.safe_load(f)
arg_dic["sample_rate"] = train_conf["data"]["sample_rate"]
arg_dic["train_conf"] = train_conf
if args.task != arg_dic["train_conf"]["data"]["task"]:
print(
"Warning : the task used to test is different than "
"the one from training, be sure this is what you want."
)
main(arg_dic)