Xinghui Li1, Jingyi Lu2, Kai Han2, Victor Prisacariu1
1Active Vision Lab, University of Oxford 2Visual AI Lab, University of Hong Kong
The environment can be easily installed through conda and pip. After downloading the code, run the following command:
$conda create -n sd4match python=3.10
$conda activate sd4match
$conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
$conda install xformers -c xformers
$pip install yacs pandas scipy einops matplotlib triton timm diffusers accelerate transformers datasets tensorboard pykeops scikit-learn
- Download PF-Pascal dataset from link.
- Rename the outermost directory from
PF-dataset-PASCAL
topf-pascal
. - Download lists for image pairs from link.
- Place the lists for image pairs under
pf-pascal
directory. The structure should be:
pf-pascal
├── __MACOSX
├── PF-dataset-PASCAL
├── trn_pairs.csv
├── val_pairs.csv
└── test_pairs.csv
- Download PF-Willow dataset from the link.
- Rename the outermost directory from
PF-dataset
topf-willow
. - Download lists for image pairs from link.
- Place the lists for image pairs under
pf-willow
directory. The structure should be:
pf-willow
├── __MACOSX
├── PF-dataset
└── test_pairs.csv
- Download SPair-71K dataset from link. After extraction, No more action required.
- Create symbol links to PF-Pascal, PF-Willow and SPair-71k dataset in
asset
directory. This can be done by:
ln -s /your/path/to/pf-pascal asset/pf-pascal
ln -s /your/path/to/pf-willow asset/pf-willow
ln -s /your/path/to/SPair-71k asset/SPair-71k
- Create a directory named
sd4match
underasset
. This is to save pre-computed features, checkpoints and learned prompts.
# create sd4match directly
mkdir asset/sd4match
# or create sd4match at anywhere you want and use symbol link
ln -s /your/path/to/sd4match asset/sd4match
- Run
pre_compute_dino_feature.py
. This would pre-compute DINOv2 feature for all images in PF-Pascal, PF-Willow and SPair-71k and save them inasset/sd4match
. The structure should be:
sd4match
└── asset
└── DINOv2
├── pfpascal
| └── cached_output.pt
├── pfwillow
| └── cached_output.pt
└── spair
└── cached_output.pt
The bash scripts for training are provided in script
directory, and organized based on training data and prompt type.
For example, to train SD4Match-CPM
on SPair-71k dataset, run:
cd script/spair
sh sd4match_cpm.sh
The batch size per GPU is currently set to 3
, which would take about 22G
GPU memory to train. Reduce the batch size if necessary. The training script will generate two directories in asset/sd4match
: log
and prompt
. Tensorboard logs and training states are saved in log
, and learned prompts are saved in prompt
. For example, training SD4Match-CPM
on SPair-71k dataset will generate:
sd4match
├── asset
| ├── ...
├── log
| └── spair
| └── CPM_spair_sd2-1_Pair-DINO-Feat-G25-C50_constant_lr0.01
| └── ...(Tensorboard log and training states)
└── prompt
└── CPM_spair_sd2-1_Pair-DINO-Feat-G25-C50
└── ckpt.pt
To replicate our results reported in the paper on SPair-71k, either learning the prompt by yourself or downloading our pre-trained prompt and place them under asset/sd4match/prompt
directory. Run:
python test.py \
--dataset spair \
--prompt_type $PROMPT_NAME \
--timestep 50 \
--layer 1
Replace $PROMPT_NAME
with prompt your want. It needs to have a corresponding directory under asset/sd4match/prompt
. For example, to evaluate SD4Match-CPM
, run:
python test.py \
--dataset spair \
--prompt_type CPM_spair_sd2-1_Pair-DINO-Feat-G25-C50 \
--timestep 50 \
--layer 1
Our pretrained prompt can be downloaded through link.
Kai Han is supported by Hong Kong Research Grant Council - Early Career Scheme (Grant No. 27208022), National Natural Science Foundation of China (Grant No. 62306251), and HKU Seed Fund for Basic Research.
We also sincerely thank Zirui Wang for his inspiring discussion.
@misc{li2023sd4match,
title={SD4Match: Learning to Prompt Stable Diffusion Model for Semantic Matching},
author={Xinghui Li and Jingyi Lu and Kai Han and Victor Prisacariu},
year={2023},
eprint={2310.17569},
archivePrefix={arXiv},
primaryClass={cs.CV}
}