Skip to content

Commit

Permalink
added token supporting
Browse files Browse the repository at this point in the history
  • Loading branch information
wenbingl committed Jan 29, 2024
1 parent f870ff5 commit b994688
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 81 deletions.
35 changes: 1 addition & 34 deletions tfmtok/bpe_encoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,37 +105,6 @@ class BpeEncoder {
return {};
}

TfmStatus LoadAddedTokens(const std::string_view added_tokens[], size_t tok_num) {
int id = bpe::kInvalidTokenId;
for (size_t n = 0; n < tok_num; ++n) {
std::string token(added_tokens[n]); // Convert std::string_view to std::string
id = GetTokenId(token);
added_tokens_.Add(FromUTF8(added_tokens[n]), 0, std::make_optional(id));
}

return {};
}

// REF: https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52
bpe::TokenPairs SplitByAddedAndSpecial(const std::u32string& input) const {
// split by added tokens
bpe::TokenPairs added_result;
bpe::TokenPairs final_result;
added_tokens_.Split(input, added_result);
for (const auto& [token, id] : added_result) {
if (id != bpe::kInvalidTokenId) {
final_result.emplace_back(token, id);
} else {
auto special_result = special_tokens_.SplitBySpecialTokens(token);
for (const auto& [_token, _id] : special_result) {
final_result.emplace_back(_token, _id);
}
}
}

return final_result;
}

