Skip to content

Commit

Permalink
Support paraformer on iOS (#265)
Browse files Browse the repository at this point in the history
* Fix C API to support streaming paraformer

* Fix Swift API

* Support paraformer in iOS
  • Loading branch information
csukuangfj authored Aug 14, 2023
1 parent 35526e2 commit a8bdb4b
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 85 deletions.
6 changes: 3 additions & 3 deletions c-api-examples/decode-file-c-api.c
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ int32_t main(int32_t argc, char *argv[]) {
config.model_config.tokens = value;
break;
case 'e':
config.model_config.encoder = value;
config.model_config.transducer.encoder = value;
break;
case 'd':
config.model_config.decoder = value;
config.model_config.transducer.decoder = value;
break;
case 'j':
config.model_config.joiner = value;
config.model_config.transducer.joiner = value;
break;
case 'n':
config.model_config.num_threads = atoi(value);
Expand Down
16 changes: 12 additions & 4 deletions ios-swift/SherpaOnnx/SherpaOnnx.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
objects = {

/* Begin PBXBuildFile section */
C93989AE2A89FE13009AB859 /* sherpa-onnx.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = C984A81B29AA11C500D74C52 /* sherpa-onnx.xcframework */; };
C93989B02A89FE33009AB859 /* onnxruntime.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = C93989AF2A89FE33009AB859 /* onnxruntime.xcframework */; };
C984A7E829A9EEB700D74C52 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = C984A7E729A9EEB700D74C52 /* AppDelegate.swift */; };
C984A7EA29A9EEB700D74C52 /* SceneDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = C984A7E929A9EEB700D74C52 /* SceneDelegate.swift */; };
C984A7F129A9EEB900D74C52 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C984A7F029A9EEB900D74C52 /* Assets.xcassets */; };
Expand All @@ -18,8 +20,6 @@
C984A82829AA196100D74C52 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = C984A82629AA196100D74C52 /* Main.storyboard */; };
C984A82A29AA19AC00D74C52 /* Model.swift in Sources */ = {isa = PBXBuildFile; fileRef = C984A82929AA19AC00D74C52 /* Model.swift */; };
C984A83C29AA430B00D74C52 /* ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = C984A83B29AA430B00D74C52 /* ViewController.swift */; };
C984A83D29AA43D900D74C52 /* sherpa-onnx.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = C984A81B29AA11C500D74C52 /* sherpa-onnx.xcframework */; };
C984A83F29AA43EE00D74C52 /* onnxruntime.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = C984A83E29AA43EE00D74C52 /* onnxruntime.xcframework */; };
/* End PBXBuildFile section */

/* Begin PBXContainerItemProxy section */
Expand All @@ -40,6 +40,10 @@
/* End PBXContainerItemProxy section */

