Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add files via upload #10

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
96 changes: 54 additions & 42 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,55 +1,65 @@
# SGRAF
PyTorch implementation for AAAI2021 paper of [**“Similarity Reasoning and Filtration for Image-Text Matching”**](https://drive.google.com/file/d/1tAE_qkAxiw1CajjHix9EXoI7xu2t66iQ/view?usp=sharing).
It is built on top of the [SCAN](https://github.com/kuanghuei/SCAN) and [Cross-modal_Retrieval_Tutorial](https://github.com/Paranioar/Cross-modal_Retrieval_Tutorial).
*PyTorch implementation for AAAI2021 paper of [**“Similarity Reasoning and Filtration for Image-Text Matching”**](https://drive.google.com/file/d/1tAE_qkAxiw1CajjHix9EXoI7xu2t66iQ/view?usp=sharing).*

*It is built on top of the [SCAN](https://github.com/kuanghuei/SCAN) and [Awesome_Matching](https://github.com/Paranioar/Awesome_Matching_Pretraining_Transfering).*

*We have released two versions of SGRAF: **Branch `main` for python2.7**; **Branch `python3.6` for python3.6**.*

*If any problems, please contact me at [email protected]. ([email protected] is deprecated)*


## Introduction

**The framework of SGRAF:**

<img src="./fig/model.png" width = "100%" height="50%">

**The updated results (Better than the original paper)**
<table>
<tr> <td rowspan="2">Dataset</td> <td rowspan="2", align="center">Module</td>
<td colspan="3", align="center">Sentence retrieval</td> <td colspan="3", align="center">Image retrieval</td> </tr>
<tr> <td>R@1</td><td>R@5</td><td>R@10</td> <td>R@1</td><td>R@5</td><td>R@10</td> </tr>
<tr> <td rowspan="3">Flick30k</td>
<td>SAF</td> <td>75.6</td><td>92.7</td><td>96.9</td> <td>56.5</td><td>82.0</td><td>88.4</td> </tr>
<tr> <td>SGR</td> <td>76.6</td><td>93.7</td><td>96.6</td> <td>56.1</td><td>80.9</td><td>87.0</td> </tr>
<tr> <td>SGRAF</td> <td>78.4</td><td>94.6</td><td>97.5</td> <td>58.2</td><td>83.0</td><td>89.1</td> </tr>
<tr> <td rowspan="3">MSCOCO1k</td>
<td>SAF</td> <td>78.0</td><td>95.9</td><td>98.5</td> <td>62.2</td><td>89.5</td><td>95.4</td> </tr>
<tr> <td>SGR</td> <td>77.3</td><td>96.0</td><td>98.6</td> <td>62.1</td><td>89.6</td><td>95.3</td> </tr>
<tr> <td>SGRAF</td> <td>79.2</td><td>96.5</td><td>98.6</td> <td>63.5</td><td>90.2</td><td>95.8</td> </tr>
<tr> <td rowspan="3">MSCOCO5k</td>
<td>SAF</td> <td>55.5</td><td>83.8</td><td>91.8</td> <td>40.1</td><td>69.7</td><td>80.4</td> </tr>
<tr> <td>SGR</td> <td>57.3</td><td>83.2</td><td>90.6</td> <td>40.5</td><td>69.6</td><td>80.3</td> </tr>
<tr> <td>SGRAF</td> <td>58.8</td><td>84.8</td><td>92.1</td> <td>41.6</td><td>70.9</td><td>81.5</td> </tr>

## Requirements
We recommended the following dependencies for ***Branch `python3.6`***.

* Python 3.6
* [PyTorch (>=0.4.1)](http://pytorch.org/)
* [NumPy (>=1.12.1)](http://www.numpy.org/)
* [TensorBoard](https://github.com/TeamHG-Memex/tensorboard_logger)
[Note]: The code applies to ***Python3.6 + Pytorch1.7***.

## Acknowledgements
Thanks to the exploration and discussion with [KevinLight831](https://github.com/KevinLight831), we made some adjustments as follows:
**1. Adjust `evaluation.py`**:
*for i, (k, v) in enumerate(self.meters.iteritems()):*
***------>** ```for i, (k, v) in enumerate(self.meters.items()):```*
*for k, v in self.meters.iteritems():*
***------>** ```for k, v in self.meters.items():```*

**2. Adjust `model.py`**:
*cap_emb = (cap_emb[:, :, :cap_emb.size(2)/2] + cap_emb[:, :, cap_emb.size(2)/2:])/2*
***------>** ```cap_emb = (cap_emb[:, :, :cap_emb.size(2)//2] + cap_emb[:, :, cap_emb.size(2)//2:])/2```*

**3. Adjust `data.py`**:
*img_id = index/self.im_div*
***------>** ```img_id = index//self.im_div```*

</table>
*for line in open(loc+'%s_caps.txt' % data_split, 'rb'):*
*tokens = nltk.tokenize.word_tokenize(str(caption).lower().decode('utf-8'))*

## Requirements
We recommended the following dependencies.
***------>** ```for line in open(loc+'%s_caps.txt' % data_split, 'rb'):```*
***------>** ```tokens = nltk.tokenize.word_tokenize(caption.lower().decode('utf-8'))```*

* Python **(2.7 not 3.\*)**
* [PyTorch](http://pytorch.org/) **(0.4.1 not 1.\*)**
* [NumPy](http://www.numpy.org/) **(>1.12.1)**
* [TensorBoard](https://github.com/TeamHG-Memex/tensorboard_logger)
* Punkt Sentence Tokenizer:
```python
import nltk
nltk.download()
> d punkt
```
or

***------>** ```for line in open(loc+'%s_caps.txt' % data_split, 'r', encoding='utf-8'):```*
***------>** ```tokens = nltk.tokenize.word_tokenize(str(caption).lower())```*

## Download data and vocab
We follow [SCAN](https://github.com/kuanghuei/SCAN) to obtain image features and vocabularies, which can be downloaded by using:

```bash
wget https://scanproject.blob.core.windows.net/scan-data/data.zip
wget https://scanproject.blob.core.windows.net/scan-data/vocab.zip
https://www.kaggle.com/datasets/kuanghueilee/scan-features
```
Another download link is available below:

```bash
https://drive.google.com/drive/u/0/folders/1os1Kr7HeTbh8FajBNegW8rjJf6GIhFqC
```

## Pre-trained models and evaluation
Expand Down Expand Up @@ -82,16 +92,18 @@ For Flickr30K:

If SGRAF is useful for your research, please cite the following paper:

@inproceedings{Diao2021SGRAF,
title={Similarity Reasoning and Filtration for Image-Text Matching},
author={Diao, Haiwen and Zhang, Ying and Ma, Lin and Lu, Huchuan},
booktitle={AAAI},
year={2021}
}
@inproceedings{Diao2021SGRAF,
title={Similarity reasoning and filtration for image-text matching},
author={Diao, Haiwen and Zhang, Ying and Ma, Lin and Lu, Huchuan},
booktitle={Proceedings of the AAAI conference on artificial intelligence},
volume={35},
number={2},
pages={1218--1226},
year={2021}
}

## License

[Apache License 2.0](http://www.apache.org/licenses/LICENSE-2.0).
If any problems, please contact me at ([email protected]) or ([email protected]).


22 changes: 16 additions & 6 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@ def __init__(self, data_path, data_split, vocab):

# load the raw captions
self.captions = []
with open(loc+'%s_caps.txt' % data_split, 'rb') as f:
for line in f:
self.captions.append(line.strip())

# -------- The main difference between python2.7 and python3.6 --------#
# The suggestion from Hongguang Zhu (https://github.com/KevinLight831)
# ---------------------------------------------------------------------#
# for line in open(loc+'%s_caps.txt' % data_split, 'r', encoding='utf-8'):
# self.captions.append(line.strip())

for line in open(loc+'%s_caps.txt' % data_split, 'rb'):
self.captions.append(line.strip())

# load the image features
self.images = np.load(loc+'%s_ims.npy' % data_split)
Expand All @@ -40,14 +46,18 @@ def __init__(self, data_path, data_split, vocab):

def __getitem__(self, index):
# handle the image redundancy
img_id = index/self.im_div
img_id = index//self.im_div
image = torch.Tensor(self.images[img_id])
caption = self.captions[index]
vocab = self.vocab

# -------- The main difference between python2.7 and python3.6 --------#
# The suggestion from Hongguang Zhu(https://github.com/KevinLight831)
# ---------------------------------------------------------------------#
# tokens = nltk.tokenize.word_tokenize(str(caption).lower())

# convert caption (string) to word ids.
tokens = nltk.tokenize.word_tokenize(
str(caption).lower().decode('utf-8'))
tokens = nltk.tokenize.word_tokenize(caption.lower().decode('utf-8'))
caption = []
caption.append(vocab('<start>'))
caption.extend([vocab(token) for token in tokens])
Expand Down
10 changes: 5 additions & 5 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __str__(self):
"""Concatenate the meters in one log line
"""
s = ''
for i, (k, v) in enumerate(self.meters.iteritems()):
for i, (k, v) in enumerate(self.meters.items()):
if i > 0:
s += ' '
s += k + ' ' + str(v)
Expand All @@ -68,7 +68,7 @@ def __str__(self):
def tb_log(self, tb_logger, prefix='', step=None):
"""Log using tensorboard
"""
for k, v in self.meters.iteritems():
for k, v in self.meters.items():
tb_logger.log_value(prefix + k, v.val, step=step)


Expand Down Expand Up @@ -125,7 +125,7 @@ def evalrank(model_path, data_path=None, split='dev', fold5=False):
opt.data_path = data_path

# load vocabulary used by the model
vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
vocab = deserialize_vocab('./vocab/%s_vocab.json' % opt.data_name)
opt.vocab_size = len(vocab)

# construct model
Expand Down Expand Up @@ -295,5 +295,5 @@ def t2i(images, captions, caplens, sims, npts=None, return_ranks=False):


if __name__ == '__main__':
evalrank("/apdcephfs/share_1313228/home/haiwendiao/SGRAF-master/runs/SAF_module/checkpoint/model_best.pth.tar",
data_path="/apdcephfs/share_1313228/home/haiwendiao", split="test", fold5=False)
evalrank("./runs/Flickr30K_SGRAF/f30k_SAF/model_best.pth.tar",
data_path='./data', split="test", fold5=False)
2 changes: 1 addition & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def forward(self, captions, lengths):
cap_emb, _ = pad_packed_sequence(out, batch_first=True)

if self.use_bi_gru:
cap_emb = (cap_emb[:, :, :cap_emb.size(2)/2] + cap_emb[:, :, cap_emb.size(2)/2:])/2
cap_emb = (cap_emb[:, :, :cap_emb.size(2)//2] + cap_emb[:, :, cap_emb.size(2)//2:])/2

# normalization in the joint embedding space
if not self.no_txtnorm:
Expand Down
85 changes: 40 additions & 45 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,56 +1,51 @@
backports.functools-lru-cache==1.6.1
backports.weakref==1.0.post1
bleach==1.5.0
boto3==1.17.8
botocore==1.20.8
certifi==2019.11.28
cffi==1.14.0
absl-py==0.12.0
astor==0.8.1
boto3==1.17.53
botocore==1.20.53
cached-property==1.5.2
certifi==2020.12.5
cffi==1.14.5
chardet==4.0.0
click==7.1.2
cloudpickle==1.3.0
cycler==0.10.0
Cython==0.29.13
decorator==4.4.2
enum34==1.1.10
funcsigs==1.0.2
futures==3.3.0
html5lib==0.9999999
docopt==0.6.2
gast==0.4.0
google-pasta==0.2.0
grpcio==1.37.0
h5py==3.1.0
idna==2.10
importlib-metadata==3.10.1
jmespath==0.10.0
joblib==0.14.1
kiwisolver==1.1.0
Markdown==3.1.1
matplotlib==2.2.4
mock==3.0.5
networkx==2.2
nltk==3.4.5
numpy==1.16.5
joblib==1.0.1
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.2
Markdown==3.3.4
mkl-fft==1.3.0
mkl-random==1.1.1
mkl-service==2.3.0
nltk==3.6.1
numpy==1.16.4
olefile==0.46
opencv-python==4.2.0.32
pandas==0.24.2
Pillow==6.2.1
protobuf==3.12.2
ptflops==0.6.4
pycocotools==2.0
Pillow==8.2.0
pipreqs==0.4.10
protobuf==3.15.8
pycparser==2.20
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2020.1
PyWavelets==1.0.3
regex==2020.11.13
regex==2021.4.4
requests==2.25.1
s3transfer==0.3.4
sacremoses==0.0.43
scikit-image==0.14.5
scipy==1.2.3
singledispatch==3.4.0.3
s3transfer==0.3.7
scipy==1.5.4
six==1.15.0
subprocess32==3.5.4
tensorboard==1.14.0
tensorboard-logger==0.1.0
tensorflow==1.4.0
tensorflow-tensorboard==0.4.0
torch==0.4.1.post2
torchvision==0.2.0
tqdm==4.56.2
urllib3==1.26.3
tensorflow-estimator==1.14.0
tensorflow-gpu==1.14.0
termcolor==1.1.0
torch==1.1.0
torchvision==0.3.0
tqdm==4.60.0
typing-extensions==3.7.4.3
urllib3==1.26.4
Werkzeug==1.0.1
wrapt==1.12.1
yarg==0.1.9
zipp==3.4.1
16 changes: 16 additions & 0 deletions visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
# Please refer to https://github.com/Paranioar/RCAR for related visualization code.
# It now includes visualize_attention_mechanism, visualize_similarity_distribution, visualize_rank_result, and etc.

# I will continue to update more related visualization codes when I am free.
# If you find these codes are useful, please cite our papers and star our projects. (We do need it! HaHaHaHa.)
# Thanks for the interest in our projects.
"""








8 changes: 0 additions & 8 deletions vocab.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
# -----------------------------------------------------------
# Stacked Cross Attention Network implementation based on
# https://arxiv.org/abs/1803.08024.
# "Stacked Cross Attention for Image-Text Matching"
# Kuang-Huei Lee, Xi Chen, Gang Hua, Houdong Hu, Xiaodong He
#
# Writen by Kuang-Huei Lee, 2018
# ---------------------------------------------------------------
"""Vocabulary wrapper"""

import nltk
Expand Down