Skip to content

Commit

Permalink
Merge pull request #32 from homanp/validate-openai-training-data
Browse files Browse the repository at this point in the history
Add functionality for validating openai training data
  • Loading branch information
homanp authored Oct 22, 2023
2 parents 5e3bc2f + 94d963b commit 206bd20
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
4 changes: 2 additions & 2 deletions nagato/service/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Union

import openai
import requests
from llama_index import Document

from nagato.service.embedding import EmbeddingService
Expand Down Expand Up @@ -43,6 +43,6 @@ def create_finetuned_model(
training_file=formatted_training_file, webhook_url=webhook_url
)
if provider == "OPENAI":
finetune = openai.FineTune.retrieve(id=finetune.get("id"))
requests.post(webhook_url, json=finetune)
finetunning_service.cleanup(training_file=finetune.get("training_file"))
return finetune
49 changes: 45 additions & 4 deletions nagato/service/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def __init__(

def generate_prompt_and_completion(self, node):
prompt = generate_qa_pair_prompt(
context=node.text, num_of_qa_paris=10, format=GPT_DATA_FORMAT
context=node.text,
num_of_qa_pairs=self.num_questions_per_chunk,
format=GPT_DATA_FORMAT,
)
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
Expand All @@ -124,9 +126,40 @@ def generate_prompt_and_completion(self, node):
return completion.choices[0].message.content

def validate_dataset(self, training_file: str) -> str:
pass
valid_lines = []
with open(training_file, "r") as file:
lines = file.readlines()
total_lines = len(lines)
progress_bar = tqdm(
total=total_lines, desc="🟠 Validating dataset", file=sys.stdout
)
for line in lines:
try:
data = json.loads(line)
if "messages" not in data:
continue
messages = data["messages"]
if len(messages) != 3:
continue
if not (
messages[0]["role"] == "system"
and messages[1]["role"] == "user"
and messages[2]["role"] == "assistant"
):
continue
valid_lines.append(line)
except json.JSONDecodeError:
continue
finally:
progress_bar.update(1)
progress_bar.set_description("🟢 Validating dataset")
progress_bar.close()

def finetune(self, training_file: str, _webhook_url: str = None) -> Dict:
with open(training_file, "w") as file:
file.writelines(valid_lines)
return training_file

def finetune(self, training_file: str, webhook_url: str = None) -> Dict:
file = openai.File.create(file=open(training_file, "rb"), purpose="fine-tune")
finetune = openai.FineTuningJob.create(
training_file=file.get("id"), model=OPENAI_MODELS[self.base_model]
Expand Down Expand Up @@ -169,7 +202,7 @@ def validate_dataset(self, training_file: str) -> str:
total_lines = len(lines)
progress_bar = tqdm(
total=total_lines,
desc="Validating lines",
desc="🟠 Validating training data",
file=sys.stdout,
)
for i, line in enumerate(lines, start=1):
Expand All @@ -180,6 +213,7 @@ def validate_dataset(self, training_file: str) -> str:
except json.JSONDecodeError:
pass
progress_bar.update(1)
progress_bar.set_description("🟢 Validating training data")
progress_bar.close()

with open(training_file, "w") as f:
Expand All @@ -201,6 +235,13 @@ def finetune(self, training_file: str, webhook_url: str = None) -> Dict:
destination="homanp/test",
webhook=webhook_url,
)
progress_bar = tqdm(
total=1,
desc="🟢 Started model training",
file=sys.stdout,
)
progress_bar.update(1)
progress_bar.close()
return {"id": training.id, "training_file": training_file}


Expand Down
6 changes: 3 additions & 3 deletions test/mytest.py → test/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ def main():
result = create_finetuned_model(
url="https://digitalassets.tesla.com/tesla-contents/image/upload/IR/TSLA-Q2-2023-Update.pdf",
type="PDF",
base_model="LLAMA2_7B_CHAT",
provider="REPLICATE",
base_model="GPT_35_TURBO",
provider="OPENAI",
webhook_url="https://webhook.site/ebe803b9-1e34-4b20-a6ca-d06356961cd1",
)
print(result)
print(f"🤖 MODEL: {result}")


# Run the function
Expand Down

0 comments on commit 206bd20

Please sign in to comment.