Skip to content

Commit

Permalink
Add more MiniLM rankers
Browse files Browse the repository at this point in the history
Add more MiniLM rankers
  • Loading branch information
svilupp authored Jun 11, 2024
2 parents 829d207 + c3d10ee commit a401900
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 32 deletions.
21 changes: 21 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Changelog
All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

### Fixed

## [0.2.0]

### Added
- Added Sentence Transformers MiniLM L-4 and MiniLM-L-6 models with full precision to provide more choice between TinyBert and MiniLM L-12

## [0.1.0]

### Added
- Initial release
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FlashRank"
uuid = "22cc3f58-1757-4700-bb45-2032706e5a8d"
authors = ["J S <[email protected]> and contributors"]
version = "0.1.0"
version = "0.2.0"

[deps]
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
Expand Down
21 changes: 13 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@

[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://svilupp.github.io/FlashRank.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://svilupp.github.io/FlashRank.jl/dev/) [![Build Status](https://github.com/svilupp/FlashRank.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/svilupp/FlashRank.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/svilupp/FlashRank.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/svilupp/FlashRank.jl) [![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

FlashRank.jl is inspired by the awesome [FlashRank Python package](https://github.com/PrithivirajDamodaran/FlashRank), originally developed by Prithiviraj Damodaran. This package leverages model weights from [Prithiviraj's repository on Hugging Face](https://huggingface.co/prithivida/flashrank) and provides a fast and efficient way to rank documents relevant to any given query without GPUs and large dependencies. This enhances Retrieval Augmented Generation (RAG) pipelines by prioritizing the most suitable documents. The smallest model can be run on almost any machine.
FlashRank.jl is inspired by the awesome [FlashRank Python package](https://github.com/PrithivirajDamodaran/FlashRank), originally developed by Prithiviraj Damodaran. This package leverages model weights from [Prithiviraj's HF repo](https://huggingface.co/prithivida/flashrank) and [Svilupp's HF repo](https://huggingface.co/svilupp/onnx-cross-encoders) to provide **a fast and efficient way to rank documents relevant to any given query without GPUs and large dependencies**.

This enhances Retrieval Augmented Generation (RAG) pipelines by prioritizing the most suitable documents. The smallest model can be run on almost any machine.

## Features
- Two ranking models:
- **Tiny (~4MB):** [ms-marco-TinyBERT-L-2-v2 (default)](https://huggingface.co/cross-encoder/ms-marco-TinyBERT-L-2) (alias `:tiny`)
- **Mini (~23MB):** [ms-marco-MiniLM-L-12-v2](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2) (alias `:mini`)
- Four ranking models:
- **Tiny (~4MB, INT8):** [ms-marco-TinyBERT-L-2-v2 (default)](https://huggingface.co/cross-encoder/ms-marco-TinyBERT-L-2) (alias `:tiny`)
- **MiniLM L-4 (~70MB, FP32):** [ms-marco-MiniLM-L-4-v2 ONNX](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-4-v2) (alias `:mini4`)
- **MiniLM L-6 (~83.4MB, FP32):** [ms-marco-MiniLM-L-6-v2 ONNX](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2) (alias `:mini6`)
- **MiniLM L-12 (~23MB, INT8):** [ms-marco-MiniLM-L-12-v2](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2) (alias `:mini` or `mini12`)
- Lightweight dependencies, avoiding heavy frameworks like Flux and CUDA for ease of integration.

How fast is it?
With the Tiny model, you can rank 100 documents in ~0.1 seconds on a laptop. With the Mini model, you can rank 20 documents in ~0.5 seconds to pick the best chunks for your context.
With the Tiny model, you can rank 100 documents in ~0.1 seconds on a laptop. With the MiniLM (12 layers) model, you can rank 100 documents in ~0.4 seconds.

Tip: Pick the largest model that you can afford with your latency budget, ie, MiniLM L-12 is the slowest but has the best accuracy.

Note that we're using BERT models with a maximum chunk size of 512 tokens (anything over will be truncated).

Expand Down Expand Up @@ -80,9 +86,8 @@ result = airag(cfg, index; question, return_all = true)

## Acknowledgments
- [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) and [Transformers.jl](https://github.com/chengchingwen/Transformers.jl) have been essential in the development of this package.
- Special thanks to Prithiviraj Damodaran for the original FlashRank and model weights.
- Special thanks to Prithiviraj Damodaran for the original FlashRank and the INT8 quantized model weights.
- And to Transformers.jl for the WordPiece implementation and BERT tokenizer which have been forked for this package (to minimize dependencies).

## Roadmap
- [ ] Provide package extension for PromptingTools
- [ ] Extend support for more models
- [ ] Provide package extension for PromptingTools
12 changes: 8 additions & 4 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ CurrentModule = FlashRank
FlashRank.jl is inspired by the awesome [FlashRank Python package](https://github.com/PrithivirajDamodaran/FlashRank), originally developed by Prithiviraj Damodaran. This package leverages model weights from [Prithiviraj's repository on Hugging Face](https://huggingface.co/prithivida/flashrank) and provides a fast and efficient way to rank documents relevant to any given query without GPUs and large dependencies. This enhances Retrieval Augmented Generation (RAG) pipelines by prioritizing the most suitable documents. The smallest model can be run on almost any machine.

## Features
- Two ranking models:
- **Tiny (~4MB):** [ms-marco-TinyBERT-L-2-v2 (default)](https://huggingface.co/cross-encoder/ms-marco-TinyBERT-L-2) (alias `:tiny`)
- **Mini (~23MB):** [ms-marco-MiniLM-L-12-v2](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2) (alias `:mini`)
- Four ranking models:
- **Tiny (~4MB, INT8):** [ms-marco-TinyBERT-L-2-v2 (default)](https://huggingface.co/cross-encoder/ms-marco-TinyBERT-L-2) (alias `:tiny`)
- **MiniLM L-4 (~70MB, FP32):** [ms-marco-MiniLM-L-4-v2 ONNX](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-4-v2) (alias `:mini4`)
- **MiniLM L-6 (~83.4MB, FP32):** [ms-marco-MiniLM-L-6-v2 ONNX](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2) (alias `:mini6`)
- **MiniLM L-12 (~23MB, INT8):** [ms-marco-MiniLM-L-12-v2](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2) (alias `:mini` or `mini12`)
- Lightweight dependencies, avoiding heavy frameworks like Flux and CUDA for ease of integration.

How fast is it?
With the Tiny model, you can rank 100 documents in ~0.1 seconds on a laptop. With the Mini model, you can rank 20 documents in ~0.5 seconds to pick the best chunks for your context.
With the Tiny model, you can rank 100 documents in ~0.1 seconds on a laptop. With the MiniLM (12 layers) model, you can rank 100 documents in ~0.4 seconds.

Tip: Pick the largest model that you can afford with your latency budget, ie, MiniLM L-12 is the slowest but has the best accuracy.

Note that we're using BERT models with a maximum chunk size of 512 tokens (anything over will be truncated).

Expand Down
14 changes: 14 additions & 0 deletions src/FlashRank.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ function __init__()
"https://huggingface.co/prithivida/flashrank/resolve/main/ms-marco-MiniLM-L-12-v2.zip";
post_fetch_method = unpack
))
register(DataDep("ms-marco-MiniLM-L-4-v2",
"""
MiniLM-L-4-v2 cross-encoder trained on the ms-marco dataset, FP32 precision.
""",
"https://huggingface.co/svilupp/onnx-cross-encoders/resolve/main/ms-marco-MiniLM-L-4-v2-onnx.zip";
post_fetch_method = unpack
))
register(DataDep("ms-marco-MiniLM-L-6-v2",
"""
MiniLM-L-6-v2 cross-encoder trained on the ms-marco dataset, FP32 precision.
""",
"https://huggingface.co/svilupp/onnx-cross-encoders/resolve/main/ms-marco-MiniLM-L-6-v2-onnx.zip";
post_fetch_method = unpack
))
end

end
70 changes: 52 additions & 18 deletions src/loader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,65 @@ end
function load_model(alias::Symbol)
model_root = if alias == :tiny
datadep"ms-marco-TinyBERT-L-2-v2"
elseif alias == :mini
elseif alias == :mini12 || alias == :mini
datadep"ms-marco-MiniLM-L-12-v2"
elseif alias == :mini6
datadep"ms-marco-MiniLM-L-6-v2"
elseif alias == :mini4
datadep"ms-marco-MiniLM-L-4-v2"
else
throw(ArgumentError("Invalid model type"))
end

# Tokenizer setup
tok_path = find_file(model_root, "tokenizer.json")
@assert !isnothing(tok_path) "Could not find tokenizer.json in $model_root"
tok_config = JSON3.read(tok_path)
model_config = tok_config[:model]
vocab_list = reverse_keymap_to_list(model_config[:vocab])
extract_and_add_tokens!(tok_config[:added_tokens], vocab_list)
## 0-based indexing as we provide it to models trained in python
vocab = Dict(k => i - 1 for (i, k) in enumerate(vocab_list))

wp = WordPiece(vocab_list, model_config[:unk_token];
max_char = model_config[:max_input_chars_per_word],
subword_prefix = model_config[:continuing_subword_prefix])

## We always assume lowercasing with our current tokenizer implementation
@assert get(tok_config[:normalizer], :lowercase, true) "Tokenizer must be lowercased. Model implementation is not compatible."
## We assume truncation of 512 if not provided
trunc = get(tok_config, :truncation, nothing) |> x -> isnothing(x) ? 512 : x
enc = BertTextEncoder(wp, vocab; trunc)
tok_config_path = find_file(model_root, "tokenizer_config.json")
vocab_path = find_file(model_root, "vocab.txt")
if !isnothing(tok_path)
## Load from tokenizer.json
tok_config = JSON3.read(tok_path)
model_config = tok_config[:model]
vocab_list = reverse_keymap_to_list(model_config[:vocab])
extract_and_add_tokens!(tok_config[:added_tokens], vocab_list)
## 0-based indexing as we provide it to models trained in python
vocab = Dict(k => i - 1 for (i, k) in enumerate(vocab_list))

wp = WordPiece(vocab_list, model_config[:unk_token];
max_char = model_config[:max_input_chars_per_word],
subword_prefix = model_config[:continuing_subword_prefix])

## We always assume lowercasing with our current tokenizer implementation
@assert get(tok_config[:normalizer], :lowercase, true) "Tokenizer must be lowercased. Model implementation is not compatible."
## We assume truncation of 512 if not provided
trunc = get(tok_config, :truncation, nothing) |> x -> isnothing(x) ? 512 : x
enc = BertTextEncoder(wp, vocab; trunc)
elseif !isnothing(tok_config_path) && !isnothing(vocab_path)
## Load from tokenizer_config.json
tok_config = JSON3.read(tok_config_path)
vocab_list = readlines(vocab_path)

## Double check that all tokens are in vocab
@assert all(
sym -> in(tok_config[sym], vocab_list), [
:unk_token, :cls_token, :sep_token, :pad_token])

vocab = Dict(k => i - 1 for (i, k) in enumerate(vocab_list))

wp = WordPiece(vocab_list, tok_config[:unk_token];
max_char = get(tok_config, :max_input_chars_per_word, 200),
subword_prefix = get(tok_config, :continuing_subword_prefix, "##"))

## We always assume lowercasing with our current tokenizer implementation
@assert get(tok_config, :do_lower_case, true) "Tokenizer must be lowercased. Model implementation is not compatible."
## We assume truncation of 512 if not provided
trunc = get(tok_config, :model_max_length, nothing) |> x -> isnothing(x) ? 512 : x
enc = BertTextEncoder(wp, vocab; trunc,
startsym = tok_config[:cls_token],
endsym = tok_config[:sep_token],
padsym = tok_config[:pad_token])
else
throw(ArgumentError("Could not find tokenizer.json or tokenizer_config.json + vocab.txt in $model_root"))
end

## Double-check that padding is ID 0, because we pad with 0s in encode() function
@assert enc.vocab[enc.padsym]==0 "Padding token must be first token in vocabulary with token ID 0."
Expand Down
7 changes: 6 additions & 1 deletion test/loader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,9 @@ using FlashRank: load_model, BertTextEncoder, ORT
@test enc isa BertTextEncoder
@test sess isa ORT.InferenceSession
@test_throws ArgumentError load_model(:notexistent)
end

## Load different pipelines
enc, sess = load_model(:mini4)
@test enc isa BertTextEncoder
@test sess isa ORT.InferenceSession
end

2 comments on commit a401900

@svilupp
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Added

  • Added Sentence Transformers MiniLM L-4 and MiniLM-L-6 models with full precision (in ONNX) to provide more choice between TinyBert and MiniLM L-12

Commits

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/108703

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" a401900610208adac64186a29e84846344ca4e40
git push origin v0.2.0

Please sign in to comment.