Skip to content

Commit

Permalink
Cleanup-at-gpt-helper (#380)
Browse files Browse the repository at this point in the history
* Clarify types

* Api restored
  • Loading branch information
SavenkovIgor authored Jul 1, 2024
1 parent de35d60 commit 68db951
Showing 1 changed file with 54 additions and 30 deletions.
84 changes: 54 additions & 30 deletions tools/gpt-helper/term_helper.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,44 @@
"cells": [
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 3,
"id": "413b2b8e",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'Hello! How can I assist you today?'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import openai\n",
"import os\n",
"import json\n",
"\n",
"from typing import List, Dict, Any\n",
"from pathlib import Path\n",
"from dataclasses import dataclass\n",
"\n",
"from dotenv import load_dotenv, find_dotenv\n",
"_ = load_dotenv(find_dotenv())\n",
"\n",
"openai.api_key = os.getenv('OPENAI_API_KEY')\n",
"\n",
"# print(openai.Engine.list())\n",
"\n",
"global_path = \"../../lib/StaticDataStorage/data/Global.json\"\n",
"\n",
"global_path: Path = Path(\"../../lib/StaticDataStorage/data/Global.json\").resolve()\n",
"\n",
"def load_global_static_data() -> dict:\n",
" with open(global_path, \"r\") as f:\n",
" globalJson = json.load(f)\n",
" \n",
" return globalJson\n",
"def load_global_static_data() -> Dict[str, Any]:\n",
" return json.loads(global_path.read_text())\n",
"\n",
"def get_terms(globalJson: dict) -> list:\n",
" ret = []\n",
"def get_terms(globalJson: Dict[str, Any]) -> List[str]:\n",
" ret: List[str] = []\n",
"\n",
" for item in globalJson['terms']:\n",
" termDef = item['termDef']\n",
Expand All @@ -42,8 +53,8 @@
" terms = get_terms(globalJson)\n",
" return \", \".join(terms)\n",
"\n",
"def print_lines(text, max_line_length=120):\n",
" lines = []\n",
"def print_lines(text: str, max_line_length: int = 120):\n",
" lines: List[str] = []\n",
" line = \"\"\n",
" for word in text.split():\n",
" if len(line + word) > max_line_length:\n",
Expand All @@ -59,35 +70,48 @@
" model = \"\"\n",
" temp = 0\n",
"\n",
" def __init__(self, model=\"gpt-3.5-turbo\", temp=0):\n",
" def __init__(self, model: str = \"gpt-4o\", temp: float = 0):\n",
" self.model = model\n",
" self.temp = temp\n",
"\n",
"class Chat:\n",
" messages: list = []\n",
"@dataclass\n",
"class Message:\n",
" role: str\n",
" content: str\n",
"\n",
" def to_dict(self) -> Dict[str, str]:\n",
" return {\"role\": self.role, \"content\": self.content}\n",
"\n",
"class TermChat:\n",
" messages: List[Message] = []\n",
" engine: ChatEngine\n",
"\n",
" def __init__(self, engine, system_message, user_messages = []):\n",
" def __init__(self, engine: ChatEngine, system_message: str, user_messages: List[str] = []):\n",
" self.engine = engine\n",
" self.messages.append({\"role\": \"system\", \"content\": system_message})\n",
" self.messages.extend([{\"role\": \"user\", \"content\": message} for message in user_messages])\n",
" self.messages.append(Message(\"system\", system_message))\n",
" self.messages.extend([Message(\"user\", message) for message in user_messages])\n",
"\n",
" def get_answer(self, user_message = \"\"):\n",
" def get_answer(self, user_message: str = \"\") -> str | None:\n",
" if user_message != \"\":\n",
" self.messages.append({\"role\": \"user\", \"content\": user_message})\n",
" response = openai.ChatCompletion.create(\n",
" self.messages.append(Message(\"user\", user_message))\n",
"\n",
" chat_messages = [message.to_dict() for message in self.messages]\n",
" response = openai.chat.completions.create(\n",
" model=self.engine.model,\n",
" messages=self.messages,\n",
" messages=chat_messages,\n",
" temperature=self.engine.temp,\n",
" )\n",
" answer = response.choices[0].message[\"content\"]\n",
" answer = response.choices[0].message.content\n",
" self.messages.append({\"role\": \"assistant\", \"content\": answer})\n",
" return answer\n",
" \n",
"\n",
" def clear_messages(self):\n",
" first_message = self.messages[0]\n",
" self.messages = []\n",
" self.messages.append(first_message)\n"
" self.messages.append(first_message)\n",
"\n",
"\n",
"TermChat(ChatEngine(), \"\").get_answer()"
]
},
{
Expand All @@ -104,7 +128,7 @@
"\n",
"def predict_next_term():\n",
" engine = ChatEngine(temp=0.5)\n",
" chat = Chat(engine, continue_term_prompt)\n",
" chat = TermChat(engine, continue_term_prompt)\n",
" terms_str = get_terms_string()\n",
" answer = chat.get_answer(terms_str)\n",
" print_lines(answer)\n",
Expand Down Expand Up @@ -152,7 +176,7 @@
" terms_str = get_terms_string()\n",
"\n",
" engine = ChatEngine(temp=0.2)\n",
" chat = Chat(engine, give_definition_prompt.format(terms_str, new_term))\n",
" chat = TermChat(engine, give_definition_prompt.format(terms_str, new_term))\n",
"\n",
" answer = chat.get_answer()\n",
" print(answer)\n",
Expand All @@ -179,7 +203,7 @@
"\n",
" with open(global_path, \"w\") as f:\n",
" json.dump(globalJson, f, indent=4)\n",
" \n",
"\n",
"add_term(\"Symmetry\", f\"{{system}} {{property}} that remains unchanged after a certain {{transformation}}\", \"phys\")"
]
}
Expand All @@ -200,7 +224,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 68db951

Please sign in to comment.