Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Encode keywords in C++ side #967

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux)
endif()
endif()

include(cppinyin)
include(kaldi-native-fbank)
include(kaldi-decoder)
include(onnxruntime)
Expand Down
63 changes: 63 additions & 0 deletions cmake/cppinyin.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
function(download_cppinyin)
include(FetchContent)

set(cppinyin_URL "https://github.com/pkufool/cppinyin/archive/refs/tags/v0.1.tar.gz")
set(cppinyin_URL2 "https://hub.nuaa.cf/pkufool/cppinyin/archive/refs/tags/v0.1.tar.gz")
set(cppinyin_HASH "SHA256=3659bc0c28d17d41ce932807c1cdc1da8c861e6acee969b5844d0d0a3c5ef78b")

# If you don't have access to the Internet,
# please pre-download cppinyin
set(possible_file_locations
$ENV{HOME}/Downloads/cppinyin-0.1.tar.gz
${CMAKE_SOURCE_DIR}/cppinyin-0.1.tar.gz
${CMAKE_BINARY_DIR}/cppinyin-0.1.tar.gz
/tmp/cppinyin-0.1.tar.gz
/star-fj/fangjun/download/github/cppinyin-0.1.tar.gz
)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(cppinyin_URL "${f}")
file(TO_CMAKE_PATH "${cppinyin_URL}" cppinyin_URL)
message(STATUS "Found local downloaded cppinyin: ${cppinyin_URL}")
set(cppinyin_URL2)
break()
endif()
endforeach()

set(CPPINYIN_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
set(CPPINYIN_BUILD_PYTHON OFF CACHE BOOL "" FORCE)

FetchContent_Declare(cppinyin
URL
${cppinyin_URL}
${cppinyin_URL2}
URL_HASH
${cppinyin_HASH}
)

FetchContent_GetProperties(cppinyin)
if(NOT cppinyin_POPULATED)
message(STATUS "Downloading cppinyin ${cppinyin_URL}")
FetchContent_Populate(cppinyin)
endif()
message(STATUS "cppinyin is downloaded to ${cppinyin_SOURCE_DIR}")
add_subdirectory(${cppinyin_SOURCE_DIR} ${cppinyin_BINARY_DIR} EXCLUDE_FROM_ALL)

target_include_directories(cppinyin_core
PUBLIC
${cppinyin_SOURCE_DIR}/
)

if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32)
install(TARGETS cppinyin_core DESTINATION ..)
else()
install(TARGETS cppinyin_core DESTINATION lib)
endif()

if(WIN32 AND BUILD_SHARED_LIBS)
install(TARGETS cppinyin_core DESTINATION bin)
endif()
endfunction()

download_cppinyin()
35 changes: 31 additions & 4 deletions python-api-examples/keyword-spotter-from-microphone.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,39 @@ def get_args():
""",
)

parser.add_argument(
"--modeling-unit",
type=str,
help="""The modeling unit of the model, valid values are bpe (for English model)
and ppinyin (For Chinese model).
""",
)

parser.add_argument(
"--bpe-vocab",
type=str,
help="""A simple format of bpe model, you can get it from the sentencepiece
generated folder. Used to tokenize the keywords into token ids. Used only
when modeling unit is bpe.
""",
)

parser.add_argument(
"--lexicon",
type=str,
help="""The lexicon used to tokenize the keywords into token ids. Used
only when modeling unit is ppinyin.
""",
)

parser.add_argument(
"--keywords-file",
type=str,
help="""
The file containing keywords, one words/phrases per line, and for each
phrase the bpe/cjkchar/pinyin are separated by a space. For example:
The file containing keywords, one words/phrases per line. For example:

▁HE LL O ▁WORLD
x iǎo ài t óng x ué
HELLO WORLD
小爱同学
""",
)

Expand Down Expand Up @@ -164,6 +188,9 @@ def main():
keywords_score=args.keywords_score,
keywords_threshold=args.keywords_threshold,
num_trailing_blanks=args.num_trailing_blanks,
modeling_unit=args.modeling_unit,
bpe_vocab=args.bpe_vocab,
lexicon=args.lexicon,
provider=args.provider,
)

Expand Down
35 changes: 31 additions & 4 deletions python-api-examples/keyword-spotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,39 @@ def get_args():
""",
)