/* Begin PBXFileReference section */
C93989AF2A89FE33009AB859 /* onnxruntime.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = onnxruntime.xcframework; path = "../../build-ios/ios-onnxruntime/1.15.1/onnxruntime.xcframework"; sourceTree = "<group>"; };
C93989B12A89FF78009AB859 /* decoder.int8.onnx */ = {isa = PBXFileReference; lastKnownFileType = file; name = decoder.int8.onnx; path = "../../../icefall-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx"; sourceTree = "<group>"; };
C93989B22A89FF78009AB859 /* encoder.int8.onnx */ = {isa = PBXFileReference; lastKnownFileType = file; name = encoder.int8.onnx; path = "../../../icefall-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx"; sourceTree = "<group>"; };
C93989B32A89FF78009AB859 /* tokens.txt */ = {isa = PBXFileReference; lastKnownFileType = text; name = tokens.txt; path = "../../../icefall-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt"; sourceTree = "<group>"; };
C984A7E429A9EEB700D74C52 /* SherpaOnnx.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = SherpaOnnx.app; sourceTree = BUILT_PRODUCTS_DIR; };
C984A7E729A9EEB700D74C52 /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = "<group>"; };
C984A7E929A9EEB700D74C52 /* SceneDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SceneDelegate.swift; sourceTree = "<group>"; };
Expand All @@ -66,8 +70,8 @@
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
C984A83F29AA43EE00D74C52 /* onnxruntime.xcframework in Frameworks */,
C984A83D29AA43D900D74C52 /* sherpa-onnx.xcframework in Frameworks */,
C93989B02A89FE33009AB859 /* onnxruntime.xcframework in Frameworks */,
C93989AE2A89FE13009AB859 /* sherpa-onnx.xcframework in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down Expand Up @@ -146,8 +150,12 @@
C984A81A29AA11C500D74C52 /* Frameworks */ = {
isa = PBXGroup;
children = (
C93989B12A89FF78009AB859 /* decoder.int8.onnx */,
C93989B22A89FF78009AB859 /* encoder.int8.onnx */,
C93989B32A89FF78009AB859 /* tokens.txt */,
C984A82029AA139600D74C52 /* onnxruntime.xcframework */,
C984A83E29AA43EE00D74C52 /* onnxruntime.xcframework */,
C93989AF2A89FE33009AB859 /* onnxruntime.xcframework */,
C984A81B29AA11C500D74C52 /* sherpa-onnx.xcframework */,
);
name = Frameworks;
Expand Down
Binary file not shown.
69 changes: 45 additions & 24 deletions ios-swift/SherpaOnnx/SherpaOnnx/Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,70 +15,91 @@ func getResource(_ forResource: String, _ ofType: String) -> String {

/// sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html
func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerModelConfig {
func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineModelConfig {
let encoder = getResource("encoder-epoch-99-avg-1", "onnx")
let decoder = getResource("decoder-epoch-99-avg-1", "onnx")
let joiner = getResource("joiner-epoch-99-avg-1", "onnx")
let tokens = getResource("tokens", "txt")

return sherpaOnnxOnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner,
return sherpaOnnxOnlineModelConfig(
tokens: tokens,
numThreads: 2,
transducer: sherpaOnnxOnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner
),
numThreads: 1,
modelType: "zipformer"
)
}

func getZhZipformer20230615() -> SherpaOnnxOnlineTransducerModelConfig {
func getZhZipformer20230615() -> SherpaOnnxOnlineModelConfig {
let encoder = getResource("encoder-epoch-12-avg-4-chunk-16-left-128", "onnx")
let decoder = getResource("decoder-epoch-12-avg-4-chunk-16-left-128", "onnx")
let joiner = getResource("joiner-epoch-12-avg-4-chunk-16-left-128", "onnx")
let tokens = getResource("tokens", "txt")

return sherpaOnnxOnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner,
return sherpaOnnxOnlineModelConfig(
tokens: tokens,
numThreads: 2,
transducer: sherpaOnnxOnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner
),
numThreads: 1,
modelType: "zipformer2"
)
}

func getZhZipformer20230615Int8() -> SherpaOnnxOnlineTransducerModelConfig {
func getZhZipformer20230615Int8() -> SherpaOnnxOnlineModelConfig {
let encoder = getResource("encoder-epoch-12-avg-4-chunk-16-left-128.int8", "onnx")
let decoder = getResource("decoder-epoch-12-avg-4-chunk-16-left-128", "onnx")
let joiner = getResource("joiner-epoch-12-avg-4-chunk-16-left-128", "onnx")
let tokens = getResource("tokens", "txt")

return sherpaOnnxOnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner,
return sherpaOnnxOnlineModelConfig(
tokens: tokens,
numThreads: 2,
transducer: sherpaOnnxOnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner),
numThreads: 1,
modelType: "zipformer2"
)
}

func getEnZipformer20230626() -> SherpaOnnxOnlineTransducerModelConfig {
func getEnZipformer20230626() -> SherpaOnnxOnlineModelConfig {
let encoder = getResource("encoder-epoch-99-avg-1-chunk-16-left-128", "onnx")
let decoder = getResource("decoder-epoch-99-avg-1-chunk-16-left-128", "onnx")
let joiner = getResource("joiner-epoch-99-avg-1-chunk-16-left-128", "onnx")
let tokens = getResource("tokens", "txt")

return sherpaOnnxOnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner,
return sherpaOnnxOnlineModelConfig(
tokens: tokens,
numThreads: 2,
transducer: sherpaOnnxOnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner),
numThreads: 1,
modelType: "zipformer2"
)
}

func getBilingualStreamingZhEnParaformer() -> SherpaOnnxOnlineModelConfig {
let encoder = getResource("encoder.int8", "onnx")
let decoder = getResource("decoder.int8", "onnx")
let tokens = getResource("tokens", "txt")

return sherpaOnnxOnlineModelConfig(
tokens: tokens,
paraformer: sherpaOnnxOnlineParaformerModelConfig(
encoder: encoder,
decoder: decoder),
numThreads: 1,
modelType: "paraformer"
)
}

/// Please refer to
/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
/// to add more models if you need
3 changes: 2 additions & 1 deletion ios-swift/SherpaOnnx/SherpaOnnx/ViewController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class ViewController: UIViewController {

// let modelConfig = getBilingualStreamZhEnZipformer20230220()
// let modelConfig = getZhZipformer20230615()
let modelConfig = getEnZipformer20230626()
// let modelConfig = getEnZipformer20230626()
let modelConfig = getBilingualStreamingZhEnParaformer()

let featConfig = sherpaOnnxFeatureConfig(
sampleRate: 16000,
Expand Down
Binary file not shown.
27 changes: 22 additions & 5 deletions ios-swiftui/SherpaOnnx/SherpaOnnx/Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,39 @@ func getResource(_ forResource: String, _ ofType: String) -> String {

/// sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html
func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerModelConfig {
func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineModelConfig {
let encoder = getResource("encoder-epoch-99-avg-1", "onnx")
let decoder = getResource("decoder-epoch-99-avg-1", "onnx")
let joiner = getResource("joiner-epoch-99-avg-1", "onnx")
let tokens = getResource("tokens", "txt")

return sherpaOnnxOnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner,
return sherpaOnnxOnlineModelConfig(
tokens: tokens,
transducer: sherpaOnnxOnlineTransducerModelConfig(
encoder: encoder,
decoder: decoder,
joiner: joiner),
numThreads: 2,
modelType: "zipformer"
)
}

// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html
func getBilingualStreamingZhEnParaformer() -> SherpaOnnxOnlineModelConfig {
let encoder = getResource("encoder.int8", "onnx")
let decoder = getResource("decoder.int8", "onnx")
let tokens = getResource("tokens", "txt")

return sherpaOnnxOnlineModelConfig(
tokens: tokens,
paraformer: sherpaOnnxOnlineParaformerModelConfig(
encoder: encoder,
decoder: decoder),
numThreads: 1,
modelType: "paraformer"
)
}

/// Please refer to
/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
/// to add more models if you need
21 changes: 11 additions & 10 deletions ios-swiftui/SherpaOnnx/SherpaOnnx/SherpaOnnxViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ enum Status {
class SherpaOnnxViewModel: ObservableObject {
@Published var status: Status = .stop
@Published var subtitles: String = ""

var sentences: [String] = []

var audioEngine: AVAudioEngine? = nil
var recognizer: SherpaOnnxRecognizer! = nil

var lastSentence: String = ""
let maxSentence: Int = 20

var results: String {
if sentences.isEmpty && lastSentence.isEmpty {
return ""
Expand All @@ -42,24 +42,25 @@ class SherpaOnnxViewModel: ObservableObject {
.joined(separator: "\n") + "\n\(sentences.count): \(lastSentence.lowercased())"
}
}

func updateLabel() {
DispatchQueue.main.async {
self.subtitles = self.results
}
}

init() {
initRecognizer()
initRecorder()
}

private func initRecognizer() {
// Please select one model that is best suitable for you.
//
// You can also modify Model.swift to add new pre-trained models from
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
let modelConfig = getBilingualStreamZhEnZipformer20230220()
// let modelConfig = getBilingualStreamZhEnZipformer20230220()
let modelConfig = getBilingualStreamingZhEnParaformer()

let featConfig = sherpaOnnxFeatureConfig(
sampleRate: 16000,
Expand All @@ -77,7 +78,7 @@ class SherpaOnnxViewModel: ObservableObject {
)
recognizer = SherpaOnnxRecognizer(config: &config)
}

private func initRecorder() {
print("init recorder")
audioEngine = AVAudioEngine()
Expand Down Expand Up @@ -152,7 +153,7 @@ class SherpaOnnxViewModel: ObservableObject {
}
}
}

public func toggleRecorder() {
if status == .stop {
startRecorder()
Expand Down
26 changes: 20 additions & 6 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,17 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);

recognizer_config.model_config.transducer.encoder =
SHERPA_ONNX_OR(config->model_config.encoder, "");
SHERPA_ONNX_OR(config->model_config.transducer.encoder, "");
recognizer_config.model_config.transducer.decoder =
SHERPA_ONNX_OR(config->model_config.decoder, "");
SHERPA_ONNX_OR(config->model_config.transducer.decoder, "");
recognizer_config.model_config.transducer.joiner =
SHERPA_ONNX_OR(config->model_config.joiner, "");
SHERPA_ONNX_OR(config->model_config.transducer.joiner, "");

recognizer_config.model_config.paraformer.encoder =
SHERPA_ONNX_OR(config->model_config.paraformer.encoder, "");
recognizer_config.model_config.paraformer.decoder =
SHERPA_ONNX_OR(config->model_config.paraformer.decoder, "");

recognizer_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.num_threads =
Expand Down Expand Up @@ -128,6 +134,8 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
const auto &text = result.text;

auto r = new SherpaOnnxOnlineRecognizerResult;
memset(r, 0, sizeof(SherpaOnnxOnlineRecognizerResult));

// copy text
r->text = new char[text.size() + 1];
std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
Expand All @@ -153,7 +161,6 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
r->tokens = new char[total_length];
memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0,
total_length);
r->timestamps = new float[r->count];
char **tokens_temp = new char *[r->count];
int32_t pos = 0;
for (int32_t i = 0; i < r->count; ++i) {
Expand All @@ -162,10 +169,17 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
result.tokens[i].c_str(), result.tokens[i].size());
// +1 to move past the null character
pos += result.tokens[i].size() + 1;
r->timestamps[i] = result.timestamps[i];
}

r->tokens_arr = tokens_temp;

if (!result.timestamps.empty()) {
r->timestamps = new float[r->count];
std::copy(result.timestamps.begin(), result.timestamps.end(),
r->timestamps);
} else {
r->timestamps = nullptr;
}

} else {
r->count = 0;
r->timestamps = nullptr;
Expand Down
Loading

0 comments on commit a8bdb4b

Please sign in to comment.