This repo contains code for the paper Multilingual k-Nearest-Neighbor Machine Translation (EMNLP 2023).
Our code builds on the knn-box toolkit. You can install the toolkit and our extensions by:
conda create -n multilingual-knn python=3.8
conda activate multilingual-knn
git clone [email protected]:davidstap/multilingual-kNN-mt.git
cd multilingual-kNN-mt
pip install --editable ./
# Installing faiss with pip is not recommended, so we use conda
conda install faiss-gpu -c pytorch
You can prepare pretrained models and dataset by executing the following command:
cd knnbox-scripts
bash prepare_dataset_and_model.sh
It is straightforward to make changes to this script to change model and/or datasets.
After preparation, the next step is to build a datastore. The following script shows how to build a datastore from Hebrew (he) into English (en), using a multilingual model that only supports Hebrew and Arabic in and out of English directions. The result is a datastore, which is stored in the data-knnds-he_en
folder. To create datastores for other languages, simply change the variables.
SRC=he
TGT=en
DATA_PATH=/path/to/preprocessed/dataset
MODEL_PATH=/path/to/preprocessed/model
CUDA_VISIBLE_DEVICES=0 python knnbox-scripts/common/validate.py $DATA_PATH \
--task translation_multi_simple_epoch \
--langs en,he \
--lang-pairs en-ar,en-he,ar-en,he-en \
--seed 222 \
--source-lang $SRC \
--target-lang $TGT \
--path $MODEL_PATH/best.pt \
--model-overrides "{'eval_bleu': False, 'required_seq_len_multiple':1, 'load_alignments': False}" \
--dataset-impl mmap \
--valid-subset valid \
--skip-invalid-size-inputs-valid-test \
--max-tokens 4096 \
--criterion label_smoothed_cross_entropy \
--optimizer adam --adam-betas '(0.9, 0.98)' \
--user-dir knnbox/models \
--arch vanilla_knn_mt@transformer \
--knn-mode build_datastore \
--knn-datastore-path $DATA_PATH/data-knnds-${SRC}_${TGT} \
--share-decoder-input-output-embed \
--decoder-langtok \
--encoder-langtok src \
--fp16 \
--dropout 0.1 \
--label-smoothing 0.1
Bilingual datastores can be combined, resulting in a multilingual datastores. This leads to better performance. The following script can be used to combine he-en and ar-en:
DATA_PATH=/path/to/preprocessed/dataset
python knnbox-scripts/common/combine_datastores.py \
--path $DATA_PATH \
--pairs ar_en he_en \
--save_path $DATA_PATH/data-knnds-ar_en-he_en \
Use the following script to do inference using kNN, and calculate BLEU scores:
SRC=he
TGT=en
KNN_K=16
KNN_LAMBDA=0.5
KNN_TEMP=100
DATA_PATH=/path/to/preprocessed/dataset
MODEL_PATH=/path/to/preprocessed/model
RESULTS_PATH=${MODEL_PATH}/generations/gen-test-${SRC}_${TGT}-${DS}-k_${KNN_K}-l_${KNN_LAMBDA}-t_${KNN_TEMP}
CUDA_VISIBLE_DEVICES=0 python knnbox-scripts/common/generate.py $DATA_PATH \
--task translation_multi_simple_epoch \
--langs en,he \
--lang-pairs en-ar,en-he,ar-en,he-en \
--source-lang $SRC \
--target-lang $TGT \
--remove-bpe 'sentencepiece' \
--path $MODEL_PATH/best.pt \
--share-decoder-input-output-embed \
--results-path $RESULTS_PATH \
--dataset-impl mmap \
--beam 5 \
--gen-subset test \
--max-tokens 1024 \
--user-dir knnbox/models \
--arch vanilla_knn_mt@transformer \
--knn-mode inference \
--knn-datastore-path $DATA_PATH/data-knnds-${SRC}_${TGT} \
--knn-k $KNN_K \
--knn-lambda $KNN_LAMBDA \
--knn-temperature $KNN_TEMP \
--model-overrides "{'eval_bleu': False, 'required_seq_len_multiple':1, 'load_alignments': False}" \
--decoder-langtok \
--encoder-langtok src \
--skip-invalid-size-inputs-valid-test \
--sampling-method temperature \
--label-smoothing 0.1 \
--criterion label_smoothed_cross_entropy \
--fp16 \
--seed 222
grep ^T $RESULTS_PATH/generate-test.txt | LC_ALL=C sort -V | cut -f2- > $RESULTS_PATH/ref.txt
grep ^D $RESULTS_PATH/generate-test.txt | LC_ALL=C sort -V | cut -f3- > $RESULTS_PATH/hyp.txt
grep ^S $RESULTS_PATH/generate-test.txt | LC_ALL=C sort -V | cut -f2- > $RESULTS_PATH/src.txt
echo knn inference result for $SRC-$TGT K=$KNN_K L=$KNN_LAMBDA T=$KNN_TEMP saved at $RESULTS_PATH
sacrebleu $RESULTS_PATH/ref.txt -i $RESULTS_PATH/hyp.txt -m bleu | grep '"score"' | grep -oE '[0-9]+(\.[0-9]+)?' > $RESULTS_PATH/bleu.txt
cat $RESULTS_PATH/bleu.txt
KNN_K
is the number of neighbors, good values for the TED dataset are 8, 16, 32, 64KNN_LAMBDA
is the weight of the kNN distribution, good values are {0.2, 0.3, ..., 0.7}KNN_TEMP
is the temperature for the kNN distribution to make it more smooth. Good values are 10, 100.
Instead of fixing these values, one can do a hyperparameter search on the validation set, and use the best parameters to do decoding on the test set.
If you use our code, please cite the following:
@misc{stap2023multilingualknn,
title={Multilingual k-Nearest-Neighbor Machine Translation},
author={David Stap and Christof Monz},
booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing",
year={2023},
publisher = "Association for Computational Linguistics",
}