Skip to content

Commit

Permalink
Fix style issues for online punctuation source files (#1225)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Aug 6, 2024
1 parent 1414e4d commit 375c055
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 121 deletions.
1 change: 1 addition & 0 deletions cmake/cmake_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_binaries():
"sherpa-onnx-offline-tts",
"sherpa-onnx-offline-tts-play",
"sherpa-onnx-offline-websocket-server",
"sherpa-onnx-online-punctuation",
"sherpa-onnx-online-websocket-client",
"sherpa-onnx-online-websocket-server",
"sherpa-onnx-vad-microphone",
Expand Down
19 changes: 11 additions & 8 deletions sherpa-onnx/csrc/online-cnn-bilstm-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ class OnlineCNNBiLSTMModel::Impl {
}
#endif

std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) {
std::array<Ort::Value, 3> inputs = {std::move(token_ids), std::move(valid_ids), std::move(label_lens)};
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids,
Ort::Value valid_ids,
Ort::Value label_lens) {
std::array<Ort::Value, 3> inputs = {
std::move(token_ids), std::move(valid_ids), std::move(label_lens)};

auto ans =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
Expand Down Expand Up @@ -117,18 +120,18 @@ OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel(

OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default;

std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(Ort::Value token_ids,
Ort::Value valid_ids,
Ort::Value label_lens) const {
return impl_->Forward(std::move(token_ids), std::move(valid_ids), std::move(label_lens));
std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(
Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const {
return impl_->Forward(std::move(token_ids), std::move(valid_ids),
std::move(label_lens));
}

OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const {
return impl_->Allocator();
}

const OnlineCNNBiLSTMModelMetaData &
OnlineCNNBiLSTMModel::GetModelMetadata() const {
const OnlineCNNBiLSTMModelMetaData &OnlineCNNBiLSTMModel::GetModelMetadata()
const {
return impl_->GetModelMetadata();
}

Expand Down
9 changes: 5 additions & 4 deletions sherpa-onnx/csrc/online-cnn-bilstm-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ namespace sherpa_onnx {
*/
class OnlineCNNBiLSTMModel {
public:
explicit OnlineCNNBiLSTMModel(
const OnlinePunctuationModelConfig &config);
explicit OnlineCNNBiLSTMModel(const OnlinePunctuationModelConfig &config);

#if __ANDROID_API__ >= 9
OnlineCNNBiLSTMModel(AAssetManager *mgr,
const OnlinePunctuationModelConfig &config);
const OnlinePunctuationModelConfig &config);
#endif

~OnlineCNNBiLSTMModel();
Expand All @@ -43,7 +42,9 @@ class OnlineCNNBiLSTMModel {
* - case_logits: A 2-D tensor of shape (T', num_cases).
* - punct_logits: A 2-D tensor of shape (T', num_puncts).
*/
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const;
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids,
Ort::Value valid_ids,
Ort::Value label_lens) const;

/** Return an allocator for allocating memory
*/
Expand Down
Loading

0 comments on commit 375c055

Please sign in to comment.