Skip to content

Commit

Permalink
better fine tune errors (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
earonesty authored Nov 9, 2023
1 parent bd5893b commit 180af0f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
31 changes: 22 additions & 9 deletions ai_worker/fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,26 @@ def temp_file(self, name, wipe=False):
def massage_line(self, ln, job):
# toss our role for now, for some reason it didn't work
# todo: check for role support in template
j = json.loads(ln)

if pr := j.get("prompt"):
# todo: use templates properly to massage data for instruct vs chat
j = json.loads(ln)
cm = j["completion"]
j = {"messages": [{"role": "user", "content": pr}, {"role": "assistant", "content": cm}]}
if not ln.strip():
# skip blank
return None

if "mistral" in job["model"].lower():
try:
j = json.loads(ln)
j["messages"] = [m for m in j["messages"] if m["role"] != "system"]
ln = json.dumps(j) + "\n"
if pr := j.get("prompt"):
# todo: use templates properly to massage data for instruct vs chat
j = json.loads(ln)
cm = j["completion"]
j = {"messages": [{"role": "user", "content": pr}, {"role": "assistant", "content": cm}]}

if "mistral" in job["model"].lower():
j = json.loads(ln)
j["messages"] = [m for m in j["messages"] if m["role"] != "system"]
ln = json.dumps(j) + "\n"
except (KeyError, ValueError, json.JSONDecodeError) as ex:
log.error("fine tune: %s error %s with training line: %s", job.get("id"), repr(ex), ln)
assert False, "fine tune: invalid training data: '%s': %s " % (ln, repr(ex))

return ln

Expand All @@ -69,6 +77,8 @@ def massage_fine_tune(self, file, job):
ln = inp.readline(MAX_CONTEXT)
while ln:
ln = self.massage_line(ln, job)
if not ln:
continue
cnt += 1
if ec and (random.random() > training_split_pct or tc <= ec):
tc += 1
Expand All @@ -77,6 +87,9 @@ def massage_fine_tune(self, file, job):
ec += 1
ef.write(ln)
ln = inp.readline(MAX_CONTEXT)

assert tc != 0 and ec != 0, "not enough valid training data"

return train_file, eval_file

async def fine_tune(self, job):
Expand Down
3 changes: 2 additions & 1 deletion ai_worker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ async def run_one(self):
log.info("done %s (%s secs)", model, en - st)
except (
websockets.ConnectionClosedError, websockets.ConnectionClosed, websockets.exceptions.ConnectionClosedError):
log.error("disconnected while running request: %s", req_str)
if req_str:
log.error("disconnected while running request: %s", req_str)
if event:
log.error("was sending event: %s", event)
except Exception as ex:
Expand Down

0 comments on commit 180af0f

Please sign in to comment.