Skip to content

Commit

Permalink
Support paraformer on Android (#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Aug 14, 2023
1 parent 6038e2a commit 35526e2
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,14 @@ class MainActivity : AppCompatActivity() {
// Please change getModelConfig() to add new models
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// for a list of available models
val type = 3
val type = 5
println("Select model type ${type}")
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
modelConfig = getModelConfig(type = type)!!,
lmConfig = getOnlineLMConfig(type = type),
endpointConfig = getEndpointConfig(),
enableEndpoint = true,
decodingMethod = "modified_beam_search",
maxActivePaths = 4,
)

model = SherpaOnnx(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,19 @@ data class EndpointConfig(
)

data class OnlineTransducerModelConfig(
var encoder: String,
var decoder: String,
var joiner: String,
var encoder: String = "",
var decoder: String = "",
var joiner: String = "",
)

data class OnlineParaformerModelConfig(
var encoder: String = "",
var decoder: String = "",
)

data class OnlineModelConfig(
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(),
var tokens: String,
var numThreads: Int = 1,
var debug: Boolean = false,
Expand All @@ -37,8 +47,8 @@ data class FeatureConfig(

data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineTransducerModelConfig,
var lmConfig : OnlineLMConfig,
var modelConfig: OnlineModelConfig,
var lmConfig: OnlineLMConfig,
var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
Expand Down Expand Up @@ -115,74 +125,102 @@ to add your own. (It should be straightforward to add a new model
by following the code)
@param type
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese)
1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-zh-2023-02-20-chinese
2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English)
2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-en-2023-02-17-english
3 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
3,4 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
3 - int8 encoder
4 - float32 encoder
5 - csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
*/
fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
fun getModelConfig(type: Int): OnlineModelConfig? {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
return OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
1 -> {
val modelDir = "sherpa-onnx-lstm-zh-2023-02-20"
return OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-11-avg-1.onnx",
decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-11-avg-1.onnx",
decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "lstm",
)
}

2 -> {
val modelDir = "sherpa-onnx-lstm-en-2023-02-17"
return OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "lstm",
)
}

3 -> {
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
return OnlineTransducerModelConfig(
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx",
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx",
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
),
tokens = "$modelDir/data/lang_char/tokens.txt",
modelType = "zipformer2",
)
}

4 -> {
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
return OnlineTransducerModelConfig(
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx",
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx",
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
),
tokens = "$modelDir/data/lang_char/tokens.txt",
modelType = "zipformer2",
)
}

5 -> {
val modelDir = "sherpa-onnx-streaming-paraformer-bilingual-zh-en"
return OnlineModelConfig(
paraformer = OnlineParaformerModelConfig(
encoder = "$modelDir/encoder.int8.onnx",
decoder = "$modelDir/decoder.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "paraformer",
)
}
}
return null;
}
Expand All @@ -200,7 +238,7 @@ by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
*/
fun getOnlineLMConfig(type : Int): OnlineLMConfig {
fun getOnlineLMConfig(type: Int): OnlineLMConfig {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
Expand Down
6 changes: 5 additions & 1 deletion sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,11 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
OnlineParaformerDecoderResult r;
s->SetParaformerResult(r);

// the internal model caches are not reset
s->GetStates().clear();
s->GetParaformerEncoderOutCache().clear();
s->GetParaformerAlphaCache().clear();

// s->GetParaformerFeatCache().clear();

// Note: We only update counters. The underlying audio samples
// are not discarded.
Expand Down
48 changes: 37 additions & 11 deletions sherpa-onnx/jni/jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class SherpaOnnx {
}

void InputFinished() const {
std::vector<float> tail_padding(input_sample_rate_ * 0.32, 0);
std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
tail_padding.size());
stream_->InputFinished();
Expand Down Expand Up @@ -158,48 +158,74 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {

//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);

// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(transducer_config);
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);

fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;");
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;");
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner = p;
env->ReleaseStringUTFChars(s, p);

// paraformer
fid = env->GetFieldID(model_config_cls, "paraformer",
"Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
jobject paraformer_config = env->GetObjectField(model_config, fid);
jclass paraformer_config_config_cls = env->GetObjectClass(paraformer_config);

fid = env->GetFieldID(paraformer_config_config_cls, "encoder",
"Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.encoder = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(paraformer_config_config_cls, "decoder",
"Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.decoder = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(transducer_config, fid);
ans.model_config.num_threads = env->GetIntField(model_config, fid);

fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(transducer_config, fid);
ans.model_config.debug = env->GetBooleanField(model_config, fid);

fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
Expand Down

0 comments on commit 35526e2

Please sign in to comment.