parser.add_argument(
"--modeling-unit",
type=str,
help="""The modeling unit of the model, valid values are bpe (for English model)
and ppinyin (For Chinese model).
""",
)

parser.add_argument(
"--bpe-vocab",
type=str,
help="""A simple format of bpe model, you can get it from the sentencepiece
generated folder. Used to tokenize the keywords into token ids. Used only
when modeling unit is bpe.
""",
)

parser.add_argument(
"--lexicon",
type=str,
help="""The lexicon used to tokenize the keywords into token ids. Used
only when modeling unit is ppinyin.
""",
)

parser.add_argument(
"--keywords-file",
type=str,
help="""
The file containing keywords, one words/phrases per line, and for each
phrase the bpe/cjkchar/pinyin are separated by a space. For example:
The file containing keywords, one words/phrases per line. For example:

▁HE LL O ▁WORLD
x iǎo ài t óng x ué
HELLO WORLD
小爱同学
""",
)

Expand Down Expand Up @@ -183,6 +207,9 @@ def main():
keywords_score=args.keywords_score,
keywords_threshold=args.keywords_threshold,
num_trailing_blanks=args.num_trailing_blanks,
modeling_unit=args.modeling_unit,
bpe_vocab=args.bpe_vocab,
lexicon=args.lexicon,
provider=args.provider,
)

Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ if(ANDROID_NDK)
endif()

target_link_libraries(sherpa-onnx-core
cppinyin_core
kaldi-native-fbank-core
kaldi-decoder-core
ssentencepiece_core
Expand Down
37 changes: 34 additions & 3 deletions sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "android/asset_manager_jni.h"
#endif

#include "cppinyin/csrc/cppinyin.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
#include "sherpa-onnx/csrc/keyword-spotter.h"
Expand All @@ -27,6 +28,7 @@
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
#include "sherpa-onnx/csrc/utils.h"
#include "ssentencepiece/csrc/ssentencepiece.h"

namespace sherpa_onnx {

Expand Down Expand Up @@ -78,6 +80,17 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
unk_id_ = sym_["<unk>"];
}

if (!config_.model_config.bpe_vocab.empty()) {
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
config_.model_config.bpe_vocab);
}

if (config_.model_config.modeling_unit == "ppinyin" &&
!config_.model_config.lexicon.empty()) {
pinyin_encoder_ = std::make_unique<cppinyin::PinyinEncoder>(
config_.model_config.lexicon);
}

model_->SetFeatureDim(config.feat_config.feature_dim);

