Skip to content

Commit

Permalink
Fix Python APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Aug 12, 2023
1 parent 41cf705 commit 196fcd7
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 24 deletions.
39 changes: 39 additions & 0 deletions python-api-examples/non_streaming_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
(5) Use a tdnn model of the yesno recipe from icefall
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno
cd sherpa-onnx-tdnn-yesno
git lfs pull --include "*.onnx"
python3 ./python-api-examples/non_streaming_server.py \
--sample-rate=8000 \
--feat-dim=23 \
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt
----
To use a certificate so that you can use https, please use
Expand Down Expand Up @@ -196,6 +210,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
)


def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--tdnn-model",
default="",
type=str,
help="Path to the model.onnx for the tdnn model of the yesno recipe",
)


def add_whisper_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--whisper-encoder",
Expand All @@ -216,6 +239,7 @@ def add_model_args(parser: argparse.ArgumentParser):
add_transducer_model_args(parser)
add_paraformer_model_args(parser)
add_nemo_ctc_model_args(parser)
add_tdnn_ctc_model_args(parser)
add_whisper_model_args(parser)

parser.add_argument(
Expand Down Expand Up @@ -730,6 +754,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model

assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
Expand All @@ -750,6 +775,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model

assert_file_exists(args.paraformer)

Expand All @@ -764,6 +790,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.nemo_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model

assert_file_exists(args.nemo_ctc)

Expand All @@ -776,6 +803,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
decoding_method=args.decoding_method,
)
elif args.whisper_encoder:
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)

Expand All @@ -786,6 +814,17 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
num_threads=args.num_threads,
decoding_method=args.decoding_method,
)
elif args.tdnn_model:
assert_file_exists(args.tdnn_model)

recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
model=args.tdnn_model,
tokens=args.tokens,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
)
else:
raise ValueError("Please specify at least one model")

