-
Notifications
You must be signed in to change notification settings - Fork 1
/
integrity_checking.py
206 lines (160 loc) · 5.39 KB
/
integrity_checking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Get positive and negative fact queries
# See with what probabilities it predicts true or false
import argparse
import json
import time
from pathlib import Path
from typing import List, Tuple
import pandas as pd
from utils.file_io import read_lm_kbc_jsonl_to_df, df_to_jsonl
from utils.model import gpt3
def logical_integrity(batch: pd.DataFrame) -> List[Tuple[Tuple[int, str], pd.DataFrame]]:
prompts = []
indices = []
for index, subject, relation, _, objects in batch.itertuples(index=True):
if relation not in ["CompanyParentOrganization",
"CountryOfficialLanguage",
"PersonCauseOfDeath",
"PersonInstrument",
"PersonLanguage"]:
continue
for object in objects:
if object == '':
continue
prompts.append(positive_negative_prompt_pairs(relation, subject, object))
indices.append((index, object))
predictions = []
for ndx in range(0, len(prompts), 20):
predictions.extend(gpt3(prompts[ndx:min(ndx + 20, len(prompts))]))
time.sleep(2)
# predictions = gpt3(prompts)
# for i, prediction in enumerate(predictions):
# print(prompts[i])
# print(prediction['text'])
# print("\n")
return list(zip(indices, predictions))
def positive_negative_prompt_pairs(relation, subject_entity, object_entity):
### depending on the relation, we fix the prompt
if relation == "CountryBordersWithCountry":
prompt = f"""Niger neighbours Libya.
True
North Korea neighbours the Netherlands.
False
{subject_entity} neighbours {object_entity}.
"""
elif relation == "CountryOfficialLanguage":
prompt = f"""Swedish is an official language of Finland.
True
French is an official language of India.
False
{object_entity} is an official language of {subject_entity}.
"""
elif relation == "StateSharesBorderState":
prompt = f"""San Marino shares a border with San Leo.
True
Texas shares a border with Hamburg.
False
{subject_entity} shares a border with {object_entity}.
"""
elif relation == "RiverBasinsCountry":
prompt = f"""The river Drava crosses Hungary.
True
The river Huai crosses the Netherlands.
False
The river {subject_entity} crosses {object_entity}.
"""
elif relation == "ChemicalCompoundElement":
prompt = f"""The molecule water is made up of the element Hydrogen.
True
The molecule aspirin is made up of the element Germanium.
False
The molecule {subject_entity} is made up of the element {object_entity}.
"""
elif relation == "PersonLanguage":
prompt = f"""Aamir Khan speaks Hindi.
True
Pharrell Williams speaks French.
False
{subject_entity} speaks {object_entity}.
"""
elif relation == "PersonProfession":
prompt = f"""Danny DeVito is a director.
True
Christina Aguilera is a businessperson.
False
{subject_entity} is a {object_entity.lower()}.
"""
elif relation == "PersonInstrument":
prompt = f"""Liam Gallagher plays the guitar.
True
Jay Park plays the piano.
False
{subject_entity} plays the {object_entity.lower()}.
"""
elif relation == "PersonEmployer":
prompt = f"""Susan Wojcicki is or was employed by Google.
True
Steve Wozniak is or was employed by Microsoft.
False
{subject_entity} is or was employed by {object_entity}.
"""
elif relation == "PersonPlaceOfDeath":
prompt = f"""The place of death of Elvis Presley is Graceland.
True
The place of death of Barack Obama is Washington.
False
The place of death of {subject_entity} is {object_entity}.
"""
elif relation == "PersonCauseOfDeath":
prompt = f"""Aretha Franklin died of pancreatic cancer.
True
Bill Gates died of femoral fracture.
False
{subject_entity} died of {object_entity}.
"""
elif relation == "CompanyParentOrganization":
prompt = f"""Apple is the parent company of Microsoft.
False
Sony Group is the parent company of Sony.
True
{object_entity} is the parent company of {subject_entity}.
"""
return prompt
def fact_checking(input_file, output_file):
### looping over all the files in the input directory
prompt_df = read_lm_kbc_jsonl_to_df(input_file)
filtered = logical_integrity(prompt_df)
indices = []
for index, prediction in filtered:
if prediction['text'].strip() == 'False':
indices.append(index)
for index, object in indices:
prompt_df["ObjectEntities"][index].remove(object)
with open(output_file, "w") as f:
for prediction in df_to_jsonl(prompt_df):
f.write(json.dumps(prediction) + "\n")
# TODO: If by filtering a fact, there are no more objects for a certain subject, make sure to add NONE
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input_file",
type=str,
default="./predictions/gpt3.pred.jsonl",
help="input directory containing the baseline or your method output",
)
parser.add_argument(
"-o",
"--output_file",
type=str,
default="./predictions/gpt3_fact_check.pred.jsonl",
help="Output file (required)",
)
args = parser.parse_args()
print(args)
input_file = Path(args.input_file)
output_file = Path(args.output_file)
assert input_file.exists()
fact_checking(input_file, output_file)
if __name__ == "__main__":
main()