diff --git a/README.md b/README.md index 777ca1cde..27db32b54 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,7 @@ More information in [egs/README.md](./egs). * [x] [DCCRNet](./asteroid/models/dccrnet.py) ([Hu et al.](https://arxiv.org/abs/2008.00264)) * [x] [DCUNet](./asteroid/models/dcunet.py) ([Choi et al.](https://arxiv.org/abs/1903.03107)) * [x] [CrossNet-Open-Unmix](./asteroid/models/x_umx.py) ([Sawata et al.](https://arxiv.org/abs/2010.04228)) +* [x] [Multi-Decoder DPRNN](./egs/wsj0-mix-var/Multi-Decoder-DPRNN) ([Zhu et al.](http://www.isle.illinois.edu/speech_web_lg/pubs/2021/zhu2021multi.pdf)) * [ ] Open-Unmix (coming) ([Stöter et al.](https://sigsep.github.io/open-unmix/)) * [ ] Wavesplit (coming) ([Zeghidour et al.](https://arxiv.org/abs/2002.08933)) diff --git a/egs/wsj0-mix-var/Multi-Decoder-DPRNN/.vscode/settings.json b/egs/wsj0-mix-var/Multi-Decoder-DPRNN/.vscode/settings.json new file mode 100644 index 000000000..5d7d7306b --- /dev/null +++ b/egs/wsj0-mix-var/Multi-Decoder-DPRNN/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "ros.distro": "noetic" +} \ No newline at end of file diff --git a/egs/wsj0-mix-var/Multi-Decoder-DPRNN/README.md b/egs/wsj0-mix-var/Multi-Decoder-DPRNN/README.md index cadec95ff..1eed1bce8 100644 --- a/egs/wsj0-mix-var/Multi-Decoder-DPRNN/README.md +++ b/egs/wsj0-mix-var/Multi-Decoder-DPRNN/README.md @@ -1,7 +1,29 @@ -## This is the repository for Multi-Decoder DPRNN, published at ICASSP 2021. -Summary: Multi-Decoder DPRNN deals with source separation with variable number of speakers. It has 98.5% accuracy in speaker number classification, which is much higher than all previous SOTA methods. It also has similar SNR as models trained separately on different number of speakers, but its runtime is constant and independent of the number of speakers. +## This is the official repository for Multi-Decoder DPRNN, published at ICASSP 2021. +**Summary**: Multi-Decoder DPRNN deals with source separation with variable number of speakers. It has 98.5% accuracy in speaker number classification, which is much higher than all previous SOTA methods. It also has similar SNR as models trained separately on different number of speakers, but **its runtime is constant and independent of the number of speakers.** -paper link: https://arxiv.org/abs/2011.12022 +**Abstract**: We propose an end-to-end trainable approach to single-channel speech separation with unknown number of speakers, **only training a single model for arbitrary number of speakers**. Our approach extends the MulCat source separation backbone with additional output heads: a count-head to infer the number of speakers, and decoder-heads for reconstructing the original signals. Beyond the model, we also propose a metric on how to evaluate source separation with variable number of speakers. Specifically, we cleared up the issue on how to evaluate the quality when the ground-truth hasmore or less speakers than the ones predicted by the model. We evaluate our approach on the WSJ0-mix datasets, with mixtures up to five speakers. **We demonstrate that our approach outperforms state-of-the-art in counting the number of speakers and remains competitive in quality of reconstructed signals.** + +paper arxiv link: https://arxiv.org/abs/2011.12022 + +## Project Page & Demo +Project page & example output can be found [here](https://junzhejosephzhu.github.io/Multi-Decoder-DPRNN/) + +## Getting Started +Install asteroid by running ```pip install -e .``` in asteroid directory +To install the requirements, run ```pip install -r requirements.txt``` + +To run a pre-trained model on your own .wav mixture files, run ```python eval.py --wav_file {file_name.wav} --use_gpu {1/0}```. The script should automatically download a pre-trained model(link below). + +You can use regular expressions for file names. For example, you can run ```python eval.py --wav_file local/*.wav --use_gpu 0 ``` + +The default output directory will be ./output, but you can override that with ```--output_dir``` option + +If you want to download an alternative pre-trained model, you can create a folder, and save the pretrained model in ```{folder_name}/checkpoints/best-model.ckpt```, then run ```python eval.py --wav_file {file_name.wav} --use_gpu {1/0} --exp_dir {folder_name}``` + +## Train your own model +To train the model, edit the file paths in run.sh and execute ```./run.sh --stage 0```, follow the instructions to generate dataset and train the model. + +After training the model, execute ```./run.sh --stage 4``` to evaluate the model. Some examples will be saved in exp/tmp_uuid/examples ## Kindly cite this paper ``` @@ -16,16 +38,18 @@ paper link: https://arxiv.org/abs/2011.12022 doi={10.1109/ICASSP39728.2021.9414205}} ``` - - +## Resources Pretrained mini model and config can be found at: https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN \ -Project page & example output can be found at: https://junzhejosephzhu.github.io/Multi-Decoder-DPRNN/ -#### This is the refactored version of the code for ease of production use. If you want to reproduce the paper results, original experiment code & config can be found at https://github.com/JunzheJosephZhu/MultiDecoder-DPRNN -Original Paper Results(Confusion Matrix) +This is the refactored version of the code, with some hyperparameter changes. If you want to reproduce the paper results, original experiment code & config can be found at https://github.com/JunzheJosephZhu/MultiDecoder-DPRNN + +**Original Paper Results**(Confusion Matrix) 2 | 3 | 4 |5 -----|------|------|-- 2998 | 17 | 1 |0 2 | 2977 | 27 |0 0 | 6 | 2928 |80 0 | 0 | 44 |2920 + +## Contact the author +If you have any question, you can reach me at josefzhu@stanford.edu diff --git a/egs/wsj0-mix-var/Multi-Decoder-DPRNN/eval.py b/egs/wsj0-mix-var/Multi-Decoder-DPRNN/eval.py index ee031f020..c92a33a5c 100644 --- a/egs/wsj0-mix-var/Multi-Decoder-DPRNN/eval.py +++ b/egs/wsj0-mix-var/Multi-Decoder-DPRNN/eval.py @@ -23,11 +23,13 @@ from pprint import pprint from asteroid.utils import tensors_to_device -from asteroid.metrics import get_metrics +from asteroid import torch_utils from model import load_best_model, make_model_and_optimizer -from wsj0_mix_variable import Wsj0mixVariable, _collate_fn - +from wsj0_mix_variable import Wsj0mixVariable +import glob +import requests +import librosa parser = argparse.ArgumentParser() parser.add_argument( @@ -37,10 +39,25 @@ help="One of `enh_single`, `enh_both`, " "`sep_clean` or `sep_noisy`", ) parser.add_argument( - "--test_dir", type=str, required=True, help="Test directory including the json files" + "--wav_file", + type=str, + default="", + help="Path to the wav file to run model inference on. Could be a regular expression of {folder_name}/*.wav", ) parser.add_argument( - "--use_gpu", type=int, default=0, help="Whether to use the GPU for model execution" + "--output_dir", type=str, default="output", help="Output folder for inference results" +) +parser.add_argument( + "--test_dir", + type=str, + default="", + help="Test directory including the WSJ0-mix(variable #speakers) test set json files", +) +parser.add_argument( + "--use_gpu", + type=int, + default=0, + help="Whether to use the GPU for model execution. Enter 1 or 0", ) parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root") parser.add_argument( @@ -49,7 +66,7 @@ def main(conf): - best_model_path = os.path.join(conf["exp_dir"], "best_model.pth") + best_model_path = os.path.join(conf["exp_dir"], "checkpoints", "best-model.ckpt") if not os.path.exists(best_model_path): # make pth from checkpoint model = load_best_model( @@ -59,84 +76,119 @@ def main(conf): else: model, _ = make_model_and_optimizer(conf["train_conf"], sample_rate=conf["sample_rate"]) model.eval() - model.load_state_dict(torch.load(best_model_path)) + checkpoint = torch.load(best_model_path, map_location="cpu") + model = torch_utils.load_state_dict_in(checkpoint["state_dict"], model) # Handle device placement - if conf["use_gpu"]: + if conf["use_gpu"] and torch.cuda.is_available(): model.cuda() model_device = next(model.parameters()).device test_dirs = [ conf["test_dir"].format(n_src) for n_src in conf["train_conf"]["masknet"]["n_srcs"] ] - test_set = Wsj0mixVariable( - json_dirs=test_dirs, - n_srcs=conf["train_conf"]["masknet"]["n_srcs"], - sample_rate=conf["train_conf"]["data"]["sample_rate"], - seglen=None, - minlen=None, - ) - - # Randomly choose the indexes of sentences to save. - ex_save_dir = os.path.join(conf["exp_dir"], "examples/") - if conf["n_save_ex"] == -1: - conf["n_save_ex"] = len(test_set) - save_idx = random.sample(range(len(test_set)), conf["n_save_ex"]) - series_list = [] - torch.no_grad().__enter__() - for idx in tqdm(range(len(test_set))): - # Forward the network on the mixture. - mix, sources = [ - torch.Tensor(x) for x in tensors_to_device(test_set[idx], device=model_device) - ] - est_sources = model.separate(mix[None]) - p_si_snr = Penalized_PIT_Wrapper(pairwise_neg_sisdr_loss)(est_sources, sources) - utt_metrics = { - "P-Si-SNR": p_si_snr.item(), - "counting_accuracy": float(sources.size(0) == est_sources.size(0)), - } - utt_metrics["mix_path"] = test_set.data[idx][0] - series_list.append(pd.Series(utt_metrics)) - - # Save some examples in a folder. Wav files and metrics as text. - if idx in save_idx: - mix_np = mix[None].cpu().data.numpy() - sources_np = sources.cpu().data.numpy() - est_sources_np = est_sources.cpu().data.numpy() - local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx)) - os.makedirs(local_save_dir, exist_ok=True) - sf.write(local_save_dir + "mixture.wav", mix_np[0], conf["sample_rate"]) - # Loop over the sources and estimates - for src_idx, src in enumerate(sources_np): - sf.write(local_save_dir + "s{}.wav".format(src_idx + 1), src, conf["sample_rate"]) - for src_idx, est_src in enumerate(est_sources_np): + if conf["wav_file"]: + mix_files = glob.glob(conf["wav_file"]) + if not os.path.exists(conf["output_dir"]): + os.makedirs(conf["output_dir"]) + for mix_file in mix_files: + mix, _ = librosa.load(mix_file, sr=conf["sample_rate"]) + mix = tensors_to_device(torch.Tensor(mix), device=model_device) + est_sources = model.separate(mix[None]) + est_sources = est_sources.cpu().numpy() + for i, est_src in enumerate(est_sources): sf.write( - local_save_dir + "s{}_estimate.wav".format(src_idx + 1), + os.path.join( + conf["output_dir"], + os.path.basename(mix_file).replace(".wav", f"_spkr{i}.wav"), + ), est_src, conf["sample_rate"], ) - # Write local metrics to the example folder. - with open(local_save_dir + "metrics.json", "w") as f: - json.dump(utt_metrics, f, indent=0) - # Save all metrics to the experiment folder. - all_metrics_df = pd.DataFrame(series_list) - all_metrics_df.to_csv(os.path.join(conf["exp_dir"], "all_metrics.csv")) + # evaluate metrics + if conf["test_dir"]: + test_set = Wsj0mixVariable( + json_dirs=test_dirs, + n_srcs=conf["train_conf"]["masknet"]["n_srcs"], + sample_rate=conf["train_conf"]["data"]["sample_rate"], + seglen=None, + minlen=None, + ) + + # Randomly choose the indexes of sentences to save. + ex_save_dir = os.path.join(conf["exp_dir"], "examples/") + if conf["n_save_ex"] == -1: + conf["n_save_ex"] = len(test_set) + save_idx = random.sample(range(len(test_set)), conf["n_save_ex"]) + series_list = [] + torch.no_grad().__enter__() + for idx in tqdm(range(len(test_set))): + # Forward the network on the mixture. + mix, sources = [ + torch.Tensor(x) for x in tensors_to_device(test_set[idx], device=model_device) + ] + est_sources = model.separate(mix[None]) + p_si_snr = Penalized_PIT_Wrapper(pairwise_neg_sisdr_loss)(est_sources, sources) + utt_metrics = { + "P-Si-SNR": p_si_snr.item(), + "counting_accuracy": float(sources.size(0) == est_sources.size(0)), + } + utt_metrics["mix_path"] = test_set.data[idx][0] + series_list.append(pd.Series(utt_metrics)) + + # Save some examples in a folder. Wav files and metrics as text. + if idx in save_idx: + mix_np = mix[None].cpu().data.numpy() + sources_np = sources.cpu().data.numpy() + est_sources_np = est_sources.cpu().data.numpy() + local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx)) + os.makedirs(local_save_dir, exist_ok=True) + sf.write(local_save_dir + "mixture.wav", mix_np[0], conf["sample_rate"]) + # Loop over the sources and estimates + for src_idx, src in enumerate(sources_np): + sf.write( + local_save_dir + "s{}.wav".format(src_idx + 1), src, conf["sample_rate"] + ) + for src_idx, est_src in enumerate(est_sources_np): + sf.write( + local_save_dir + "s{}_estimate.wav".format(src_idx + 1), + est_src, + conf["sample_rate"], + ) + # Write local metrics to the example folder. + with open(local_save_dir + "metrics.json", "w") as f: + json.dump(utt_metrics, f, indent=0) + + # Save all metrics to the experiment folder. + all_metrics_df = pd.DataFrame(series_list) + all_metrics_df.to_csv(os.path.join(conf["exp_dir"], "all_metrics.csv")) - # Print and save summary metrics - final_results = {} - for metric_name in ["P-Si-SNR", "counting_accuracy"]: - final_results[metric_name] = all_metrics_df[metric_name].mean() - print("Overall metrics :") - pprint(final_results) - with open(os.path.join(conf["exp_dir"], "final_metrics.json"), "w") as f: - json.dump(final_results, f, indent=0) + # Print and save summary metrics + final_results = {} + for metric_name in ["P-Si-SNR", "counting_accuracy"]: + final_results[metric_name] = all_metrics_df[metric_name].mean() + print("Overall metrics :") + pprint(final_results) + with open(os.path.join(conf["exp_dir"], "final_metrics.json"), "w") as f: + json.dump(final_results, f, indent=0) if __name__ == "__main__": args = parser.parse_args() arg_dic = dict(vars(args)) - - # Load training config + # create an exp and checkpoints folder if none exist + os.makedirs(os.path.join(args.exp_dir, "checkpoints"), exist_ok=True) + # Download a checkpoint if none exists + if len(glob.glob(os.path.join(args.exp_dir, "checkpoints", "*.ckpt"))) == 0: + r = requests.get( + "https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN/resolve/main/best-model.ckpt" + ) + with open(os.path.join(args.exp_dir, "checkpoints", "best-model.ckpt"), "wb") as handle: + handle.write(r.content) + # if conf doesn't exist, copy default one conf_path = os.path.join(args.exp_dir, "conf.yml") + if not os.path.exists(conf_path): + conf_path = "local/conf.yml" + # Load training config with open(conf_path) as f: train_conf = yaml.safe_load(f) arg_dic["sample_rate"] = train_conf["data"]["sample_rate"] diff --git a/egs/wsj0-mix-var/Multi-Decoder-DPRNN/model.py b/egs/wsj0-mix-var/Multi-Decoder-DPRNN/model.py index ae30e99b2..690998907 100644 --- a/egs/wsj0-mix-var/Multi-Decoder-DPRNN/model.py +++ b/egs/wsj0-mix-var/Multi-Decoder-DPRNN/model.py @@ -201,7 +201,6 @@ def forward_wav(self, wav, slice_size=32000, *args, **kwargs): output_cat[:, :slice_size] = output_wavs[0] start = slice_stride for i in range(1, slice_nb): - end = start + slice_size overlap_prev = output_cat[:, start : start + slice_stride].unsqueeze(0) overlap_next = output_wavs[i : i + 1, :, :slice_stride] pw_losses = pairwise_neg_sisdr(overlap_next, overlap_prev) diff --git a/egs/wsj0-mix-var/Multi-Decoder-DPRNN/requirements.txt b/egs/wsj0-mix-var/Multi-Decoder-DPRNN/requirements.txt new file mode 100644 index 000000000..038dfbc14 --- /dev/null +++ b/egs/wsj0-mix-var/Multi-Decoder-DPRNN/requirements.txt @@ -0,0 +1,3 @@ +asteroid +numpy +librosa \ No newline at end of file diff --git a/egs/wsj0-mix-var/Multi-Decoder-DPRNN/run.sh b/egs/wsj0-mix-var/Multi-Decoder-DPRNN/run.sh index a563015c7..e53ad0d02 100755 --- a/egs/wsj0-mix-var/Multi-Decoder-DPRNN/run.sh +++ b/egs/wsj0-mix-var/Multi-Decoder-DPRNN/run.sh @@ -72,7 +72,7 @@ if [[ $stage -le 1 ]]; then echo "Stage 1 : Downloading wsj0-mix mixing scripts" # Link + WHAM is ok for 2 source. # wget https://www.merl.com/demos/deep-clustering/create-speaker-mixtures.zip -O ./local/ - wget https://github.com/JunzheJosephZhu/MultiDecoder-DPRNN/raw/master/create-speaker-mixtures-2345.zip -P ./local + wget https://github.com/JunzheJosephZhu/MDDPRNN-deprecated/raw/master/create-speaker-mixtures-2345.zip -P ./local unzip ./local/create-speaker-mixtures-2345.zip -d ./local/create-speaker-mixtures-2345 mv ./local/create-speaker-mixtures-2345.zip ./local/create-speaker-mixtures-2345 @@ -106,6 +106,7 @@ if [[ -z ${tag} ]]; then fi expdir=exp/tmp_${tag} mkdir -p $expdir && echo $uuid >> $expdir/run_uuid.txt +mkdir -p logs echo "Results from the following experiment will be stored in $expdir" if [[ $stage -le 3 ]]; then @@ -129,7 +130,6 @@ if [[ $stage -le 3 ]]; then fi if [[ $stage -le 4 ]]; then - expdir=exp/tmp echo "Stage 4 : Evaluation" echo "If you want to change n_srcs, please change the config file" CUDA_VISIBLE_DEVICES=$id $python_path eval.py \