Skip to content

Commit

Permalink
add a convert_token_string_to_an_id API for the prompt ids (#794)
Browse files Browse the repository at this point in the history
* add a convert token string to an id API for the prompt ids

* fix the build issues on Linux
  • Loading branch information
wenbingl authored Aug 19, 2024
1 parent 6ce22f8 commit 711a2cf
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 42 deletions.
11 changes: 11 additions & 0 deletions include/ortx_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer, const ch
extError_t ORTX_API_CALL OrtxTokenize(
const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size, OrtxTokenId2DArray** output);


/**
* Converts a token to its corresponding ID.
*
* @param tokenizer The tokenizer object.
* @param input The input token to be converted.
* @param output Pointer to store the converted token ID.
* @return The error code indicating the success or failure of the conversion.
*/
extError_t ORTX_API_CALL OrtxConvertTokenToId(const OrtxTokenizer* tokenizer, const char* token, extTokenId_t* id);

/**
* @brief Retrieves the decoder prompt IDs from the tokenizer.
*
Expand Down
11 changes: 10 additions & 1 deletion operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@ OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKerne
return {};
}

uint32_t KernelBpeTokenizer::GetTokenId(const std::string& token) const {
auto id = bbpe_tokenizer_->GetAddedTokenId(token);
if (id != bpe::kInvalidTokenId) {
return id;
}

return bbpe_tokenizer_->GetTokenId(token);
}

std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
int64_t max_length,
bool compute_offset_mapping,
Expand Down Expand Up @@ -778,4 +787,4 @@ OrtxStatus JsonFastTokenizer::Compute(const ortc::Tensor<std::string>& input,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
}
}
3 changes: 2 additions & 1 deletion operators/tokenizer/bpe_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct KernelBpeTokenizer {
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;

const std::string& ModelName() const { return model_name_; }
uint32_t GetTokenId(const std::string& token) const;

protected:
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
Expand Down Expand Up @@ -104,7 +105,7 @@ struct SpmTokenizer : KernelBpeTokenizer {
}
};

class JsonFastTokenizer : KernelBpeTokenizer {
class JsonFastTokenizer : public KernelBpeTokenizer {
public:
JsonFastTokenizer();
bool tiktoken_ = false;
Expand Down
39 changes: 19 additions & 20 deletions operators/tokenizer/bpe_tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,8 @@ class BpeModel {
}
}

OrtxStatus Load(std::istream& vocab_stream,
std::istream& merges_stream,
const char* unk_token,
const char* special_tokens,
bool spm_converted) {
OrtxStatus Load(std::istream& vocab_stream, std::istream& merges_stream, const char* unk_token,
const char* special_tokens, bool spm_converted) {
nlohmann::json tok_json;
vocab_stream >> tok_json;
tok_json.get_to(vocab_map_);
Expand Down Expand Up @@ -125,9 +122,7 @@ class BpeModel {
return {};
}

OrtxStatus Load(const json& bpe_model,
const char* /* special_tokens */,
bool spm_converted) {
OrtxStatus Load(const json& bpe_model, const char* /* special_tokens */, bool spm_converted) {
const json& vocab_json = bpe_model["vocab"];
const json& merges_json = bpe_model["merges"];
vocab_json.get_to(vocab_map_);
Expand Down Expand Up @@ -195,8 +190,7 @@ class BpeModel {
}

OrtxStatus Load(std::unordered_map<std::string, uint32_t>& vocab,
std::vector<std::pair<std::string, std::string>>& merges,
const char* /* special_tokens */,
std::vector<std::pair<std::string, std::string>>& merges, const char* /* special_tokens */,
bool spm_converted) {
vocab_map_ = vocab;

Expand All @@ -207,7 +201,7 @@ class BpeModel {
}

uint32_t index = 0;
for (auto& tuple : merges){
for (auto& tuple : merges) {
std::string w1 = tuple.first;
std::string w2 = tuple.second;
int token_length = ort_extensions::narrow<int>(w1.length() + w2.length());
Expand Down Expand Up @@ -269,11 +263,10 @@ class BpeModel {
return {};
}

std::vector<std::string> BuildDecoder() const {
return id2token_map_;
}
std::vector<std::string> BuildDecoder() const { return id2token_map_; }

// REF: https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52
// REF:
// https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52
bpe::TokenPairs SplitByAddedAndSpecial(const ustring& input) const {
// split by added tokens
bpe::TokenPairs added_result;
Expand Down Expand Up @@ -343,9 +336,7 @@ class BpeModel {
}
}

const auto& ByteEncoder() const {
return byte_encoder_;
}
const auto& ByteEncoder() const { return byte_encoder_; }

uint32_t GetTokenId(const std::string& key) const {
auto it = vocab_map_.find(key);
Expand All @@ -356,10 +347,18 @@ class BpeModel {
}
}

const std::string& GetEndOfWordSuffix() const {
return end_of_word_suffix_;
uint32_t GetAddedTokenId(const std::string& key) const {
size_t idx = 0;
int id = added_tokens_.FindLongest(ustring(key), idx);
if (idx == 0) {
return bpe::kInvalidTokenId;
}

return static_cast<uint32_t>(id);
}

const std::string& GetEndOfWordSuffix() const { return end_of_word_suffix_; }

private:
struct BpeNode {
uint32_t id;
Expand Down
15 changes: 15 additions & 0 deletions shared/api/c_api_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, const char
return extError_t();
}

extError_t ORTX_API_CALL OrtxConvertTokenToId(const OrtxTokenizer* tokenizer, const char* token, extTokenId_t* id) {
if (tokenizer == nullptr || token == nullptr || id == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
ReturnableStatus status = token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer);
if (!status.IsOk()) {
return status.Code();
}

status = token_ptr->Token2Id(token, *id);
return status.Code();
}

extError_t ORTX_API_CALL OrtxGetDecoderPromptIds(const OrtxTokenizer* tokenizer, size_t batch_size, const char* lang,
const char* task, int no_timestamps, OrtxTokenId2DArray** output) {
if (tokenizer == nullptr || output == nullptr) {
Expand Down
19 changes: 4 additions & 15 deletions shared/api/tokenizer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,9 @@ OrtxStatus TokenizerImpl::GetDecoderPromptIds(size_t batch_size, const char* lan
}
// since it was only supported by Whisper model, should we check it here?

auto convert_tokens_to_ids = [this](const std::string& token) -> extTokenId_t {
ortc::Tensor<int64_t> ts_output(&CppAllocator::Instance());
ortc::Tensor<std::string> ts_input = ortc::Tensor<std::string>(std::vector<std::string>{std::string(token)});
auto status = this->tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt);
if (!status.IsOk()) {
return static_cast<extTokenId_t>(-1);
}
auto num = ts_output.NumberOfElement();
return static_cast<extTokenId_t>(ts_output.Data()[num / 2]); // get the middle token
};

auto translate_token_id = convert_tokens_to_ids("<|translate|>");
auto transcribe_token_id = convert_tokens_to_ids("<|transcribe|>");
auto notimestamps_token_id = convert_tokens_to_ids("<|notimestamps|>");
auto translate_token_id = tokenizer_->GetTokenId("<|translate|>");
auto transcribe_token_id = tokenizer_->GetTokenId("<|transcribe|>");
auto notimestamps_token_id = tokenizer_->GetTokenId("<|notimestamps|>");
std::vector<extTokenId_t> ids;
ids.reserve(4);
if (lang != nullptr) {
Expand All @@ -133,7 +122,7 @@ OrtxStatus TokenizerImpl::GetDecoderPromptIds(size_t batch_size, const char* lan
}

std::string lang_token = "<|" + lang_str->first + "|>";
ids.push_back(convert_tokens_to_ids(lang_token));
ids.push_back(tokenizer_->GetTokenId(lang_token));
}

if (task != nullptr) {
Expand Down
6 changes: 5 additions & 1 deletion shared/api/tokenizer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ class TokenizerImpl : public OrtxObjectImpl {
return BatchDecode(t_ids, t_text);
}

OrtxStatus Token2Id(const std::string& token, extTokenId_t& id) const {
id = tokenizer_->GetTokenId(token);
return {};
}

OrtxStatus Id2Token(extTokenId_t id, std::string& token, std::unique_ptr<BPEDecoderState>& cache) const {
BPEDecoderState* state_ptr = cache.get();
OrtxStatus status = Id2Token(id, token, &state_ptr);
Expand All @@ -50,7 +55,6 @@ class TokenizerImpl : public OrtxObjectImpl {
std::vector<std::vector<extTokenId_t>>& t_ids) const;

private:
bool tiktoken = false;
std::string tokenizer_dir_;
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig> tok_config_;
std::unique_ptr<JsonFastTokenizer> tokenizer_;
Expand Down
9 changes: 5 additions & 4 deletions test/pp_api_test/test_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,11 @@ TEST(OrtxTokenizerTest, WhisperTokenizer) {
const extTokenId_t* token_ids = NULL;
OrtxTokenId2DArrayGetItem(prompt_ids.get(), 0, &token_ids, &length);
std::vector<extTokenId_t> ids(token_ids, token_ids + length);
// std::cout << "Prompt IDs: ";
// for (const auto& id : ids) {
// std::cout << id << " ";
// }

EXPECT_EQ(ids, std::vector<extTokenId_t>({50259, 50358, 50363}));

extTokenId_t sot_id{};
err = OrtxConvertTokenToId(tokenizer.get(), "<|startoftranscript|>", &sot_id);
EXPECT_EQ(err, kOrtxOK);
EXPECT_EQ(sot_id, 50258);
}

0 comments on commit 711a2cf

Please sign in to comment.