Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split-Embedding of long texts #4

Merged
merged 1 commit into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.4.0]

### Added
- Simplify embedding with the tiny models and provide native support for splitting of long strings into several smaller ones (kwarg `split_instead_trunc`)

## [0.3.0]

### Added
- Added a base Bert Tiny model to support lightning-fast embeddings (alias `tiny_embed`). See `?embed` for more details.


Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FlashRank"
uuid = "22cc3f58-1757-4700-bb45-2032706e5a8d"
authors = ["J S <[email protected]> and contributors"]
version = "0.3.0"
version = "0.4.0"

[deps]
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
Expand Down
22 changes: 20 additions & 2 deletions src/embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,26 @@ function embed(
EmbedResult(embeddings, t)
end

function embed(embedder::EmbedderModel, passages::AbstractString)
embed(embedder, [passages])
"""
embed(
embedder::EmbedderModel, passage::AbstractString; split_instead_trunc::Bool = false)

Embeds a single `passage`.

If passage is too long for the model AND `split_instead_trunc` is true, the passage is split into several smaller chunks of size `embedder.encoder.trunc` and embedded separately.
"""
function embed(
embedder::EmbedderModel, passage::AbstractString; split_instead_trunc::Bool = false)
t = @elapsed begin
token_ids, token_type_ids, attention_mask = encode(
embedder.encoder, passage; split_instead_trunc)
## transpose as the model expects row-major
onnx_input = Dict("input_ids" => token_ids', "attention_mask" => attention_mask')
out = embedder.session(onnx_input)
## Permute dimensions to return column-major embeddings, ie, batch-size X embedding-size
embeddings = out["avg_embeddings"] |> permutedims
end
EmbedResult(embeddings, t)
end

function (embedder::EmbedderModel)(
Expand Down
68 changes: 59 additions & 9 deletions src/encoder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,20 @@ end

"""
tokenize(enc::BertTextEncoder, text::AbstractString;
add_special_tokens::Bool = true, add_end_token::Bool = true, token_ids::Bool = false)
add_special_tokens::Bool = true, add_end_token::Bool = true, token_ids::Bool = false,
max_tokens::Union{Nothing, Int} = enc.trunc)

Tokenizes the text and returns the tokens or token IDs (to skip looking up the IDs twice).

# Arguments
- `add_special_tokens::Bool = true`: Add special tokens at the beginning and end of the text.
- `add_end_token::Bool = true`: Add end token at the end of the text.
- `token_ids::Bool = false`: If true, return the token IDs directly. Otherwise, return the tokens.
- `max_tokens::Union{Nothing, Int} = enc.trunc`: The maximum number of tokens to return (usually defined by the model).
"""
function tokenize(enc::BertTextEncoder, text::AbstractString;
add_special_tokens::Bool = true, add_end_token::Bool = true, token_ids::Bool = false)
add_special_tokens::Bool = true, add_end_token::Bool = true, token_ids::Bool = false,
max_tokens::Union{Nothing, Int} = enc.trunc)
tokens = token_ids ? Int[] : String[]
if add_special_tokens
token = token_ids ? enc.vocab[enc.startsym] : enc.startsym
Expand All @@ -55,8 +58,8 @@ function tokenize(enc::BertTextEncoder, text::AbstractString;
for token in bert_uncased_tokenizer(text)
append!(tokens, enc.wp(token; token_ids))
end
if !isnothing(enc.trunc) && length(tokens) > (enc.trunc - 1)
tokens = tokens[1:(enc.trunc - 1)]
if !isnothing(max_tokens) && length(tokens) > (max_tokens - 1)
tokens = tokens[1:(max_tokens - 1)]
end
if add_special_tokens || add_end_token
token = token_ids ? enc.vocab[enc.endsym] : enc.endsym
Expand All @@ -65,11 +68,58 @@ function tokenize(enc::BertTextEncoder, text::AbstractString;
return tokens
end

function encode(enc::BertTextEncoder, text::String; add_special_tokens::Bool = true)
token_ids = tokenize(enc, text; add_special_tokens, token_ids = true)
# Zero indexed as models are trained for Python
token_type_ids = zeros(Int, length(token_ids))
attention_mask = ones(Int, length(token_ids))
"""
encode(enc::BertTextEncoder, text::String; add_special_tokens::Bool = true,
max_tokens::Int = enc.trunc, split_instead_trunc::Bool = false)

Encodes the text and returns the token IDs, token type IDs, and attention mask.

We enforce `max_tokens` to be a concrete number here to be able to do `split_instead_trunc`.
`split_instead_trunc` splits any long sequences into several smaller ones.
"""
function encode(enc::BertTextEncoder, text::String; add_special_tokens::Bool = true,
max_tokens::Int = enc.trunc, split_instead_trunc::Bool = false)
if !split_instead_trunc
## Standard run - if text is longer, we truncate it and ignore
token_ids = tokenize(enc, text; add_special_tokens, token_ids = true, max_tokens)
# Zero indexed as models are trained for Python
token_type_ids = zeros(Int, length(token_ids))
attention_mask = ones(Int, length(token_ids))
else
## Split run - if text is longer, we split it into multiple chunks and encode them separately
## Only possible with a single string to know where the chunks belong to
## tokenize without special tokens at first
token_ids = tokenize(enc, text; add_special_tokens = false,
token_ids = true, max_tokens = nothing)
## determine correct chunk size
start_token = enc.vocab[enc.startsym]
end_token = enc.vocab[enc.endsym]
chunk_size = max_tokens - 2 * add_special_tokens
itr = Iterators.partition(token_ids, chunk_size)
num_chunks = length(itr)
## split vector in several
mat_token_ids = zeros(Int, max_tokens, num_chunks)
token_type_ids = zeros(Int, max_tokens, num_chunks)
attention_mask = zeros(Int, max_tokens, num_chunks)
@inbounds for (i, chunk) in enumerate(itr)
if add_special_tokens
mat_token_ids[1, i] = start_token
attention_mask[1, i] = 1
end
for ri in eachindex(chunk)
## if special token, we shift all items by 1 down
row_idx = add_special_tokens ? ri + 1 : ri
mat_token_ids[row_idx, i] = chunk[ri]
attention_mask[row_idx, i] = 1
end
if add_special_tokens
row_idx = 2 + length(chunk)
mat_token_ids[row_idx, i] = end_token
attention_mask[row_idx, i] = 1
end
end
token_ids = mat_token_ids
end
return token_ids, token_type_ids, attention_mask
end

Expand Down
9 changes: 9 additions & 0 deletions test/embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ using FlashRank: EmbedderModel, embed, EmbedResult
result = embedder(texts[1])
@test result.embeddings isa AbstractArray{Float32}
@test size(result.embeddings) == (312, 1)

## Splitting - no effect
result = embed(embedder, texts[1]; split_instead_trunc = true)
@test size(result.embeddings) == (312, 1)

# split when long text
long_text = repeat("Hello, how are you? ", 200)
result = embed(embedder, long_text; split_instead_trunc = true)
@test size(result.embeddings) == (312, 3)
end

@testset "show-embedding" begin
Expand Down
32 changes: 32 additions & 0 deletions test/encoder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,36 @@ using FlashRank: RankerModel, tokenize, encode
102 1012; 0 102]
@test all(iszero, output[2])
@test output[3] == [1 1; 1 1; 1 1; 1 1; 1 1; 1 1; 1 1; 1 1; 0 1]

## # Encode with splitting
long_text = repeat("Hello, how are you? ", 100)

output = encode(encoder, long_text; split_instead_trunc = true)
@test size(output[1]) == (512, 2)
## check that tokens were added correctly
@test output[1][1, :] == [101, 101]
@test output[1][end, :] == [102, 0]
@test output[1][93, 2] == 102
@test output[1][94, 2] == 0
@test size(output[3]) == (512, 2)
@test output[3][1, :] == [1, 1]
@test output[3][end, :] == [1, 0]
@test output[3][93, :] == [1, 1]
@test output[3][94, :] == [1, 0]
@test iszero(output[2])

## Do not add special tokens
output = encode(
encoder, long_text; split_instead_trunc = true, add_special_tokens = false)
@test size(output[1]) == (512, 2)
## check that tokens were added correctly
@test output[1][1, :] == [7592, 2129]
@test output[1][end, :] == [1010, 0]
@test output[1][93, 2] == 0
@test size(output[3]) == (512, 2)
@test output[3][1, :] == [1, 1]
@test output[3][end, :] == [1, 0]
@test output[3][89, :] == [1, 1]
@test output[3][90, :] == [1, 0]
@test iszero(output[2])
end