Skip to content

Commit

Permalink
Split-Embedding of long texts
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Jun 29, 2024
2 parents 17dc9cc + 27ddef9 commit 3455ecc
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 12 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.4.0]

### Added
- Simplify embedding with the tiny models and provide native support for splitting of long strings into several smaller ones (kwarg `split_instead_trunc`)

## [0.3.0]

### Added
- Added a base Bert Tiny model to support lightning-fast embeddings (alias `tiny_embed`). See `?embed` for more details.


Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FlashRank"
uuid = "22cc3f58-1757-4700-bb45-2032706e5a8d"
authors = ["J S <[email protected]> and contributors"]
version = "0.3.0"
version = "0.4.0"

[deps]
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
Expand Down
22 changes: 20 additions & 2 deletions src/embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,26 @@ function embed(
EmbedResult(embeddings, t)
end

function embed(embedder::EmbedderModel, passages::AbstractString)
embed(embedder, [passages])
"""
embed(
embedder::EmbedderModel, passage::AbstractString; split_instead_trunc::Bool = false)
Embeds a single `passage`.
If passage is too long for the model AND `split_instead_trunc` is true, the passage is split into several smaller chunks of size `embedder.encoder.trunc` and embedded separately.
"""
function embed(
embedder::EmbedderModel, passage::AbstractString; split_instead_trunc::Bool = false)
t = @elapsed begin
token_ids, token_type_ids, attention_mask = encode(
embedder.encoder, passage; split_instead_trunc)
## transpose as the model expects row-major
onnx_input = Dict("input_ids" => token_ids', "attention_mask" => attention_mask')
out = embedder.session(onnx_input)
## Permute dimensions to return column-major embeddings, ie, batch-size X embedding-size
embeddings = out["avg_embeddings"] |> permutedims
end
EmbedResult(embeddings, t)
end

function (embedder::EmbedderModel)(
Expand Down
68 changes: 59 additions & 9 deletions src/encoder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,20 @@ end

"""
tokenize(enc::BertTextEncoder, text::AbstractString;
add_special_tokens::Bool = true, add_end_token::Bool = true, token_ids::Bool = false)
add_special_tokens::Bool = true, add_end_token::Bool = true, token_ids::Bool = false,
max_tokens::Union{Nothing, Int} = enc.trunc)
Tokenizes the text and returns the tokens or token IDs (to skip looking up the IDs twice).
# Arguments
- `add_special_tokens::Bool = true`: Add special tokens at the beginning and end of the text.
- `add_end_token::Bool = true`: Add end token at the end of the text.
- `token_ids::Bool = false`: If true, return the token IDs directly. Otherwise, return the tokens.
- `max_tokens::Union{Nothing, Int} = enc.trunc`: The maximum number of tokens to return (usually defined by the model).
"""
function tokenize(enc::BertTextEncoder, text::AbstractString;
add_special_tokens::Bool = true, add_end_token::Bool = true, token_ids::Bool = false)
add_special_tokens::Bool = true, add_end_token::Bool = true, token_ids::Bool = false,
max_tokens::Union{Nothing, Int} = enc.trunc)
tokens = token_ids ? Int[] : String[]
if add_special_tokens
token = token_ids ? enc.vocab[enc.startsym] : enc.startsym
Expand All @@ -55,8 +58,8 @@ function tokenize(enc::BertTextEncoder, text::AbstractString;
for token in bert_uncased_tokenizer(text)
append!(tokens, enc.wp(token; token_ids))
end
if !isnothing(enc.trunc) && length(tokens) > (enc.trunc - 1)
tokens = tokens[1:(enc.trunc - 1)]
if !isnothing(max_tokens) && length(tokens) > (max_tokens - 1)
tokens = tokens[1:(max_tokens - 1)]
end
if add_special_tokens || add_end_token
token = token_ids ? enc.vocab[enc.endsym] : enc.endsym
Expand All @@ -65,11 +68,58 @@ function tokenize(enc::BertTextEncoder, text::AbstractString;
return tokens
end

function encode(enc::BertTextEncoder, text::String; add_special_tokens::Bool = true)
token_ids = tokenize(enc, text; add_special_tokens, token_ids = true)
# Zero indexed as models are trained for Python
token_type_ids = zeros(Int, length(token_ids))
attention_mask = ones(Int, length(token_ids))
"""
encode(enc::BertTextEncoder, text::String; add_special_tokens::Bool = true,
max_tokens::Int = enc.trunc, split_instead_trunc::Bool = false)
Encodes the text and returns the token IDs, token type IDs, and attention mask.
We enforce `max_tokens` to be a concrete number here to be able to do `split_instead_trunc`.
`split_instead_trunc` splits any long sequences into several smaller ones.
"""
function encode(enc::BertTextEncoder, text::String; add_special_tokens::Bool = true,
max_tokens::Int = enc.trunc, split_instead_trunc::Bool = false)
if !split_instead_trunc
## Standard run - if text is longer, we truncate it and ignore
token_ids = tokenize(enc, text; add_special_tokens, token_ids = true, max_tokens)
# Zero indexed as models are trained for Python
token_type_ids = zeros(Int, length(token_ids))
attention_mask = ones(Int, length(token_ids))
else
## Split run - if text is longer, we split it into multiple chunks and encode them separately
## Only possible with a single string to know where the chunks belong to
## tokenize without special tokens at first
token_ids = tokenize(enc, text; add_special_tokens = false,
token_ids = true, max_tokens = nothing)
## determine correct chunk size
start_token = enc.vocab[enc.startsym]
end_token = enc.vocab[enc.endsym]
chunk_size = max_tokens - 2 * add_special_tokens
itr = Iterators.partition(token_ids, chunk_size)
num_chunks = length(itr)
## split vector in several
mat_token_ids = zeros(Int, max_tokens, num_chunks)
token_type_ids = zeros(Int, max_tokens, num_chunks)
attention_mask = zeros(Int, max_tokens, num_chunks)
@inbounds for (i, chunk) in enumerate(itr)
if add_special_tokens
mat_token_ids[1, i] = start_token
attention_mask[1, i] = 1
end
for ri in eachindex(chunk)
## if special token, we shift all items by 1 down
row_idx = add_special_tokens ? ri + 1 : ri
mat_token_ids[row_idx, i] = chunk[ri]
attention_mask[row_idx, i] = 1
end
if add_special_tokens
row_idx = 2 + length(chunk)
mat_token_ids[row_idx, i] = end_token
attention_mask[row_idx, i] = 1
end
end
token_ids = mat_token_ids
end
return token_ids, token_type_ids, attention_mask
end

Expand Down
9 changes: 9 additions & 0 deletions test/embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ using FlashRank: EmbedderModel, embed, EmbedResult
result = embedder(texts[1])
@test result.embeddings isa AbstractArray{Float32}
@test size(result.embeddings) == (312, 1)

## Splitting - no effect
result = embed(embedder, texts[1]; split_instead_trunc = true)
@test size(result.embeddings) == (312, 1)

# split when long text
long_text = repeat("Hello, how are you? ", 200)
result = embed(embedder, long_text; split_instead_trunc = true)
@test size(result.embeddings) == (312, 3)
end

@testset "show-embedding" begin
Expand Down
32 changes: 32 additions & 0 deletions test/encoder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,36 @@ using FlashRank: RankerModel, tokenize, encode
102 1012; 0 102]
@test all(iszero, output[2])
@test output[3] == [1 1; 1 1; 1 1; 1 1; 1 1; 1 1; 1 1; 1 1; 0 1]

## # Encode with splitting
long_text = repeat("Hello, how are you? ", 100)

output = encode(encoder, long_text; split_instead_trunc = true)
@test size(output[1]) == (512, 2)
## check that tokens were added correctly
@test output[1][1, :] == [101, 101]
@test output[1][end, :] == [102, 0]
@test output[1][93, 2] == 102
@test output[1][94, 2] == 0
@test size(output[3]) == (512, 2)
@test output[3][1, :] == [1, 1]
@test output[3][end, :] == [1, 0]
@test output[3][93, :] == [1, 1]
@test output[3][94, :] == [1, 0]
@test iszero(output[2])

## Do not add special tokens
output = encode(
encoder, long_text; split_instead_trunc = true, add_special_tokens = false)
@test size(output[1]) == (512, 2)
## check that tokens were added correctly
@test output[1][1, :] == [7592, 2129]
@test output[1][end, :] == [1010, 0]
@test output[1][93, 2] == 0
@test size(output[3]) == (512, 2)
@test output[3][1, :] == [1, 1]
@test output[3][end, :] == [1, 0]
@test output[3][89, :] == [1, 1]
@test output[3][90, :] == [1, 0]
@test iszero(output[2])
end

2 comments on commit 3455ecc

@svilupp
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Added

  • Simplify embedding with the tiny models and provide native support for splitting of long strings into several smaller ones (kwarg split_instead_trunc)

Commits

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/110090

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.0 -m "<description of version>" 3455ecc588bd509cc83cee9e7030a4534b283b98
git push origin v0.4.0

Please sign in to comment.