if (config.keywords_buf.empty()) {
Expand All @@ -103,6 +116,19 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {

model_->SetFeatureDim(config.feat_config.feature_dim);

if (!config_.model_config.bpe_vocab.empty()) {
auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
std::istringstream iss(std::string(buf.begin(), buf.end()));
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
}

if (config_.model_config.modeling_unit == "ppinyin" &&
!config_.model_config.lexicon.empty()) {
auto buf = ReadFile(mgr, config_.model_config.lexicon);
std::istringstream iss(std::string(buf.begin(), buf.end()));
pinyin_encoder_ = std::make_unique<cppinyin::PinyinEncoder>(iss);
}

InitKeywords(mgr);

decoder_ = std::make_unique<TransducerKeywordDecoder>(
Expand All @@ -128,8 +154,9 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
std::vector<float> current_scores;
std::vector<float> current_thresholds;

if (!EncodeKeywords(is, sym_, &current_ids, &current_kws, &current_scores,
&current_thresholds)) {
if (!EncodeKeywords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), pinyin_encoder_.get(), &current_ids,
&current_kws, &current_scores, &current_thresholds)) {
SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str());
return nullptr;
}
Expand Down Expand Up @@ -269,7 +296,9 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {

private:
void InitKeywords(std::istream &is) {
if (!EncodeKeywords(is, sym_, &keywords_id_, &keywords_, &boost_scores_,
if (!EncodeKeywords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), pinyin_encoder_.get(),
&keywords_id_, &keywords_, &boost_scores_,
&thresholds_)) {
SHERPA_ONNX_LOGE("Encode keywords failed.");
exit(-1);
Expand Down Expand Up @@ -339,6 +368,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
std::vector<float> thresholds_;
std::vector<std::string> keywords_;
ContextGraphPtr keywords_graph_;
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
std::unique_ptr<cppinyin::PinyinEncoder> pinyin_encoder_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<TransducerKeywordDecoder> decoder_;
SymbolTable sym_;
Expand Down
7 changes: 3 additions & 4 deletions sherpa-onnx/csrc/keyword-spotter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,9 @@ void KeywordSpotterConfig::Register(ParseOptions *po) {
"The acoustic threshold (probability) to trigger the keywords.");
po->Register(
"keywords-file", &keywords_file,
"The file containing keywords, one word/phrase per line, and for each"
"phrase the bpe/cjkchar are separated by a space. For example: "
"▁HE LL O ▁WORLD"
"你 好 世 界");
"The file containing keywords, one word/phrase per line. For example: "
"HELLO WORLD"
"你好世界");
}

bool KeywordSpotterConfig::Validate() const {
Expand Down
24 changes: 19 additions & 5 deletions sherpa-onnx/csrc/online-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ void OnlineModelConfig::Register(ParseOptions *po) {
po->Register("debug", &debug,
"true to print model information while loading it.");

po->Register("modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, "
"cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
"hotwords are provided, we need it to encode the hotwords into "
"token sequence.");
po->Register(
"modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, "
"cjkchar, cjkchar+bpe, ppinyin, etc. Currently, it is needed only when "
"hotwords are provided, we need it to encode the hotwords into "
"token sequence.");

po->Register("bpe-vocab", &bpe_vocab,
"The vocabulary generated by google's sentencepiece program. "
Expand All @@ -43,6 +44,10 @@ void OnlineModelConfig::Register(ParseOptions *po) {
"your bpe model is generated. Only used when hotwords provided "
"and the modeling unit is bpe or cjkchar+bpe");

po->Register("lexicon", &lexicon,
"The lexicon used to encode words into tokens."
"Only used for keyword spotting now");

po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: conformer, lstm, zipformer, zipformer2, "
Expand Down Expand Up @@ -80,6 +85,14 @@ bool OnlineModelConfig::Validate() const {
}
}

if (!modeling_unit.empty() &&
(modeling_unit == "fpinyin" || modeling_unit == "ppinyin")) {
if (!FileExists(lexicon)) {
SHERPA_ONNX_LOGE("lexicon: %s does not exist", lexicon.c_str());
return false;
}
}

if (!paraformer.encoder.empty()) {
return paraformer.Validate();
}
Expand Down Expand Up @@ -119,6 +132,7 @@ std::string OnlineModelConfig::ToString() const {
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "model_type=\"" << model_type << "\", ";
os << "modeling_unit=\"" << modeling_unit << "\", ";
os << "lexicon=\"" << lexicon << "\", ";
os << "bpe_vocab=\"" << bpe_vocab << "\")";

return os.str();
Expand Down
13 changes: 11 additions & 2 deletions sherpa-onnx/csrc/online-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,17 @@ struct OnlineModelConfig {
// - cjkchar
// - bpe
// - cjkchar+bpe
// - fpinyin
// - ppinyin
std::string modeling_unit = "cjkchar";
// For encoding words into tokens
// Used only for models trained with bpe
std::string bpe_vocab;

// For encoding words into tokens
// Used for models trained with pinyin or phone
std::string lexicon;

/// if tokens_buf is non-empty,
/// the tokens will be loaded from the buffer instead of from the
/// "tokens" file
Expand All @@ -60,7 +68,7 @@ struct OnlineModelConfig {
const std::string &tokens, int32_t num_threads,
int32_t warm_up, bool debug, const std::string &model_type,
const std::string &modeling_unit,
const std::string &bpe_vocab)
const std::string &bpe_vocab, const std::string &lexicon)
: transducer(transducer),
paraformer(paraformer),
wenet_ctc(wenet_ctc),
Expand All @@ -73,7 +81,8 @@ struct OnlineModelConfig {
debug(debug),
model_type(model_type),
modeling_unit(modeling_unit),
bpe_vocab(bpe_vocab) {}
bpe_vocab(bpe_vocab),
lexicon(lexicon) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
Loading
Loading