Skip to content

Commit

Permalink
feat: Added GPT Rating for prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
rohanpooniwala committed Apr 9, 2023
1 parent 049da89 commit d862a91
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
40 changes: 33 additions & 7 deletions server/api/gpt_rating.py → server/commons/gpt_rating.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,35 @@
api_key = OPENAI_API_KEY
openai.api_key = api_key

SYSTEM_PROMPT = """You are an AI which rates a conversation betweeen User and a Bot. You rate the reply of the Bot on a scale of 1 to 10.
The conversation is rated on the following criteria:
1. Relevance of the conversation
2. Accuracy of the answer
3. Grammar and spelling
4. If the user's task is completed
You can only reply with a number between 1 to 10.
You are very strict and will only give a 10 if the bot's reply is perfect.
If there is even a single mistake or the a, you will give a 1.
If the bot's reply is not relevant, you will give a 1.
If user's task is not completed, you will give a 1 else you will give a 10.
"""

RATING_PROMPT = """Rate the following message
```
{message}
```
"""

# TODOs: Handle max tokens reached
# TODOs: Handle no text in response
# Function to send a rate the chatbot's response and return the rating
def rate_the_conversation(rating_log):

response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=rating_log,
max_tokens=4000,
max_tokens=200,
stop=None,
temperature=0.7,
)
Expand All @@ -26,20 +46,19 @@ def rate_the_conversation(rating_log):


def ask_for_rating(message):
print("Rating -> ", message)
message_log = [
{"role": "system", "content": "You are talking to a feedback chatbot, rate the conversation between 1 to 10."},
({"role": "user", "content": "message"}),
{"role": "system", "content": SYSTEM_PROMPT},
({"role": "user", "content": RATING_PROMPT.format(message=message)}),
]
response = rate_the_conversation(message_log)
score = process_rating_response(response)
print("Score -> ", score)
return score


def process_rating_response(response):
score = extract_number_from_text(response)
print("*" * 30)
print("score")
print(score)
return score


Expand All @@ -60,3 +79,10 @@ def calculate_ratings_metrics_score(metrics_scores):
total_score += score
# TODOs: Check for the range and send score accordingly
return total_score


if __name__ == "__main__":
# low score
print("Low score", ask_for_rating("User: This is a test message\nBot: Earth revoles around the sun"))
# high score
print("High score", ask_for_rating("User: This is a test message\nBot: This is a test reply"))
8 changes: 5 additions & 3 deletions server/commons/langflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from database_utils.intermediate_step import insert_intermediate_steps
from database_utils.prompt import create_prompt
from schemas.prompt_schema import Prompt
from server.api.gpt_rating import ask_for_rating
from commons.gpt_rating import ask_for_rating


def format_intermediate_steps(intermediate_steps):
Expand Down Expand Up @@ -84,7 +84,7 @@ def process_graph(message, chat_history, data_graph):
# We have to save it here because if the
# memory is updated we need to keep the new values
print("Saving langchain object to cache")
save_cache(computed_hash, langchain_object, is_first_message)
# save_cache(computed_hash, langchain_object, is_first_message)
print("Saved langchain object to cache")
return {"result": str(result), "thought": thought}

Expand All @@ -102,7 +102,9 @@ def get_prompt(chatbot_id: int, prompt: Prompt, db: Session):
prompt_row.response = result["result"]
prompt_row.time_taken = float(time.time() - start) # type: ignore
insert_intermediate_steps(db, prompt_row.id, result["thought"]) # type: ignore
prompt_row.gpt_rating = ask_for_rating() # type: ignore

message = f"User: {prompt.new_message}\nBot: {result['result']}"
prompt_row.gpt_rating = ask_for_rating(message) # type: ignore
db.commit()

result["prompt_id"] = prompt_row.id
Expand Down
2 changes: 1 addition & 1 deletion server/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Prompt(Base):
id = Column(Integer, primary_key=True)
chatbot_id = Column(Integer, ForeignKey("chatbot.id"), nullable=False)
input_prompt = Column(Text, nullable=False)
gpt_rating = Column(Enum(PromptRating), nullable=True)
gpt_rating = Column(String(5), nullable=True)
user_rating = Column(Enum(PromptRating), nullable=True)
chatbot_user_rating = Column(Enum(PromptRating), nullable=True)
response = Column(Text, nullable=False)
Expand Down

0 comments on commit d862a91

Please sign in to comment.