-
Notifications
You must be signed in to change notification settings - Fork 1
/
gpt3_baseline.py
345 lines (261 loc) · 9.5 KB
/
gpt3_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import argparse
import json
import logging
import time
from pathlib import Path
from tqdm.auto import tqdm
from utils.file_io import read_lm_kbc_jsonl
from utils.model import gpt3, clean_up, convert_nan
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
SAMPLE_SIZE = 5000
MODEL_TYPES = ['text-davinci-002', 'text-curie-001', 'text-babbage-001', 'text-ada-001']
def create_prompt(subject_entity, relation):
### depending on the relation, we fix the prompt
if relation == "CountryBordersWithCountry":
prompt = f"""
Which countries neighbour Dominica?
['Venezuela']
Which countries neighbour North Korea?
['South Korea', 'China', 'Russia']
Which countries neighbour Serbia?
['Montenegro', 'Kosovo', 'Bosnia and Herzegovina', 'Hungary', 'Croatia', 'Bulgaria', 'Macedonia', 'Albania', 'Romania']
Which countries neighbour Fiji?
[]
Which countries neighbour {subject_entity}?
"""
elif relation == "CountryOfficialLanguage":
prompt = f"""
Which are the official languages of Suriname?
['Dutch']
Which are the official languages of Canada?
['English', 'French']
Which are the official languages of Singapore?
['English', 'Malay', 'Mandarin', 'Tamil']
Which are the official languages of Sri Lanka?
['Sinhala', 'Tamil']
Which are the official languages of {subject_entity}?
"""
elif relation == "StateSharesBorderState":
prompt = f"""
What states border San Marino?
['San Leo', 'Acquaviva', 'Borgo Maggiore', 'Chiesanuova', 'Fiorentino']
What states border Whales?
['England']
What states border Liguria?
['Tuscany', 'Auvergne-Rhoone-Alpes', 'Piedmont', 'Emilia-Romagna']
What states border Mecklenberg-Western Pomerania?
['Brandenburg', 'Pomeranian', 'Schleswig-Holstein', 'Lower Saxony']
What states border {subject_entity}?
"""
elif relation == "RiverBasinsCountry":
prompt = f"""
What countries does the river Drava cross?
['Hungary', 'Italy', 'Austria', 'Slovenia', 'Croatia']
What countries does the river Huai river cross?
['China']
What countries does the river Paraná river cross?
['Bolivia', 'Paraguay', 'Argentina', 'Brazil']
What countries does the river Oise cross?
['Belgium', 'France']
What countries does the river {subject_entity} cross?
"""
elif relation == "ChemicalCompoundElement":
prompt = f"""
What are all the atoms that make up the molecule Water?
['Hydrogen', 'Oxygen']
What are all the atoms that make up the molecule Bismuth subsalicylate ?
['Bismuth']
What are all the atoms that make up the molecule Sodium Bicarbonate ?
['Hydrogen', 'Oxygen', 'Sodium', 'Carbon']
What are all the atoms that make up the molecule Aspirin?
['Oxygen', 'Carbon', 'Hydrogen']
What are all the atoms that make up the molecule {subject_entity}?
"""
elif relation == "PersonLanguage":
prompt = f"""
Which languages does Aamir Khan speak?
['Hindi', 'English', 'Urdu']
Which languages does Pharrell Williams speak?
['English']
Which languages does Xabi Alonso speak?
['German', 'Basque', 'Spanish', 'English']
Which languages does Shakira speak?
['Catalan', 'English', 'Portuguese', 'Spanish', 'Italian', 'French']
Which languages does {subject_entity} speak?
"""
elif relation == "PersonProfession":
prompt = f"""
What is Danny DeVito's profession?
['Comedian', 'Film Director', 'Voice Actor', 'Actor', 'Film Producer', 'Film Actor', 'Dub Actor', 'Activist', 'Television Actor']
What is David Guetta's profession?
['DJ']
What is Gary Lineker's profession?
['Commentator', 'Association Football Player', 'Journalist', 'Broadcaster']
What is Gwyneth Paltrow's profession?
['Film Actor','Musician']
What is {subject_entity}'s profession?
"""
elif relation == "PersonInstrument":
prompt = f"""
Which instruments does Liam Gallagher play?
['Maraca', 'Guitar']
Which instruments does Jay Park play?
[]
Which instruments does Axl Rose play?
['Guitar', 'Piano', 'Pander', 'Bass']
Which instruments does Neil Young play?
['Guitar']
Which instruments does {subject_entity} play?
"""
elif relation == "PersonEmployer":
prompt = f"""
Where is or was Susan Wojcicki employed?
['Google']
Where is or was Steve Wozniak employed?
['Apple Inc', 'Hewlett-Packard', 'University of Technology Sydney', 'Atari']
Where is or was Yukio Hatoyama employed?
['Senshu University','Tokyo Institute of Technology']
Where is or was Yahtzee Croshaw employed?
['PC Gamer', 'Hyper', 'Escapist']
Where is or was {subject_entity} employed?
"""
elif relation == "PersonPlaceOfDeath":
prompt = f"""
What is the place of death of Barack Obama?
[]
What is the place of death of Ennio Morricone?
['Rome']
What is the place of death of Elon Musk?
[]
What is the place of death of Prince?
['Chanhassen']
What is the place of death of {subject_entity}?
"""
elif relation == "PersonCauseOfDeath":
prompt = f"""
How did André Leon Talley die?
['Infarction']
How did Angela Merkel die?
[]
How did Bob Saget die?
['Injury', 'Blunt Trauma']
How did Jamal Khashoggi die?
['Murder']
How did {subject_entity} die?
"""
elif relation == "CompanyParentOrganization":
prompt = f"""
What is the parent company of Microsoft?
[]
What is the parent company of Sony?
['Sony Group']
What is the parent company of Saab?
['Saab Group', 'Saab-Scania', 'Spyker N.V.', 'National Electric Vehicle Sweden', 'General Motors']
What is the parent company of Max Motors?
[]
What is the parent company of {subject_entity}?
"""
return prompt
def load_prompt(subject_entity, relation, prompt_type='triple'):
"""
Function that loads the prompts from the .txt files
:param subject_entity:
:param relation:
:param simple: Set to True if you want to load the triple-based prompts (a.k.a. simple prompts)
:return:
"""
if prompt_type == 'triple':
prompt_path = Path('data/prompts_triple_based')
elif prompt_type == 'language':
prompt_path = Path('data/prompts_natural_language')
else:
prompt_path = Path('data/prompts_optimized')
prompt_path = Path.joinpath(prompt_path, f"{relation}.txt")
with open(prompt_path, "r") as f:
prompt = f.read()
prompt = prompt.format(subject_entity=subject_entity)
return prompt
def probe_lm(input: Path, model: str, output: Path, prompt_type='triple', batch_size=20):
### for every subject-entity in the entities list, we probe the LM using the below sample prompts
# Load the input file
logger.info(f"Loading the input file \"{input}\"...")
input_rows = read_lm_kbc_jsonl(input)
logger.info(f"Loaded {len(input_rows):,} rows.")
# Trim list & batch entities
input_rows = input_rows[:SAMPLE_SIZE] #
batches = [input_rows[x:x + batch_size] for x in range(0, len(input_rows), batch_size)]
results = []
for idx, batch in tqdm(enumerate(batches)):
prompts = []
for index, row in enumerate(batch):
# TODO: Generate examples in the prompt automatically (Thiviyan)
#
# TODO: Rephrase prompt automatically (Dimitris)
### creating a specific prompt for the given relation
logger.info(f"Creating prompts...")
prompts.append(load_prompt(row['SubjectEntity'], row['Relation'], prompt_type))
### probing the language model and obtaining the ranked tokens in the masked_position
logger.info(f"Running the model...")
predictions = gpt3(prompts, model=model) # TODO Figure out what to do with probabilities
### Clean and format results
for row, prediction in zip(batch, predictions):
prediction['text'] = clean_up(prediction['text'])
prediction['text'] = convert_nan(prediction['text'])
result = {
"SubjectEntity": row['SubjectEntity'],
"Relation": row['Relation'],
"Prompt": prediction['prompt'],
"ObjectEntities": prediction['text']
}
results.append(result)
# Sleep is needed because we make many API calls. We can make 60 calls every minute
if idx % 5:
time.sleep(2)
### saving the prompt outputs separately for each relation type
logger.info(f"Saving the results to \"{output}\"...")
with open(output, "w") as f:
for result in results:
f.write(json.dumps(result) + "\n")
def main():
parser = argparse.ArgumentParser(
description="Probe a Language Model and Run the Baseline Method on Prompt Outputs"
)
parser.add_argument(
"-i",
"--input",
type=str,
default="data/dev.jsonl",
help="input file containing the subject-entities for each relation to probe the language model",
)
parser.add_argument(
"-o",
"--output",
type=str,
default="predictions/gpt3.pred.jsonl",
help="output directory to store the baseline output",
)
parser.add_argument(
"-m",
"--model",
type=str,
default="text-davinci-002",
help="The models provided by OpenAI. \
Options: 'text-davinci-002', 'text-curie-001', 'text-babbage-001', 'text-ada-001'"
)
parser.add_argument(
"--prompt_type",
type=str,
default="triple",
help="Simple vs natural language vs optimized based prompts\
Options: 'triple', 'language', 'optimized'",
)
args = parser.parse_args()
print(args)
probe_lm(args.input, args.model, args.output, args.prompt_type)
if __name__ == "__main__":
main()