This repository contains the source code of our paper "Unsupervised Multilingual Dense Retrieval via Generative Pseudo Labeling", which has been accepted to Findings of EACL 2024.
- Python >= 3.8
- transformers
- torch
Please install all required packages listed in requirements.txt
by running the following command:
pip install -r requirements.txt
We use the XOR-TYDI QA dataset in our experiments, which includes XOR-Retrieve and XOR-Full. Please download the datasets from the following link and put it in the data
directory:
For XOR-Retrieve, where a question is written in the target language (e.g., Japanese) and a system is required to retrieve English document that answers the question:
For XOR-Full, where a question is written in the target language (e.g., Japanese) and a system is required to retrieve from multilingual documents and output a short answer in the target language:
We provide the trained checkpoints and processed data for the UMR model. You can download the files from the Google Drive.
Training the UMR model consists of two steps: unsupervised multilingual reranking and knowledge-distilled retriever training. We provide the commands for training the UMR model below. We trained one retriever for each task. The following commands are examples of training the UMR model using the XOR-Retrieve dataset. You can modify the commands to train the UMR model using the XOR-Full dataset.
First, we generate context embeddings using the mContriever model.
python3 generate_dense_embeddings.py \
--pretrained_model_cfg facebook/mcontriever \
--encoder_model_type hf_bert \
--sequence_length 256 \
--batch_size 256 \
--ctx_file data/enwiki_20190201_w100.tsv \
--shard_id 0 --num_shards 1 \
--out_file data/enwiki_embeddings_iter0 \
--fp16
If using a trained checkpoint, e.g., in the second iteration, specify the --model_file
argument to the checkpoint file.
Then, we retrieve top-k passages for each question using the mContriever model.
python3 dense_retriever.py \
--pretrained_model_cfg facebook/mcontriever \
--encoder_model_type hf_bert \
--sequence_length 256 \
--ctx_file data/enwiki_20190201_w100.tsv \
--qa_file data/xor_train_retrieve_eng_span.jsonl \
--encoded_ctx_file "data/enwiki_embeddings_iter0*" \
--out_file data/xor_retrieve_train_retrieved_iter0.json \
--n-docs 100 \
--validation_workers 1 --batch_size 128 --search_batch_size 512
If using a trained checkpoint, e.g., in the second iteration, specify the --model_file
argument to the checkpoint file.
python3 -m torch.distributed.launch --nproc_per_node {NGPUS} upr-multi.py \
--num-workers 2 \
--shard-size 2 \
--topk-passages 100 \
--hf-model-name "chaoweihuang/mt5-xl-lm-adapt" \
--use-gpu \
--use-fp16 \
--report-topk-accuracies 1 5 20 100 \
--retriever-topk-passages-path data/xor_retrieve_train_retrieved_iter0.json \
--reranker-output-dir data/xor_retrieve_train_retrieved_iter0_reranked
The reranked results will be saved in the data/xor_retrieve_train_retrieved_iter0_reranked/rank{RANK}.json
. You will need to merge the results from different ranks to obtain the final reranked results.
Once the reranked results are obtained, we can train the knowledge-distilled retriever using the reranked results. The following command is an example of training the knowledge-distilled retriever using the XOR-Retrieve dataset. You may want to split the reranked results into training and development sets and modify the --train_file
and --dev_file
arguments accordingly.
CUDA_VISIBLE_DEVICES=${DEVICES} python3 train_dense_encoder_with_llm.py \
--max_grad_norm 2.0 \
--encoder_model_type hf_bert \
--pretrained_model_cfg facebook/mcontriever \
--seed 12345 \
--sequence_length 256 \
--warmup_steps 1237 \
--num_contexts 16 \
--batch_size 16 \
--gradient_accumulation_steps 1 \
--inbatch_negative \
--temperature 10 \
--train_file data/xor_retrieve_train_retrieved_iter0_reranked/rank0.json \
--dev_file {DEV_FILE} \
--output_dir {CHECKPOINT_DIR} \
--learning_rate 2e-05 \
--num_train_epochs 10 \
--dev_batch_size 12 \
--val_av_rank_start_epoch 0 \
--global_loss_buf_sz 2000000 \
--eval_per_epoch 4 \
--grad_cache \
--q_chunk_size 16 \
--ctx_chunk_size 8 \
--restart \
--fp16 \
--wandb_project {WANDB_PROJECT} \
--wandb_name {WANDB_NAME}
This shows one iteration of the UMR training. To train the UMR model for more iterations, you need to repeat the above steps using the trained checkpoint.
We provide the commands for evaluating the UMR model on the XOR-Retrieve dataset. You can modify the commands to evaluate the UMR model on the XOR-Full dataset. Note that we provide the commands for evaluating on the development set as the test set is not publicly available.
python generate_dense_embeddings.py \
--model_file {CHECKPOINT_FILE} \
--encoder_model_type hf_bert \
--sequence_length 256 \
--batch_size 256 \
--ctx_file data/enwiki_20190201_w100.tsv \
--shard_id 0 --num_shards 1 \
--out_file data/enwiki_embeddings_iter1 \
--fp16
python3 dense_retriever.py \
--model_file {CHECKPOINT_FILE} \
--encoder_model_type hf_bert \
--sequence_length 256 \
--ctx_file data/enwiki_20190201_w100.tsv \
--qa_file data/xor_dev_retrieve_eng_span_v1_1.jsonl \
--encoded_ctx_file "data/enwiki_embeddings_iter1*" \
--out_file data/xor_retrieve_dev_retrieved_iter1.json \
--n-docs 100 \
--validation_workers 1 --batch_size 128 --search_batch_size 1024
python3 evals/eval_xor_retrieve.py \
--pred_file data/xor_retrieve_dev_retrieved_iter1.json \
--data_file data/xor_dev_retrieve_eng_span_v1_1.jsonl
Note that for evaluating on the XOR-Full dataset, since there is no ground truth for the retrieval task, we feed the retrieval results from UMR to the CORA reader (mGEN) and evaluate the end-to-end QA performance. Please refer to the CORA repository for how to run the mGEN model and evaluate the QA performance.
If you find our work useful, please cite the following paper:
@inproceedings{huang2024umr,
title = "Unsupervised Multilingual Dense Retrieval via Generative Pseudo Labeling",
author = "Huang, Chao-Wei and Hsu, Tsu-Yuan and Li, Chen-An and Hsu, Chen-Yu and Chen, Yun-Nung",
booktitle = "Proceedings of the 18th Conference of the European Chapter of the Association for Computational Linguistics",
month = mar,
year = "2024",
publisher = "Association for Computational Linguistics",
}