Skip to content

Commit

Permalink
Merge pull request #194 from kuke/improve_tuning_dev
Browse files Browse the repository at this point in the history
Improve params tuning strategy for CTC beam search decoder
  • Loading branch information
kuke authored Sep 25, 2017
2 parents 09ff206 + c348997 commit 5e74b46
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 49 deletions.
24 changes: 13 additions & 11 deletions deep_speech_2/examples/librispeech/run_tune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,31 @@
pushd ../.. > /dev/null

# grid-search for hyper-parameters in language model
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -u tools/tune.py \
--num_samples=100 \
--num_batches=-1 \
--batch_size=256 \
--trainer_count=8 \
--beam_size=500 \
--num_proc_bsearch=12 \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--num_alphas=14 \
--num_betas=20 \
--alpha_from=0.1 \
--alpha_to=0.36 \
--beta_from=0.05 \
--beta_to=1.0 \
--cutoff_prob=0.99 \
--num_alphas=45 \
--num_betas=8 \
--alpha_from=1.0 \
--alpha_to=3.2 \
--beta_from=0.1 \
--beta_to=0.45 \
--cutoff_prob=1.0 \
--cutoff_top_n=40 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
--tune_manifest='data/librispeech/manifest.dev-clean' \
--mean_std_path='data/librispeech/mean_std.npz' \
--vocab_path='data/librispeech/vocab.txt' \
--model_path='checkpoints/libri/params.latest.tar.gz' \
--vocab_path='models/librispeech/vocab.txt' \
--model_path='models/librispeech/params.tar.gz' \
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
--error_rate_type='wer' \
--specgram_type='linear'
Expand Down
18 changes: 10 additions & 8 deletions deep_speech_2/examples/tiny/run_tune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,22 @@ pushd ../.. > /dev/null
# grid-search for hyper-parameters in language model
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
python -u tools/tune.py \
--num_samples=100 \
--num_batches=1 \
--batch_size=24 \
--trainer_count=8 \
--beam_size=500 \
--num_proc_bsearch=12 \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--num_alphas=14 \
--num_betas=20 \
--alpha_from=0.1 \
--alpha_to=0.36 \
--beta_from=0.05 \
--beta_to=1.0 \
--cutoff_prob=0.99 \
--num_alphas=45 \
--num_betas=8 \
--alpha_from=1.0 \
--alpha_to=3.2 \
--beta_from=0.1 \
--beta_to=0.45 \
--cutoff_prob=1.0 \
--cutoff_top_n=40 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
Expand Down
109 changes: 79 additions & 30 deletions deep_speech_2/tools/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,43 @@
from __future__ import division
from __future__ import print_function

import sys
import numpy as np
import argparse
import functools
import paddle.v2 as paddle
import _init_paths
from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model
from utils.error_rate import wer
from utils.error_rate import wer, cer
from utils.utility import add_arguments, print_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('num_samples', int, 100, "# of samples to infer.")
add_arg('num_batches', int, -1, "# of batches tuning on. "
"Default -1, on whole dev set.")
add_arg('batch_size', int, 256, "# of samples per batch.")
add_arg('trainer_count', int, 8, "# of Trainers (CPUs or GPUs).")
add_arg('beam_size', int, 500, "Beam search width.")
add_arg('num_proc_bsearch', int, 12, "# of CPUs for beam search.")
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('num_alphas', int, 14, "# of alpha candidates for tuning.")
add_arg('num_betas', int, 20, "# of beta candidates for tuning.")
add_arg('alpha_from', float, 0.1, "Where alpha starts tuning from.")
add_arg('alpha_to', float, 0.36, "Where alpha ends tuning with.")
add_arg('beta_from', float, 0.05, "Where beta starts tuning from.")
add_arg('beta_to', float, 1.0, "Where beta ends tuning with.")
add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.")
add_arg('num_alphas', int, 45, "# of alpha candidates for tuning.")
add_arg('num_betas', int, 8, "# of beta candidates for tuning.")
add_arg('alpha_from', float, 1.0, "Where alpha starts tuning from.")
add_arg('alpha_to', float, 3.2, "Where alpha ends tuning with.")
add_arg('beta_from', float, 0.1, "Where beta starts tuning from.")
add_arg('beta_to', float, 0.45, "Where beta ends tuning with.")
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
"bi-directional RNNs. Not for GRU.")
add_arg('tune_manifest', str,
'data/librispeech/manifest.dev',
'data/librispeech/manifest.dev-clean',
"Filepath of manifest to tune.")
add_arg('mean_std_path', str,
'data/librispeech/mean_std.npz',
Expand Down Expand Up @@ -63,7 +67,7 @@


def tune():
"""Tune parameters alpha and beta on one minibatch."""
"""Tune parameters alpha and beta incrementally."""
if not args.num_alphas >= 0:
raise ValueError("num_alphas must be non-negative!")
if not args.num_betas >= 0:
Expand All @@ -77,7 +81,7 @@ def tune():
num_threads=1)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.tune_manifest,
batch_size=args.num_samples,
batch_size=args.batch_size,
sortagrad=False,
shuffle_method=None)
tune_data = batch_reader().next()
Expand All @@ -95,30 +99,75 @@ def tune():
pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights)

# decoders only accept string encoded in utf-8
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]

error_rate_func = cer if args.error_rate_type == 'cer' else wer
# create grid for search
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
params_grid = [(alpha, beta) for alpha in cand_alphas
for beta in cand_betas]

## tune parameters in loop
for alpha, beta in params_grid:
result_transcripts = ds2_model.infer_batch(
infer_data=tune_data,
decoding_method='ctc_beam_search',
beam_alpha=alpha,
beam_beta=beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
vocab_list=data_generator.vocab_list,
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch)
wer_sum, num_ins = 0.0, 0
for target, result in zip(target_transcripts, result_transcripts):
wer_sum += wer(target, result)
num_ins += 1
print("alpha = %f\tbeta = %f\tWER = %f" %
(alpha, beta, wer_sum / num_ins))
err_sum = [0.0 for i in xrange(len(params_grid))]
err_ave = [0.0 for i in xrange(len(params_grid))]
num_ins, cur_batch = 0, 0
## incremental tuning parameters over multiple batches
for infer_data in batch_reader():
if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
break

target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
]

num_ins += len(target_transcripts)
# grid search
for index, (alpha, beta) in enumerate(params_grid):
result_transcripts = ds2_model.infer_batch(
infer_data=infer_data,
decoding_method='ctc_beam_search',
beam_alpha=alpha,
beam_beta=beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch)

for target, result in zip(target_transcripts, result_transcripts):
err_sum[index] += error_rate_func(target, result)
err_ave[index] = err_sum[index] / num_ins
if index % 2 == 0:
sys.stdout.write('.')
sys.stdout.flush()

# output on-line tuning result at the end of current batch
err_ave_min = min(err_ave)
min_index = err_ave.index(err_ave_min)
print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), "
" min [%s] = %f" %(cur_batch, num_ins,
"%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1],
args.error_rate_type, err_ave_min))
cur_batch += 1

# output WER/CER at every (alpha, beta)
print("\nFinal %s:\n" % args.error_rate_type)
for index in xrange(len(params_grid)):
print("(alpha, beta) = (%s, %s), [%s] = %f"
% ("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1],
args.error_rate_type, err_ave[index]))

err_ave_min = min(err_ave)
min_index = err_ave.index(err_ave_min)
print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)"
% (args.num_batches, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1]))

ds2_model.logger.info("finish inference")


def main():
Expand Down

0 comments on commit 5e74b46

Please sign in to comment.