Expand Down
40 changes: 39 additions & 1 deletion python-api-examples/offline-decode-files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
file(s) with a non-streaming model.
(1) For paraformer
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/paraformer.onnx \
Expand All @@ -20,6 +21,7 @@
/path/to/1.wav
(2) For transducer models from icefall
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
Expand Down Expand Up @@ -56,9 +58,20 @@
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
(5) For tdnn models of the yesno recipe from icefall
python3 ./python-api-examples/offline-decode-files.py \
--sample-rate=8000 \
--feature-dim=23 \
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download the pre-trained models
to install sherpa-onnx and to download non-streaming pre-trained models
used in this file.
"""
import argparse
Expand Down Expand Up @@ -159,6 +172,13 @@ def get_args():
help="Path to the model.onnx from NeMo CTC",
)

parser.add_argument(
"--tdnn-model",
default="",
type=str,
help="Path to the model.onnx for the tdnn model of the yesno recipe",
)

parser.add_argument(
"--num-threads",
type=int,
Expand Down Expand Up @@ -285,6 +305,7 @@ def main():
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model

contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
Expand All @@ -311,6 +332,7 @@ def main():
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model

assert_file_exists(args.paraformer)

Expand All @@ -326,6 +348,7 @@ def main():
elif args.nemo_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model

assert_file_exists(args.nemo_ctc)

Expand All @@ -339,6 +362,7 @@ def main():
debug=args.debug,
)
elif args.whisper_encoder:
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)

Expand All @@ -347,6 +371,20 @@ def main():
decoder=args.whisper_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.tdnn_model:
assert_file_exists(args.tdnn_model)

recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
model=args.tdnn_model,
tokens=args.tokens,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
debug=args.debug,
)
Expand Down
23 changes: 10 additions & 13 deletions python-api-examples/web/js/upload.js
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,18 @@ function onFileChange() {
console.log('file.type ' + file.type);
console.log('file.size ' + file.size);

let audioCtx = new AudioContext({sampleRate: 16000});

let reader = new FileReader();
reader.onload = function() {
console.log('reading file!');
let view = new Int16Array(reader.result);
// we assume the input file is a wav file.
// TODO: add some checks here.
let int16_samples = view.subarray(22); // header has 44 bytes == 22 shorts
let num_samples = int16_samples.length;
let float32_samples = new Float32Array(num_samples);
console.log('num_samples ' + num_samples)

for (let i = 0; i < num_samples; ++i) {
float32_samples[i] = int16_samples[i] / 32768.
}
audioCtx.decodeAudioData(reader.result, decodedDone);
};

function decodedDone(decoded) {
let typedArray = new Float32Array(decoded.length);
let float32_samples = decoded.getChannelData(0);
let buf = float32_samples.buffer

// Send 1024 audio samples per request.
//
Expand All @@ -119,14 +117,13 @@ function onFileChange() {
// (2) There is a limit on the number of bytes in the payload that can be
// sent by websocket, which is 1MB, I think. We can send a large
// audio file for decoding in this approach.
let buf = float32_samples.buffer
let n = 1024 * 4; // send this number of bytes per request.
console.log('buf length, ' + buf.byteLength);
send_header(buf.byteLength);
for (let start = 0; start < buf.byteLength; start += n) {
socket.send(buf.slice(start, start + n));
}
};
}

reader.readAsArrayBuffer(file);
}
Expand Down
83 changes: 73 additions & 10 deletions sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
OfflineModelConfig,
OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig,
OfflineTdnnModelConfig,
OfflineWhisperModelConfig,
)
from _sherpa_onnx import OfflineRecognizer as _Recognizer
Expand Down Expand Up @@ -37,7 +38,7 @@ def from_transducer(
decoder: str,
joiner: str,
tokens: str,
num_threads: int,
num_threads: int = 1,
sample_rate: int = 16000,
feature_dim: int = 80,
decoding_method: str = "greedy_search",
Expand All @@ -48,7 +49,7 @@ def from_transducer(
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Expand Down Expand Up @@ -115,7 +116,7 @@ def from_paraformer(
cls,
paraformer: str,
tokens: str,
num_threads: int,
num_threads: int = 1,
sample_rate: int = 16000,
feature_dim: int = 80,
decoding_method: str = "greedy_search",
Expand All @@ -124,9 +125,8 @@ def from_paraformer(
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html>`_
to download pre-trained models.
Args:
tokens:
Expand Down Expand Up @@ -179,7 +179,7 @@ def from_nemo_ctc(
cls,
model: str,
tokens: str,
num_threads: int,
num_threads: int = 1,
sample_rate: int = 16000,
feature_dim: int = 80,
decoding_method: str = "greedy_search",
Expand All @@ -188,7 +188,7 @@ def from_nemo_ctc(
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Expand Down Expand Up @@ -244,14 +244,14 @@ def from_whisper(
encoder: str,
decoder: str,
tokens: str,
num_threads: int,
num_threads: int = 1,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html>`_
to download pre-trained models for different kinds of whisper models,
e.g., tiny, tiny.en, base, base.en, etc.
Expand Down Expand Up @@ -301,6 +301,69 @@ def from_whisper(
self.config = recognizer_config
return self

@classmethod
def from_tdnn_ctc(
cls,
model: str,
tokens: str,
num_threads: int = 1,
sample_rate: int = 8000,
feature_dim: int = 23,
decoding_method: str = "greedy_search",
debug: bool = False,
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html>`_
to download pre-trained models.
Args:
model:
Path to ``model.onnx``.
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Valid values are greedy_search.
debug:
True to show debug messages.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
tdnn=OfflineTdnnModelConfig(model=model),
tokens=tokens,
num_threads=num_threads,
debug=debug,
provider=provider,
model_type="tdnn",
)

feat_config = OfflineFeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)

recognizer_config = OfflineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
return self

def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
return self.recognizer.create_stream()
Expand Down

0 comments on commit 196fcd7

Please sign in to comment.