From 711a2cfa699a768721f6bbb83c92e79b27df376f Mon Sep 17 00:00:00 2001 From: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:44:07 -0700 Subject: [PATCH] add a convert_token_string_to_an_id API for the prompt ids (#794) * add a convert token string to an id API for the prompt ids * fix the build issues on Linux --- include/ortx_tokenizer.h | 11 ++++++++ operators/tokenizer/bpe_kernels.cc | 11 +++++++- operators/tokenizer/bpe_kernels.h | 3 ++- operators/tokenizer/bpe_tokenizer.hpp | 39 +++++++++++++-------------- shared/api/c_api_tokenizer.cc | 15 +++++++++++ shared/api/tokenizer_impl.cc | 19 +++---------- shared/api/tokenizer_impl.h | 6 ++++- test/pp_api_test/test_tokenizer.cc | 9 ++++--- 8 files changed, 71 insertions(+), 42 deletions(-) diff --git a/include/ortx_tokenizer.h b/include/ortx_tokenizer.h index dbbfac9d..ca2feba6 100644 --- a/include/ortx_tokenizer.h +++ b/include/ortx_tokenizer.h @@ -37,6 +37,17 @@ extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer, const ch extError_t ORTX_API_CALL OrtxTokenize( const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size, OrtxTokenId2DArray** output); + +/** + * Converts a token to its corresponding ID. + * + * @param tokenizer The tokenizer object. + * @param input The input token to be converted. + * @param output Pointer to store the converted token ID. + * @return The error code indicating the success or failure of the conversion. + */ +extError_t ORTX_API_CALL OrtxConvertTokenToId(const OrtxTokenizer* tokenizer, const char* token, extTokenId_t* id); + /** * @brief Retrieves the decoder prompt IDs from the tokenizer. * diff --git a/operators/tokenizer/bpe_kernels.cc b/operators/tokenizer/bpe_kernels.cc index b19cd2c3..80397fd6 100644 --- a/operators/tokenizer/bpe_kernels.cc +++ b/operators/tokenizer/bpe_kernels.cc @@ -171,6 +171,15 @@ OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKerne return {}; } +uint32_t KernelBpeTokenizer::GetTokenId(const std::string& token) const { + auto id = bbpe_tokenizer_->GetAddedTokenId(token); + if (id != bpe::kInvalidTokenId) { + return id; + } + + return bbpe_tokenizer_->GetTokenId(token); +} + std::vector KernelBpeTokenizer::Tokenize(ustring& input, int64_t max_length, bool compute_offset_mapping, @@ -778,4 +787,4 @@ OrtxStatus JsonFastTokenizer::Compute(const ortc::Tensor& input, std::optional*> attention_mask, std::optional*> offset_mapping) const { return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping); -} \ No newline at end of file +} diff --git a/operators/tokenizer/bpe_kernels.h b/operators/tokenizer/bpe_kernels.h index b6e38629..bf4b0cd3 100644 --- a/operators/tokenizer/bpe_kernels.h +++ b/operators/tokenizer/bpe_kernels.h @@ -33,6 +33,7 @@ struct KernelBpeTokenizer { std::optional*> offset_mapping) const; const std::string& ModelName() const { return model_name_; } + uint32_t GetTokenId(const std::string& token) const; protected: using OffsetMappingType = std::list>; @@ -104,7 +105,7 @@ struct SpmTokenizer : KernelBpeTokenizer { } }; -class JsonFastTokenizer : KernelBpeTokenizer { +class JsonFastTokenizer : public KernelBpeTokenizer { public: JsonFastTokenizer(); bool tiktoken_ = false; diff --git a/operators/tokenizer/bpe_tokenizer.hpp b/operators/tokenizer/bpe_tokenizer.hpp index f70c5e41..ff5282a3 100644 --- a/operators/tokenizer/bpe_tokenizer.hpp +++ b/operators/tokenizer/bpe_tokenizer.hpp @@ -44,11 +44,8 @@ class BpeModel { } } - OrtxStatus Load(std::istream& vocab_stream, - std::istream& merges_stream, - const char* unk_token, - const char* special_tokens, - bool spm_converted) { + OrtxStatus Load(std::istream& vocab_stream, std::istream& merges_stream, const char* unk_token, + const char* special_tokens, bool spm_converted) { nlohmann::json tok_json; vocab_stream >> tok_json; tok_json.get_to(vocab_map_); @@ -125,9 +122,7 @@ class BpeModel { return {}; } - OrtxStatus Load(const json& bpe_model, - const char* /* special_tokens */, - bool spm_converted) { + OrtxStatus Load(const json& bpe_model, const char* /* special_tokens */, bool spm_converted) { const json& vocab_json = bpe_model["vocab"]; const json& merges_json = bpe_model["merges"]; vocab_json.get_to(vocab_map_); @@ -195,8 +190,7 @@ class BpeModel { } OrtxStatus Load(std::unordered_map& vocab, - std::vector>& merges, - const char* /* special_tokens */, + std::vector>& merges, const char* /* special_tokens */, bool spm_converted) { vocab_map_ = vocab; @@ -207,7 +201,7 @@ class BpeModel { } uint32_t index = 0; - for (auto& tuple : merges){ + for (auto& tuple : merges) { std::string w1 = tuple.first; std::string w2 = tuple.second; int token_length = ort_extensions::narrow(w1.length() + w2.length()); @@ -269,11 +263,10 @@ class BpeModel { return {}; } - std::vector BuildDecoder() const { - return id2token_map_; - } + std::vector BuildDecoder() const { return id2token_map_; } - // REF: https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52 + // REF: + // https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52 bpe::TokenPairs SplitByAddedAndSpecial(const ustring& input) const { // split by added tokens bpe::TokenPairs added_result; @@ -343,9 +336,7 @@ class BpeModel { } } - const auto& ByteEncoder() const { - return byte_encoder_; - } + const auto& ByteEncoder() const { return byte_encoder_; } uint32_t GetTokenId(const std::string& key) const { auto it = vocab_map_.find(key); @@ -356,10 +347,18 @@ class BpeModel { } } - const std::string& GetEndOfWordSuffix() const { - return end_of_word_suffix_; + uint32_t GetAddedTokenId(const std::string& key) const { + size_t idx = 0; + int id = added_tokens_.FindLongest(ustring(key), idx); + if (idx == 0) { + return bpe::kInvalidTokenId; + } + + return static_cast(id); } + const std::string& GetEndOfWordSuffix() const { return end_of_word_suffix_; } + private: struct BpeNode { uint32_t id; diff --git a/shared/api/c_api_tokenizer.cc b/shared/api/c_api_tokenizer.cc index 4ffc2b4e..dd7fc316 100644 --- a/shared/api/c_api_tokenizer.cc +++ b/shared/api/c_api_tokenizer.cc @@ -76,6 +76,21 @@ extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, const char return extError_t(); } +extError_t ORTX_API_CALL OrtxConvertTokenToId(const OrtxTokenizer* tokenizer, const char* token, extTokenId_t* id) { + if (tokenizer == nullptr || token == nullptr || id == nullptr) { + ReturnableStatus::last_error_message_ = "Invalid argument"; + return kOrtxErrorInvalidArgument; + } + auto token_ptr = static_cast(tokenizer); + ReturnableStatus status = token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer); + if (!status.IsOk()) { + return status.Code(); + } + + status = token_ptr->Token2Id(token, *id); + return status.Code(); +} + extError_t ORTX_API_CALL OrtxGetDecoderPromptIds(const OrtxTokenizer* tokenizer, size_t batch_size, const char* lang, const char* task, int no_timestamps, OrtxTokenId2DArray** output) { if (tokenizer == nullptr || output == nullptr) { diff --git a/shared/api/tokenizer_impl.cc b/shared/api/tokenizer_impl.cc index e0b5c52f..7a0cbff5 100644 --- a/shared/api/tokenizer_impl.cc +++ b/shared/api/tokenizer_impl.cc @@ -110,20 +110,9 @@ OrtxStatus TokenizerImpl::GetDecoderPromptIds(size_t batch_size, const char* lan } // since it was only supported by Whisper model, should we check it here? - auto convert_tokens_to_ids = [this](const std::string& token) -> extTokenId_t { - ortc::Tensor ts_output(&CppAllocator::Instance()); - ortc::Tensor ts_input = ortc::Tensor(std::vector{std::string(token)}); - auto status = this->tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt); - if (!status.IsOk()) { - return static_cast(-1); - } - auto num = ts_output.NumberOfElement(); - return static_cast(ts_output.Data()[num / 2]); // get the middle token - }; - - auto translate_token_id = convert_tokens_to_ids("<|translate|>"); - auto transcribe_token_id = convert_tokens_to_ids("<|transcribe|>"); - auto notimestamps_token_id = convert_tokens_to_ids("<|notimestamps|>"); + auto translate_token_id = tokenizer_->GetTokenId("<|translate|>"); + auto transcribe_token_id = tokenizer_->GetTokenId("<|transcribe|>"); + auto notimestamps_token_id = tokenizer_->GetTokenId("<|notimestamps|>"); std::vector ids; ids.reserve(4); if (lang != nullptr) { @@ -133,7 +122,7 @@ OrtxStatus TokenizerImpl::GetDecoderPromptIds(size_t batch_size, const char* lan } std::string lang_token = "<|" + lang_str->first + "|>"; - ids.push_back(convert_tokens_to_ids(lang_token)); + ids.push_back(tokenizer_->GetTokenId(lang_token)); } if (task != nullptr) { diff --git a/shared/api/tokenizer_impl.h b/shared/api/tokenizer_impl.h index f5b51c28..9eab1e8f 100644 --- a/shared/api/tokenizer_impl.h +++ b/shared/api/tokenizer_impl.h @@ -27,6 +27,11 @@ class TokenizerImpl : public OrtxObjectImpl { return BatchDecode(t_ids, t_text); } + OrtxStatus Token2Id(const std::string& token, extTokenId_t& id) const { + id = tokenizer_->GetTokenId(token); + return {}; + } + OrtxStatus Id2Token(extTokenId_t id, std::string& token, std::unique_ptr& cache) const { BPEDecoderState* state_ptr = cache.get(); OrtxStatus status = Id2Token(id, token, &state_ptr); @@ -50,7 +55,6 @@ class TokenizerImpl : public OrtxObjectImpl { std::vector>& t_ids) const; private: - bool tiktoken = false; std::string tokenizer_dir_; std::shared_ptr tok_config_; std::unique_ptr tokenizer_; diff --git a/test/pp_api_test/test_tokenizer.cc b/test/pp_api_test/test_tokenizer.cc index 6f06acc4..3c2f64cf 100644 --- a/test/pp_api_test/test_tokenizer.cc +++ b/test/pp_api_test/test_tokenizer.cc @@ -410,10 +410,11 @@ TEST(OrtxTokenizerTest, WhisperTokenizer) { const extTokenId_t* token_ids = NULL; OrtxTokenId2DArrayGetItem(prompt_ids.get(), 0, &token_ids, &length); std::vector ids(token_ids, token_ids + length); - // std::cout << "Prompt IDs: "; - // for (const auto& id : ids) { - // std::cout << id << " "; - // } EXPECT_EQ(ids, std::vector({50259, 50358, 50363})); + + extTokenId_t sot_id{}; + err = OrtxConvertTokenToId(tokenizer.get(), "<|startoftranscript|>", &sot_id); + EXPECT_EQ(err, kOrtxOK); + EXPECT_EQ(sot_id, 50258); }