NUDGE is a lightweight tool to fine-tune pre-trained embeddings for retrieval and RAG pipelines, presented in the paper NUDGE: Lightweight Non-Parametric Embedding Fine-Tuning (see this blog post for a simple overview). It runs in minutes and often improves retrieval accuracy by over 10%.
NUDGE modifies data embeddings non-parametrically, i.e., it does not change any model parameters but instead moves the data embeddings themselves to maximize accuracy. NUDGE solves a constrained optimization problem to do so, moving data embeddings towards the embedding of training queries for which they are the ground-truth answer. NUDGE-M and NUDGE-N are two variants of the approach, each solving the optimization problem with different constraints.
As the figure above shows, NUDGE changes data embeddings within a constrained region (shown in dashed lines) to maximize similarity with training queries. Data embeddings in the figure are colored based on queries for which they are the ground-truth answers.Documentation is available along with the code in nudge/nudge.py. We further discuss how to install and use NUDGE
To install NUDGE, run
pip install nudge-ft
NUDGE operates on embeddings. It fine-tunes data embeddings given training and validation queries. This package provides two classes, NUDGEM
and NUDGEN
to do so, implementing NUDGE-N and NUDGE-M in the paper. Both have the same interface and can be imported as
from nudge import NUDGEN, NUDGEM
To use either class, you need to have already embedded the documents and training/validation queries and have ground-truth answers for training/validation queries. Then, call
train_set = {'q_embs':train_q_embs, 'q_ans_indx':train_q_ans_indx}
val_set = {'q_embs':val_q_embs, 'q_ans_indx':val_q_ans_indx}
finetuned_embs_nudge_n = NUDGEN().finetune_embeddings(data_embs, train_set, val_set)
finetuned_embs_nudge_m = NUDGEM().finetune_embeddings(data_embs, train_set, val_set)
where data_embs
is a numpy array containing data embeddings, train_q_embs
and val_q_embs
are numpy arrays containing embeddings of training queries and train_q_ans_indx
and val_q_ans_indx
contain ground-truth query answers. train_q_ans_indx
/val_q_ans_indx
are nested python lists, where the i
-th item in train_q_ans_indx
/val_q_ans_indx
is the list of indexes of data records that are relevant to the i
-th query. That is, data_embs[train_q_ans_indx[i][j]]
is a positive data record for query train_q_embs[i]
.
After running the code, you can use finetuned_embs_nudge_n
or finetuned_embs_nudge_m
for similarity search instead of data_embs
.
An end-to-end example of using NUDGE
is shown below, to fine-tune embeddings on nfcorpus. The code is also available in this notebook, or alternatively can be run from the root of the repo with
python example.py
After installing the dependencies below.
NUDGE
does not embed the queries or data and operates on the embeddings directly. Thus, we first need to embed the data and queries. Here we use BAAI/bge-small-en-v1.5, and the sentence_transformers
library for embeddings, and use datasets
to load the nfcorpus
dataset.
Install the two libraries
pip install sentence_transformers datasets
Load dataset and embed the data and queries:
from util.utils import load_hf_datasets, embed_data_and_query_sets
dataset_name = 'nfcorpus'
dataset, query_sets = load_hf_datasets(dataset_name)
data_emb, query_sets = embed_data_and_query_sets(dataset, query_sets, "BAAI/bge-small-en-v1.5")
Fine-tune Embeddings (can alternatively use NUDGEM
):
from nudge import NUDGEN
finetuned_embs_nudge_n = NUDGEN().finetune_embeddings(data_emb, query_sets['train'], query_sets['dev'])
Use fine-tuned embeddings to answer queries:
from util.knnretriever import kNNRetriever
nudge_n_res = kNNRetriever(finetuned_embs_nudge_n).retrieve_topk_from_emb_batch(k=10, q_embeds=query_sets['test']['q_embs'])
Use non-fine-tuned embeddings to answer queries:
no_ft_res = kNNRetriever(data_emb).retrieve_topk_from_emb_batch(k=10, q_embeds=query_sets['test']['q_embs'])
Compare accuracy:
from util.utils import calc_metrics_batch
metrics = [('recall',10), ('ndcg',10)]
no_ft_accs = calc_metrics_batch(metrics,no_ft_res, query_sets['test']['q_ans_indx'], query_sets['test']['q_ans_indx_rel'])
nudgen_accs = calc_metrics_batch(metrics,nudge_n_res, query_sets['test']['q_ans_indx'], query_sets['test']['q_ans_indx_rel'])
print(f"No Fine-Tuning {metrics[0][0]}@{metrics[0][1]}: {no_ft_accs[0]*100:.1f}, {metrics[1][0]}@{metrics[1][1]}: {no_ft_accs[1]*100:.1f}")
print(f"NUDGE-N {metrics[0][0]}@{metrics[0][1]}: {nudgen_accs[0]*100:.1f}, {metrics[1][0]}@{metrics[1][1]}: {nudgen_accs[1]*100:.1f}")
Gives the output:
No Fine-Tuning recall@10: 31.4, ndcg@10: 33.9
NUDGE-N recall@10: 43.7, ndcg@10: 44.5
More Datasets. More text datasets are hosted on huggingface here (the datasets were created using this file). The above code can be run with any of nfcorpus
, scifact
, arguana
, fever
, nq
, triviaqa
and hotpotqa
.
For the larger dataset (i.e., fever
, nq
, triviaqa
and hotpotqa
), you may run out of memory if you run the above on GPU. The corpora contain many documents that are never a positive training sample but are loaded in GPU for fine-tuning. Instead, NUDGE
allows for an optimization where data records that are not an answer to any of the training or validation queries are filtered out and accounted for separately. Such data records still impact fine-tuning, but only through their impact on validation accuracy. The following code
max_nontest_index = -1
for split in ["train", "dev"]:
max_nontest_index = max(np.array([indx for curr_q_ans_indx in query_sets[split]['q_ans_indx'] for indx in curr_q_ans_indx]).max()+1, max_nontest_index)
nontrain_dataset = dataset.loc[max_nontest_index:]
if nontrain_dataset.shape[0] == 0:
embeddings = data_emb
nontrain_embeddings = None
else:
embeddings = data_emb[:max_nontest_index]
nontrain_embeddings = data_emb[max_nontest_index:]
new_embs_nudgen = NUDGEN().finetune_embeddings(embeddings, query_sets['train'], query_sets['dev'], nontrain_embeddings)
nudge_n_res = kNNRetriever(new_embs_nudgen, nontrain_embeddings).retrieve_topk_from_emb_batch(k=10, q_embeds=query_sets['test']['q_embs'])
gives the same result as
finetuned_embs_nudge_n = NUDGEN().finetune_embeddings(data_emb, query_sets['train'], query_sets['dev'])
nudge_n_res = kNNRetriever(finetuned_embs_nudge_n).retrieve_topk_from_emb_batch(k=10, q_embeds=query_sets['test']['q_embs'])
but uses less memory if many data records are not an answer to any training query. Complete code running nq
using the above optimization is available here.
To reproduce all baseline experiments in the paper (e.g, Tables 3-4) follow the instructions in the paper_exps branch of the repo.
Sepanta Zeighami, Zac Wellmer, and Aditya Parameswaran. "NUDGE: Lightweight Non-Parametric Fine-Tuning of Embeddings for Retrieval." arXiv preprint arXiv:2409.02343 (2024).
@article{zeighami2024nudge, title={NUDGE: Lightweight Non-Parametric Fine-Tuning of Embeddings for Retrieval}, author={Zeighami, Sepanta and Wellmer, Zac and Parameswaran, Aditya}, journal={arXiv preprint arXiv:2409.02343}, year={2024} }