-
Notifications
You must be signed in to change notification settings - Fork 8
/
openai.coco
297 lines (261 loc) · 11 KB
/
openai.coco
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
"""
The OpenAI backend. Uses large language models for black box optimization.
"""
import os
from ast import literal_eval
import openai
from bbopt import constants
from bbopt.util import printerr, mean_abs_dev, mean
from bbopt.params import param_processor
from bbopt.backends.util import StandardBackend, sorted_params
# Utilities:
def get_prompt(params, data_points, losses, hoped_for_loss) =
"""Get the base OpenAI API prompt."""
r'''# black box function to be minimized
def f({func_params}) -> float:
"""
parameters:
{docstring}
returns:
float: the loss
"""
return black_box_loss({names})
# experimentally observed data
# new experiments MUST stay within the bounds, SHOULD fully explore the bounds, and SHOULD converge to minimum
# bounds: f({bounds})
{values}{hoped_for_loss} == f('''.format(
func_params=", ".join(
"{name}: {type}".format(
name=name,
type=(
"int" if func == "randrange"
else type(args[0][0]).__name__ if func == "choice" and all_equal(map(type, args[0]))
else "typing.Any" if func == "choice"
else "float"
),
)
for name, (func, args, _) in params.items()
),
docstring="\n".join(
" {name}: in {func}({args})".format(
name=name,
func="range" if func == "randrange" else func,
args=", ".join((args[:2] if func == "randrange" and args[-1] == 1 else args) |> map$(repr)),
)
for name, (func, args, _) in params.items()
),
names=", ".join(params),
bounds=", ".join(
"{name}: {func}({args})".format(
name=name,
func="range" if func == "randrange" else func,
args=", ".join((args[:2] if func == "randrange" and args[-1] == 1 else args) |> map$(repr)),
)
for name, (func, args, _) in params.items()
),
values="".join(
"{loss} == f({named_args})\n".format(
named_args=", ".join(params |> map$(name -> f"{name}={point[name]!r}")),
loss=loss,
)
for point, loss in zip(data_points, losses)
),
hoped_for_loss=int(hoped_for_loss) if int(hoped_for_loss) == hoped_for_loss else hoped_for_loss,
)
def is_int(x) = int(x) == x
def typed(average):
"""Get a typed version of average."""
def typed_average(xs) =
int(round(mu)) if all(xs |> map$(is_int)) else mu where:
xs = tuple(xs)
mu = average(xs)
return typed_average
def get_loss_eps(typical_loss):
"""Get a reasonably-sized hoped for loss improvement."""
a, b = float(abs(typical_loss)).as_integer_ratio()
little_a = int("1" * len(str(a)))
return little_a / b
# Backend:
class OpenAIBackend(StandardBackend):
"""OpenAI large language model BBopt backend."""
backend_name = "openai"
implemented_funcs = (
"randrange",
"uniform",
"normalvariate",
"choice",
)
max_prompt_len = float("inf")
def setup_backend(self, params, engine=None, temperature=None, max_retries=None, api_key=None, debug=False):
self.params = sorted_params(params)
self.engine = engine ?? constants.openai_default_engine
self.temp = temperature ?? constants.openai_default_temp
self.max_retries = max_retries ?? constants.openai_default_max_retries
openai.api_key = api_key ?? os.getenv("OPENAI_API_KEY")
self.debug = debug
self.data_points = []
self.losses = []
self.cached_values = ()
def tell_data(self, new_data, new_losses):
for point, loss in zip(new_data, new_losses):
# avoid (point, loss) duplicates since they cause GPT to repeat itself
try:
existing_index = self.data_points.index(point)
except ValueError:
existing_index = None
if existing_index is None or self.losses[existing_index] != loss:
self.data_points.append(point)
self.losses.append(loss)
def get_prompt(self) =
"""Get the OpenAI API prompt to use."""
(
get_prompt(
self.params,
self.data_points,
self.losses,
# all terms are structured to ensure that they don't use too many more sig figs than the loss
hoped_for_loss=min(self.losses) - (typed mean_abs_dev)(self.losses) - get_loss_eps((typed mean)(self.losses)),
)
+ ", ".join(zip(self.params, self.cached_values) |> starmap$((name, vals) -> f"{name}={vals!r}"))
# only "," not ", " since the prompt shouldn't end in a space
+ ("," if self.cached_values else "")
)
@property
def expected_params(self) =
"""The parameters that are expected to be in the prompt."""
self.params.keys()$[len(self.cached_values):]
def get_completion_len(self) =
"""Get the maximum number of characters in a completion."""
max(
len(", ".join(", ".join(self.expected_params |> map$(name -> f"{name}={point[name]!r}"))))
for point in self.data_points
) + 1
def to_python(self, completion):
"""Convert a completion to Python code as best as possible."""
completion = completion.strip(",")
for repl, to in (
("\u2212", "-"),
("\u2018", "'"),
("\u2019", "'"),
("\u201c", '"'),
("\u201d", '"'),
("\u221e", 'float("inf")'),
) :: (
(f"{name}:", "{name}=") for name in self.params
):
completion = completion.replace(repl, to)
# treat as dictionary
if all(f"{name}=" in completion for name in self.expected_params):
for name in self.expected_params:
completion = completion.replace(f"{name}=", f'"{name}":')
completion = "{" + completion + "}"
if self.debug:
print("== PYTHON ==\n" + completion)
valdict = literal_eval(completion)
return tuple(valdict[name] for name in self.expected_params)
# treat as tuple
else:
for name in self.params:
completion = completion.replace(f"{name}=", "")
completion = "(" + completion + ",)"
if self.debug:
print("== PYTHON ==\n" + completion)
return literal_eval(completion)
def get_next_values(self):
# generate prompt
prompt = self.get_prompt()
while len(prompt) > self.max_prompt_len:
self.data_points.pop(0)
self.losses.pop(0)
prompt = self.get_prompt()
if self.debug:
print("\n== PROMPT ==\n" + prompt)
# query api
try:
response = openai.Completion.create(
engine=self.engine,
prompt=prompt,
temperature=self.temp,
max_tokens=self.get_completion_len(),
)
except openai.error.InvalidRequestError as api_err:
if self.debug:
print("== END ==")
if not str(api_err).startswith(constants.openai_max_context_err_prefix):
raise
if self.debug:
print(f"ERROR: got max context length error with {self.max_prompt_len=}")
if self.max_prompt_len == float("inf"):
self.max_prompt_len = len(prompt.rsplit("\n", 1)[0])
else:
self.max_prompt_len -= self.get_completion_len()
return self.retry_get_values()
# parse response
try:
completion = response["choices"][0]["text"]
if self.debug:
print("== COMPLETION ==\n" + completion)
valvec = self.to_python(completion.split(")", 1)[0].strip())
except BaseException as parse_err:
if self.debug:
print("== END ==")
if self.debug:
print(f"ERROR: {parse_err} for API response:\n{response}")
return self.retry_get_values(temp=(self.temp + constants.openai_default_temp) / 2)
if self.debug:
print("== END ==")
legal_values = (
self.cached_values + valvec
|> .[:len(self.params)]
|> zip$(?, self.params.items())
|> takewhile$(def ((val, (name, (func, args, kwargs)))) ->
param_processor.in_support(name, val, func, *args, **kwargs))
|> map$(.[0])
|> tuple
)
if len(legal_values) < len(self.params):
if self.debug:
if len(valvec) < len(self.params) - len(self.cached_values):
print(f"ERROR: insufficient values (got {len(valvec)}; expected {len(self.params) - len(self.cached_values)})")
else:
print(f"ERROR: got illegal values: {valvec!r}")
return self.retry_get_values(temp=(self.temp + constants.openai_default_temp) / 2, cached_values=legal_values)
# return values
values = {name: val for name, val in zip(self.params, legal_values)}
if values in self.data_points:
if self.debug:
print(f"ERROR: got duplicate point: {legal_values!r}")
return self.retry_get_values(temp=self.temp + (constants.openai_max_temp - self.temp) / 2, cached_values=())
return values
def retry_get_values(self, temp=None, cached_values=None):
"""Used in get_next_values to keep track of recursive calls."""
if not self.max_retries:
if self.debug:
print()
printerr(f"BBopt Warning: Maximum number of OpenAI API retries exceeded for {self.engine} on:\n== PROMPT ==\n{self.get_prompt()}\n== END ==")
return {} # return empty values so that the fallback random backend will be used instead
if self.debug:
if temp is None:
print(f"RETRYING with: {self.max_prompt_len=}")
else:
print(f"RETRYING with new temperature: {self.temp} -> {temp}")
old_retries, self.max_retries = self.max_retries, self.max_retries - 1
if temp is not None:
old_temp, self.temp = self.temp, temp
if cached_values is not None:
if self.debug:
print(f"CACHING values: {self.cached_values} -> {cached_values}")
self.cached_values = cached_values
try:
return self.get_next_values()
finally:
self.max_retries = old_retries
if temp is not None:
self.temp = old_temp
if cached_values is not None:
self.cached_values = ()
# Registered names:
OpenAIBackend.register()
OpenAIBackend.register_alg("openai")
OpenAIBackend.register_alg("openai_debug", debug=True)
OpenAIBackend.register_alg("openai_davinci", engine=constants.openai_davinci_engine)