void PerformBPE(std::list<std::pair<uint32_t, uint32_t>>& vals) const {
while (vals.size() >= 2) {
auto pos_it = vals.end();
Expand Down Expand Up @@ -186,7 +155,7 @@ class BpeEncoder {
}
}

uint32_t GetTokenId(const std::string& key) {
uint32_t GetTokenId(const std::string& key) const{
auto it = vocab_map_.find(key);
if (it != end(vocab_map_)) {
return it->second;
Expand Down Expand Up @@ -227,8 +196,6 @@ class BpeEncoder {
uint32_t unk_id_ = std::numeric_limits<uint32_t>::max();
uint32_t max_token_id_ = 0;
std::string end_of_word_suffix_;
bpe::SpecialTokenMap special_tokens_;
TrieTree<char32_t> added_tokens_;
};

} // namespace tfm
79 changes: 79 additions & 0 deletions tfmtok/bpe_extended.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "utils/unescape.h"
#include "trietree.hpp"
#include "bpe_utils.hpp"

namespace tfm::bpe {

class ExtendedToken {
public:
TfmStatus LoadAddedTokens(const std::string_view added_tokens[], size_t tok_num) {
// int id = bpe::kInvalidTokenId;
// for (size_t n = 0; n < tok_num; ++n) {
// std::string token(added_tokens[n]); // Convert std::string_view to std::string
// id = GetTokenId(token);
// added_tokens_.Add(FromUTF8(added_tokens[n]), 0, std::make_optional(id));
// }

return {};
}

TfmStatus LoadAddedTokens(const simdjson::dom::element& added_tokens, std::map<int64_t, std::string>& dict) {
for (simdjson::dom::object tok: added_tokens) {
int id = bpe::kInvalidTokenId;
std::string_view content;
bool special = false;
for (auto field: tok) {
if (field.key == "id") {
id = gsl::narrow_cast<int>(field.value.get_int64());
} else if (field.key == "content") {
content = field.value;
} else if (field.key == "special") {
special = field.value;
}
}
if (id == bpe::kInvalidTokenId || content.empty() ){ // skip the token if id is not specified
continue;
}
// Need to find the case where added_token cannot cover the special_token
// if (special) {
// special_tokens_.Add(content, id);
// }
dict.emplace(id, std::string(content));
added_tokens_.Add(FromUTF8(content), 0, std::make_optional(id));
}
return {};
}


// REF:
// https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/tokenization_utils.py#L52
bpe::TokenPairs Split(const std::u32string& input) const {
// split by added tokens
bpe::TokenPairs added_result;
bpe::TokenPairs final_result;
added_tokens_.Split(input, added_result);
for (const auto& [token, id] : added_result) {
if (id != bpe::kInvalidTokenId) {
final_result.emplace_back(token, id);
} else {
auto special_result = special_tokens_.SplitBySpecialTokens(token);
for (const auto& [_token, _id] : special_result) {
final_result.emplace_back(_token, _id);
}
}
}

return final_result;
}

private:
SpecialTokenMap special_tokens_;
TrieTree<char32_t> added_tokens_;
};

} // namespace tfm::bpe
110 changes: 66 additions & 44 deletions tfmtok/token_bpe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,16 @@ void BPETokenizer::CreateByteEncoder() {
)
*/
if ((/* i >= 0 && */ i < 33) || (i >= 127 && i < 161) || (i == 173)) {
byte_encoder_[i] = index;
byte_decoder_[index] = i;
index++;
byte_encoder_[i] = bbpe_encoder_.GetTokenId(EncodeUTF8Char(index++));
} else {
byte_encoder_[i] = i;
byte_encoder_[i] = bbpe_encoder_.GetTokenId(EncodeUTF8Char(i));
byte_decoder_[i] = i;
}
}
}

BPETokenizer::BPETokenizer() {
CreateByteEncoder();
bbpe_encoder_ = std::make_unique<BpeEncoder>();
}

BPETokenizer::~BPETokenizer() = default;
Expand All @@ -116,10 +113,10 @@ void BPETokenizer::LoadPredefinedTokens(const TokenConfig& config) {
eos_token_ = config.eos_token_.content_;
pad_token_ = config.pad_token_;

unk_token_id_ = bbpe_encoder_->GetTokenId(unk_token_);
bos_token_id_ = bbpe_encoder_->GetTokenId(bos_token_);
eos_token_id_ = bbpe_encoder_->GetTokenId(eos_token_);
pad_token_id_ = bbpe_encoder_->GetTokenId(pad_token_);
unk_token_id_ = bbpe_encoder_.GetTokenId(unk_token_);
bos_token_id_ = bbpe_encoder_.GetTokenId(bos_token_);
eos_token_id_ = bbpe_encoder_.GetTokenId(eos_token_);
pad_token_id_ = bbpe_encoder_.GetTokenId(pad_token_);

added_tokens_.emplace(std::pair(unk_token_id_, unk_token_));
added_tokens_.emplace(std::pair(bos_token_id_, bos_token_));
Expand All @@ -132,45 +129,63 @@ void BPETokenizer::LoadPredefinedTokens(const TokenConfig& config) {
all_special_ids_.emplace(pad_token_id_);
}

TfmStatus BPETokenizer::DecodeExtraArgs(const simdjson::dom::element& root) {
const simdjson::dom::element& decoder_obj = root.at_key("decoder");
if (decoder_obj.is_null()) {
return {kTfmErrorInvalidFile, "Cannot find the decoder key in the the tokenizer.json"};
}
TryToGetJson(decoder_obj, "add_prefix_space", decode_extra_args_.add_prefix_space);
return TfmStatus::OK();
}

TfmStatus BPETokenizer::Onload() {
simdjson::dom::parser parser;
simdjson::dom::element root;
std::string tokenizer_file = GetDataDir() + "/tokenizer.json";
auto error = parser.load(tokenizer_file).get(root);
if (error) {
return {kTfmErrorInvalidFile, "Invalid tokenizer file"};
return {kTfmErrorInvalidFile, "Failed to parse tokenizer.json"};
}

auto& config = *GetConfig();
model_name_ = std::string_view(config.tokenizer_class_.c_str(),
config.tokenizer_class_.find("Tokenizer"));
auto status = bbpe_encoder_->Load(root, config);
auto status = bbpe_encoder_.Load(root, config);
if (!status.ok()) {
return status;
}

// update the byte encoder with token-ids
// it's not a bug of byte_decoder_size() here since the encoder/decoder are symmetric
for(uint32_t i = 0; i < byte_decoder_.size(); ++i) {
byte_encoder_[i] = bbpe_encoder_->GetTokenId(EncodeUTF8Char(i));
}
CreateByteEncoder();

// Get AddedTokens from config
std::string_view added_tokens[] = {
config.bos_token_.content_,
config.eos_token_.content_,
config.unk_token_.content_,
config.pad_token_};
simdjson::dom::element added_tokens_obj;
if (error = root["added_tokens"].get(added_tokens_obj); error) {
// Get AddedTokens from config
std::string_view added_tokens[] = {
config.unk_token_.content_,
config.eos_token_.content_,
config.bos_token_.content_,
config.pad_token_};
size_t num_added_tokens = sizeof(added_tokens) / sizeof(added_tokens[0]);

if (config.pad_token_.empty()) {
num_added_tokens--;
}
if (config.bos_token_.content_.empty()) {
num_added_tokens--;
}

status = bbpe_encoder_->LoadAddedTokens(added_tokens,
sizeof(added_tokens) / sizeof(added_tokens[0]));
status = extended_token_.LoadAddedTokens(added_tokens, num_added_tokens);
} else {
status = extended_token_.LoadAddedTokens(added_tokens_obj, added_tokens_);
}

if (!status.ok()) {
return status;
}

LoadPredefinedTokens(config);
arr_vocab_ = bbpe_encoder_->BuildDecoder();
arr_vocab_ = bbpe_encoder_.BuildDecoder();
status = DecodeExtraArgs(root);

return status;
}
Expand Down Expand Up @@ -212,7 +227,7 @@ std::vector<tfmTokenId_t> BPETokenizer::Encode(std::string_view sv_input,
return res;
}

if (ModelName() != kModel_GPT2) {
if (ModelName() != kModel_GPT2 && ModelName() != kModel_CodeGen) {
// Add BOS token to result
res.push_back(bos_token_id_);
}
Expand All @@ -222,7 +237,7 @@ std::vector<tfmTokenId_t> BPETokenizer::Encode(std::string_view sv_input,
}

// Parse input
auto special_token_split_res = bbpe_encoder_->SplitByAddedAndSpecial(input);
auto special_token_split_res = extended_token_.Split(input);
bpe::TokenWithRegularExp regcmp;

for (auto& seg_id : special_token_split_res) {
Expand Down Expand Up @@ -273,7 +288,7 @@ std::vector<tfmTokenId_t> BPETokenizer::Encode(std::string_view sv_input,
for (int i = 0; i < utf8_token.length(); i++) {
if (i == utf8_token.length() - 1) {
std::string boundary(1, utf8_token[i]);
byte_list.emplace_back(bbpe_encoder_->GetTokenId(boundary + "</w>"), 1);
byte_list.emplace_back(bbpe_encoder_.GetTokenId(boundary + "</w>"), 1);
} else {
byte_list.emplace_back(byte_encoder_[static_cast<unsigned char>(utf8_token[i])], 1);
}
Expand All @@ -285,7 +300,7 @@ std::vector<tfmTokenId_t> BPETokenizer::Encode(std::string_view sv_input,
}

// Perform BPE
bbpe_encoder_->PerformBPE(byte_list);
bbpe_encoder_.PerformBPE(byte_list);

// Add output to result
for (auto p : byte_list) {
Expand Down Expand Up @@ -317,7 +332,7 @@ std::vector<tfmTokenId_t> BPETokenizer::Encode(std::string_view sv_input,
}
}

if (ModelName() != kModel_GPT2) {
if (ModelName() != kModel_GPT2 && ModelName() != kModel_CodeGen) {
// Add EOS token to result
res.push_back(eos_token_id_);
}
Expand All @@ -341,12 +356,14 @@ TfmStatus BPETokenizer::Decode(const span<tfmTokenId_t const>& ids, std::string&
bool f_special = false;
auto count = static_cast<size_t>(ids.size());
auto p_ids = ids.data();
auto ewsuffix = bbpe_encoder_->GetEndWordSuffix();

auto& args = decode_extra_args_;
auto end_word_suffix = bbpe_encoder_.GetEndWordSuffix();
for (size_t tok_idx = 0; tok_idx < count; ++tok_idx) {
const auto token = *(p_ids + tok_idx);
std::string decoded_token;
f_special = all_special_ids_.count(token) ? true : false;
if (skip_special_tokens_ && f_special) {
if (args.skip_special_tokens_ && f_special) {
f_special_last = f_special;
continue;
}
Expand All @@ -355,36 +372,41 @@ TfmStatus BPETokenizer::Decode(const span<tfmTokenId_t const>& ids, std::string&
const std::string ws = added_tokens_.at(token);
decoded_token = (std::string)ws;
} else if (static_cast<size_t>(token) < arr_vocab_.size()) {
const auto str = arr_vocab_[token];
for (auto wchr : str) {
unsigned char uchr = byte_decoder_.at(wchr);
decoded_token.push_back(uchr);
const auto str = FromUTF8(arr_vocab_[token]);
for (unsigned char wchr : str) {
if (byte_decoder_.count(wchr) == 0 && wchr <= 0xFF) {
// std::cout << "Error: cannot find the byte_decoder_ for " << (uint32_t)(unsigned char)wchr << std::endl;
decoded_token.push_back(gsl::narrow<unsigned char>(wchr));
}else {
unsigned char uchr = byte_decoder_.at(wchr);
decoded_token.push_back(uchr);
}
}
} else {
if (skip_special_tokens_) {
if (args.skip_special_tokens_) {
continue;
} else {
decoded_token = unk_token_;
}
}

// ugly hack for CLIP
if (ewsuffix.size() > 0) {
if (decoded_token.size() >= ewsuffix.size() &&
decoded_token.substr(decoded_token.size() - ewsuffix.size()) == ewsuffix) {
decoded_token = decoded_token.substr(0, decoded_token.size() - ewsuffix.size());
// remove the end_word_suffix like </w> or </s> etc.
if (end_word_suffix.size() > 0) {
if (decoded_token.size() >= end_word_suffix.size() &&
decoded_token.substr(decoded_token.size() - end_word_suffix.size()) == end_word_suffix) {
decoded_token = decoded_token.substr(0, decoded_token.size() - end_word_suffix.size());
decoded_token += ' ';
}
}

if (whitespace_token_ &&
if (args.whitespace_token_ &&
f_special && (tok_idx > 0 && !f_special_last)) {
text.push_back(' ');
}

text.append(decoded_token);

if (whitespace_token_ &&
if (args.whitespace_token_ &&
f_special && tok_idx != count - 1) {
text.push_back(' ');
}
Expand Down
Loading

0 comments on commit b994688

Please sign in to comment.