Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

review #2

Open
wants to merge 80 commits into
base: test
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
9893a9a
initial commit
Nov 10, 2021
062cc38
add bash scripts
Nov 11, 2021
41a4101
new scripts
Nov 11, 2021
b13790a
add symmetric search
Nov 15, 2021
6430db8
add second method for semantic search
Nov 22, 2021
fa3a3ae
config
Nov 23, 2021
f398f26
minor edits
Nov 30, 2021
a8bf943
refactor
Nov 30, 2021
23443fe
vcr
Dec 5, 2021
55fcf9b
minor fixes
Dec 6, 2021
c868171
reformats
sahithyaravi Dec 6, 2021
dc1d706
reqs
sahithyaravi Dec 6, 2021
72f5d68
remove prints add plots
sahithyaravi Dec 9, 2021
95e0d67
small changes
sahithyaravi Dec 10, 2021
957b315
add top method
sahithyaravi Dec 12, 2021
6baf5af
small changes
sahithyaravi Dec 12, 2021
2815ba2
srl
sahithyaravi Dec 13, 2021
5647c72
reqs
sahithyaravi Dec 13, 2021
fa23888
process exp
sahithyaravi Dec 13, 2021
8ccb3e1
qn phrase
sahithyaravi Dec 14, 2021
7d4d4a1
changes to process exp
sahithyaravi Dec 15, 2021
690018c
substring change
sahithyaravi Dec 15, 2021
69f3082
reorg1
sahithyaravi Jan 25, 2022
c6389e7
Added image search
aditya10 Feb 8, 2022
0c3afd9
Merge pull request #3 from sahithyaravi1493/image_search
sahithyaravi Feb 8, 2022
d86931a
restructuring v1
sahithyaravi Feb 8, 2022
93afaf6
fix errors
sahithyaravi Feb 8, 2022
c3fa513
add map
sahithyaravi Feb 8, 2022
9f7239e
semantic model change
sahithyaravi Feb 9, 2022
3cfd262
refactor
sahithyaravi Feb 10, 2022
11eb5bc
add plot file
sahithyaravi Feb 11, 2022
ef712cb
more edits
sahithyaravi Feb 11, 2022
cf270ba
remove parallel
sahithyaravi Feb 12, 2022
0d7746f
im search error
sahithyaravi Feb 13, 2022
96e3a70
correct image
sahithyaravi Feb 13, 2022
28e58e7
semv2
sahithyaravi Feb 27, 2022
71404f0
qn_to_phrase
sahithyaravi Feb 28, 2022
8990c07
process_expansions.py
sahithyaravi Feb 28, 2022
3d71d04
Merge branch 'main' of https://github.com/sahithyaravi1493/vlc_transf…
sahithyaravi Feb 28, 2022
e4278a6
grad norm to text
sahithyaravi Mar 13, 2022
baf00b8
clean up 1
sahithyaravi Jul 19, 2022
9d1e74f
cleanups
sahithyaravi Jul 20, 2022
5eea5c7
merges
sahithyaravi Jul 20, 2022
9b630c5
scripts
sahithyaravi Jul 20, 2022
6a580e2
compare aok
sahithyaravi Aug 8, 2022
04d9329
minor fixes
sahithyaravi Aug 8, 2022
aab3fca
reorgs
sahithyaravi Aug 10, 2022
b08d82e
compare working
sahithyaravi Aug 10, 2022
92b416e
code for new versions
sahithyaravi Aug 15, 2022
000d450
refactor
sahithyaravi Aug 15, 2022
46cd8bd
fixes
sahithyaravi Aug 16, 2022
bd0c4e8
entity match
sahithyaravi Aug 16, 2022
b6187c1
phrase entities
sahithyaravi Aug 16, 2022
87254bc
minor fixes
sahithyaravi Aug 18, 2022
6f41563
debugs
sahithyaravi Aug 18, 2022
76586ee
plot changes
sahithyaravi Aug 20, 2022
f6f95d9
expansions
sahithyaravi Aug 22, 2022
a5ac4f0
fixes
sahithyaravi Aug 22, 2022
1d7e9e9
phrase shorten
sahithyaravi Aug 22, 2022
0bb7496
small fixes
sahithyaravi Aug 23, 2022
5e06e2f
triplets
sahithyaravi Aug 23, 2022
6d8a071
Merge branch 'main' of https://github.com/sahithyaravi1493/vlc_transf…
sahithyaravi Aug 23, 2022
229b712
first triplets
sahithyaravi Aug 23, 2022
d29ebe5
more changes to svo
sahithyaravi Aug 24, 2022
04a7ffb
more changes to triplets
sahithyaravi Aug 24, 2022
76b8b07
final q version
sahithyaravi Aug 24, 2022
c352daa
older q2 version
sahithyaravi Aug 25, 2022
04312a6
small fixes to older version
sahithyaravi Aug 25, 2022
ffd9744
more fixes to svo
sahithyaravi Aug 25, 2022
2baf12d
fixes
sahithyaravi Aug 25, 2022
91e586b
change ssearch
sahithyaravi Aug 26, 2022
5787f12
ss
sahithyaravi Aug 26, 2022
fbf0cac
trained
sahithyaravi Aug 26, 2022
fe47879
training sbert modifs
sahithyaravi Aug 26, 2022
f753e28
Reorg
sahithyaravi Aug 26, 2022
2b34e33
okvqa
sahithyaravi Aug 27, 2022
0c81da9
small fixes
sahithyaravi Aug 28, 2022
3690411
final
sahithyaravi Sep 28, 2022
e783115
Update README.md
sahithyaravi Oct 19, 2022
9147dc8
Update README.md
sahithyaravi Jan 1, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,22 @@
# vlc_transformer
# VLC-Commonsense
This repo is part of the VLC-BERT project (https://github.com/aditya10/VLC-BERT).


### Download data and organize expansions
Follow all the steps for data download and organization from https://github.com/aditya10/VLC-BERT/blob/master/README.md
mkdir data
cd data
ln -s DATA_PATH ./

### Install requirements
pip install -r requirements.txt

### Configure
Set paths for images, COMET expansions, method to generate context in config.py

### To process expansions
python process_expansions.py

### To train S-BERT (Augmented S-BERT)
python train_sbert_search.py

153 changes: 153 additions & 0 deletions analysis/check_exp_quality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import json
import string

from nltk.corpus import stopwords
from pytorch_pretrained_bert import BertTokenizer

# nltk.download('stopwords')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
s = set(stopwords.words('english'))

root = "/Users/sahiravi/Documents/Research/VL project/scratch/data/coco"


def _load_json(path):
with open(path, 'r') as f:
return json.load(f)


def check_exp_quality_aok(exp_name, set_name):
# Load questions:
annotations = _load_json(f'{root}/aokvqa/aokvqa_v1p0_' + set_name + '.json')

# Load expansions:
expansions = _load_json(f'{root}/aokvqa/commonsense/expansions/' + exp_name + '_aokvqa_' + set_name + '.json')
raw_expansions = _load_json(f'{root}/aokvqa/commonsense/expansions/question_expansion_sentences_{set_name}_aokvqa_{exp_name}.json')

K = 10
helpful = 0
total = 0
common_tokens = {}
print("USING K", K)
for q in annotations:
q_id = q['question_id']
im_id = q['image_id']

ans_text = ' '.join(q['direct_answers'])
ans_text = ans_text.translate(str.maketrans('', '', string.punctuation))
ans_text = ans_text.lower()

exp_text = expansions['{:012d}.jpg'.format(im_id)][str(q_id)][0]
exp_text = exp_text.split('.')[:K]
exp_text = ' '.join(exp_text)

raw_exp_text = ' '.join(raw_expansions[str(q_id)])
exp_text = raw_exp_text
print(len(raw_expansions[str(q_id)]))

exp_text = exp_text.translate(str.maketrans('', '', string.punctuation))
exp_text = exp_text.lower()

ans_tokens = tokenizer.tokenize(ans_text)
ans_tokens = [t for t in ans_tokens if t not in s]
exp_tokens = tokenizer.tokenize(exp_text)
exp_tokens = [t for t in exp_tokens if t not in s]

is_helpful = False
for token in ans_tokens:
if token in exp_tokens:
is_helpful = True
if token not in common_tokens:
common_tokens[token] = 1
else:
common_tokens[token] += 1

if is_helpful:
helpful += 1
total += 1

common_tokens = dict(sorted(common_tokens.items(), key=lambda item: item[1], reverse=True))
common_tokens = list(common_tokens.keys())[:15]

return helpful, total, common_tokens


def check_exp_quality_ok(exp_name, set_name):
# Load questions:
annotations = _load_json(f'{root}/annotations/mscoco_{set_name}2014_annotations.json')["annotations"]

# Load expansions:
expansions = _load_json(
f'{root}/okvqa/commonsense/expansions/{exp_name}/' + exp_name + '_okvqa_' + set_name + '.json')

# Load raw expansions:
# raw_expansions = _load_json(f'{root}/okvqa/commonsense/expansions/question_expansion_sentences_{set_name}_okvqa_{exp_name}.json')
helpful = 0
total = 0
common_tokens = {}
raw_commons = {}
K = 10
print("USING K", K)
for q in annotations:
q_id = q['question_id']
image_id = str(q['image_id'])
ans = q["answers"]
ans_text = ' '.join([a['answer'] for a in ans])
ans_text = ans_text.translate(str.maketrans('', '', string.punctuation))
ans_text = ans_text.lower()
n_zeros = 12 - len(str(image_id))
filename = f'COCO_{set_name}2014_' + n_zeros * '0' + image_id + '.jpg'
# raw_exp_text = ' '.join(raw_expansions[str(q_id)])

try:
exp_text = expansions[filename][str(q_id)][0]
except KeyError:
exp_text = expansions[image_id][str(q_id)]
exp_text = exp_text.split('.')[:K]
picked_exp_text = ' '.join(exp_text)

exp_text = picked_exp_text

exp_text = exp_text.translate(str.maketrans('', '', string.punctuation))
exp_text = exp_text.lower()

ans_tokens = tokenizer.tokenize(ans_text)
ans_tokens = [t for t in ans_tokens if t not in s]
exp_tokens = tokenizer.tokenize(exp_text)
exp_tokens = [t for t in exp_tokens if t not in s]

is_helpful = False
for token in ans_tokens:
if token in exp_tokens:
is_helpful = True
if token not in common_tokens:
common_tokens[token] = 1
else:
common_tokens[token] += 1
if is_helpful:
helpful += 1
total += 1

common_tokens = dict(sorted(common_tokens.items(), key=lambda item: item[1], reverse=True))
common_tokens = list(common_tokens.keys())[:10]

return helpful, total, common_tokens


if __name__ == '__main__':

exp_names = ['semq.4']
sets = ['val', 'train']
datasets = ["aokvqa"]
for dataset in datasets:
for exp_name in exp_names:
for set_name in sets:
if dataset == "okvqa":
helpful, total, common_tokens = check_exp_quality_ok(exp_name, set_name)
else:
helpful, total, common_tokens = check_exp_quality_aok(exp_name, set_name)
print('{} {} {}: {}/{} ({:.2f}%)'.format(dataset, exp_name, set_name, helpful, total,
helpful / total * 100),
' Most common tokens: ', common_tokens)
print()
126 changes: 126 additions & 0 deletions analysis/compare_results_aok.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
import random
from utils import load_json, imageid_to_path
from analysis.plot_picked_expansions import show_image
from config import *

def get_count(res, answers):
human_count = answers.count(res)
return min(1.0, float(human_count) / 3)


if __name__ == '__main__':
# check if all the paths provided are correct:
annotations = load_json("scratch/data/coco/aokvqa/aokvqa_v1p0_val.json")
questions = load_json("scratch/data/coco/aokvqa/aokvqa_v1p0_val.json")
captions = load_json(captions_path)
expansion = load_json(final_expansion_save_path)

# the two results to compare
results1 = load_json('../result_files/aokvqa/captions_aokvqa_val2017.json')
results2 = load_json('../result_files/aokvqa/sem11_aokvqa_val2017.json')


ans_list = annotations
q_list = questions


# get all difference between results
diffs = []
total = len(results1)
for i in range(len(results1)):
if results1[i]['answer'] != results2[i]['answer']:
diffs.append(i)

# random 50 indices
seed = 42
random.seed(seed)
for rand in range(100):
rand_idx = random.randint(0, len(diffs) - 1)
rand_idx = diffs[rand_idx]
# print(results1[rand_idx])
# print(results2[rand_idx])

q_id = results1[rand_idx]["question_id"]

ans = None
for a in ans_list:
if a['question_id'] == q_id:
# print(a)
ans = a
break

ques = None
for q in q_list:
if q['question_id'] == q_id:
# print(q)
ques = q
image_id = q['image_id']
break

res = {}
res['question_id'] = q_id
res['image_id'] = image_id

image_id = str(image_id)
print(image_id, q_id)
res['expansion'] = expansion[imageid_to_path(image_id)][str(q_id)]

# image_id_to_path
# for k in range(0, 12 - len(image_id)):
# image_id = '0' + image_id
img_path = (image_id)
caption = captions[imageid_to_path(image_id)]
# print(caption)

res['image_path'] = imageid_to_path(img_path)
res['question'] = ques['question']
res['caption'] = caption
res['rationale'] = ques['rationales']

# first and second answer
res['answer_1'] = results1[rand_idx]['answer']
res['answer_2'] = results2[rand_idx]['answer']

answers = [a for a in ans['direct_answers']]
acc_res1 = get_count(res['answer_1'], answers)
acc_res2 = get_count(res['answer_2'], answers)
res['possible_answers'] = answers
print(res['possible_answers'])

if acc_res1 > acc_res2:
res['state'] = 'bad'
elif acc_res2 > acc_res1:
res['state'] = 'good'
elif acc_res2 == 0 and acc_res1 == 0:
res['state'] = 'neutral'
else:
res['state'] = 'both-good'

# print(res)
if res['state'] == 'neutral':
# norms = grad_norms_dict[int(q_id)]

exp_list = res['expansion'].split(".")[:5]
print(exp_list)
l = exp_list
sorted_expansions = l #sorted(l, key=lambda x:x[0], reverse=True)
image_path = f"{images_path}" + res['image_path']
title = caption
print(title)

text = res['question'] + '\n Context:' + str([x for x in sorted_expansions]) + '\n\n base Answer: ' + res[
'answer_1'] + '\nimproved Answer: ' + res['answer_2'] + '\n\n GT Answers: ' + ", ".join(
res["possible_answers"]) + "\n" + ", ".join(res["rationale"])
print(text)
# print("gpt3", gpt3[str(res['question_id'])])

# Save path - change this
if not os.path.exists(res['state']):
os.mkdir(res['state'])

save_name = res['state'] + "/"+ str(res['question_id']) + '_'+ str(seed)
show_image(image_path, text, title, save_name)

print(len(diffs))
print(total)
Loading