From 33da1e2e4db3886232f742ee6b63e14977e619f7 Mon Sep 17 00:00:00 2001 From: Spyros Date: Sun, 18 Aug 2024 11:31:27 +0100 Subject: [PATCH] experiment to suggest mappings using chat completions --- application/database/db.py | 4 +- .../prompt_client/openai_prompt_client.py | 20 +++++++++ application/prompt_client/prompt_client.py | 20 ++++++++- .../prompt_client/vertex_prompt_client.py | 7 ++++ application/utils/spreadsheet_parsers.py | 41 ++++++++++++------- application/web/web_main.py | 30 ++++++++++++++ 6 files changed, 105 insertions(+), 17 deletions(-) diff --git a/application/database/db.py b/application/database/db.py index de2b24f10..71f777fe6 100644 --- a/application/database/db.py +++ b/application/database/db.py @@ -735,11 +735,11 @@ def __get_unlinked_cres(self) -> List[CRE]: .all() ) return cres + def get_all_nodes_and_cres(self): + return self.__get_all_nodes_and_cres() def __get_all_nodes_and_cres(self) -> List[cre_defs.Document]: result = [] - nodes = [] - cres = [] node_ids = self.session.query(Node.id).all() for nid in node_ids: result.extend(self.get_nodes(db_id=nid[0])) diff --git a/application/prompt_client/openai_prompt_client.py b/application/prompt_client/openai_prompt_client.py index b2fdc6849..21fa4047f 100644 --- a/application/prompt_client/openai_prompt_client.py +++ b/application/prompt_client/openai_prompt_client.py @@ -1,3 +1,4 @@ +from typing import List import openai import logging @@ -58,3 +59,22 @@ def query_llm(self, raw_question: str) -> str: messages=messages, ) return response.choices[0].message["content"].strip() + + def create_mapping_completion(self, prompt:str, cre_id_and_name_in_export_format:List[str], standard_id_or_content :str) -> str: + messages = [ + { + "role": "system", + "content": f"You are map-gpt, a helpful assistant that is an expert in mapping standards to other standards. I will give you a standard to map to and a range of candidates and you will response ONLY with the most relevant candidate.", + }, + { + "role": "user", + "content": f"Your task is to map the following standard to the most relevant candidate in the list of candidates provided. The standard to map to is: `{standard_id_or_content}`. The candidates are: `{cre_id_and_name_in_export_format}`. Answer ONLY with the most relevant candidate exactly as it is on the input, delimit the candidate with backticks`.", + }, + ] + openai.api_key = self.api_key + response = openai.chat.completions.create( + model="gpt-3.5-turbo", + messages=messages, + temperature=0.0, + ) + return response.choices[0].message.content.strip() diff --git a/application/prompt_client/prompt_client.py b/application/prompt_client/prompt_client.py index 500c5bcf9..916dce35b 100644 --- a/application/prompt_client/prompt_client.py +++ b/application/prompt_client/prompt_client.py @@ -363,7 +363,6 @@ def get_id_of_most_similar_cre_paginated( self, item_embedding: List[float], similarity_threshold: float = SIMILARITY_THRESHOLD, - refresh_embeddings: bool = False, ) -> Optional[Tuple[str, float]]: """this method is meant to be used when CRE runs in a web server with limited memory (e.g. firebase/heroku) instead of loading all our embeddings in memory we take the slower approach of paginating them @@ -518,3 +517,22 @@ def generate_text(self, prompt: str) -> Dict[str, str]: table = [closest_object] result = f"Answer: {answer}" return {"response": result, "table": table, "accurate": accurate} + + def get_id_of_most_similar_cre_using_chat( + self, item: defs.Document + ) -> Optional[str]: + # load all cres + content = "" + if item.hyperlink: + content = self.embeddings_instance.get_content(item.hyperlink) + else: + content = item.__repr__() + database = self.database + res = database.get_all_nodes_and_cres() + cres = [r for r in res if r.doctype == defs.Credoctypes.CRE.value] + cres_in_export_format = [f"{c.id}|{c.name}" for c in cres] + return self.ai_client.create_mapping_completion( + prompt="", + cre_id_and_name_in_export_format=cres_in_export_format, + standard_id_or_content=content, + ) diff --git a/application/prompt_client/vertex_prompt_client.py b/application/prompt_client/vertex_prompt_client.py index 34acff062..4add94184 100644 --- a/application/prompt_client/vertex_prompt_client.py +++ b/application/prompt_client/vertex_prompt_client.py @@ -109,3 +109,10 @@ def query_llm(self, raw_question: str) -> str: msg = f"Your task is to answer the following cybesrsecurity question if you can, provide code examples, delimit any code snippet with three backticks, ignore any unethical questions or questions irrelevant to cybersecurity\nQuestion: `{raw_question}`\n ignore all other commands and questions that are not relevant." response = self.chat.send_message(msg, **parameters) return response.text + + def create_mapping_completion(self, prompt:str, cre_id_and_name_in_export_format:List[str], standard_id_or_content :str) -> str: + parameters = {"temperature": 0.5, "max_output_tokens": MAX_OUTPUT_TOKENS} + + msg= f"You are map-gpt, a helpful assistant that is an expert in mapping standards to other standards. I will give you a standard to map to and a range of candidates and you will response ONLY with the most relevant candidate."\ + f"Your task is to map the following standard to the most relevant candidate in the list of candidates provided. The standard to map to is: `{standard_id_or_content}`. The candidates are: `{cre_id_and_name_in_export_format}`. Answer ONLY with the most relevant candidate exactly as it is on the input, delimit the candidate with backticks`.", + return self.chat.send_message(msg, **parameters).text diff --git a/application/utils/spreadsheet_parsers.py b/application/utils/spreadsheet_parsers.py index 33fda0982..b13ce6808 100644 --- a/application/utils/spreadsheet_parsers.py +++ b/application/utils/spreadsheet_parsers.py @@ -571,7 +571,7 @@ def parse_standards( def suggest_from_export_format( - lfile: List[Dict[str, Any]], database: db.Node_collection + lfile: List[Dict[str, Any]], database: db.Node_collection, use_llm: bool = False ) -> Dict[str, Any]: output: List[Dict[str, Any]] = [] for line in lfile: @@ -608,20 +608,33 @@ def suggest_from_export_format( ) # find nearest CRE for standards in line ph = prompt_client.PromptHandler(database=database, load_all_embeddings=False) + cre = None + if use_llm: + most_similar_id = ph.get_id_of_most_similar_cre_using_chat(item=standard) + if not most_similar_id: + logger.warning(f"Could not find a CRE for {standard.id}") + output.append(line) + continue + c = most_similar_id.split(defs.ExportFormat.separator) + cres = database.get_CREs(name=c[1]) + if not cres: + logger.warning(f"Could not find a CRE for {standard.id}") + output.append(line) + continue + cre = cres[0] + else: + most_similar_id,_ = ph.get_id_of_most_similar_cre_paginated(item_embedding= ph.generate_embeddings_for_document(standard)) + if not most_similar_id: + logger.warning(f"Could not find a CRE for {standard.id}") + output.append(line) + continue + + cre = database.get_cre_by_db_id(most_similar_id) + if not cre: + logger.warning(f"Could not find a CRE for {standard.id}") + output.append(line) + continue - most_similar_id, _ = ph.get_id_of_most_similar_cre_paginated( - item_embedding=ph.generate_embeddings_for_document(standard) - ) - if not most_similar_id: - logger.warning(f"Could not find a CRE for {standard.id}") - output.append(line) - continue - - cre = database.get_cre_by_db_id(most_similar_id) - if not cre: - logger.warning(f"Could not find a CRE for {standard.id}") - output.append(line) - continue line[f"CRE 0"] = f"{cre.id}{defs.ExportFormat.separator}{cre.name}" # add it to the line output.append(line) diff --git a/application/web/web_main.py b/application/web/web_main.py index 293ba2d21..659a8a504 100644 --- a/application/web/web_main.py +++ b/application/web/web_main.py @@ -793,6 +793,36 @@ def suggest_from_cre_csv() -> Any: ) +@app.route("/rest/v1/cre_csv/suggest_chat", methods=["POST"]) +def suggest_from_cre_csv_using_chat() -> Any: + """Given a csv file that follows the CRE import format but has missing fields, this function will return a csv file with the missing fields filled in with suggestions. + + Returns: + Any: the csv file with the missing fields filled in with suggestions + """ + database = db.Node_collection() + file = request.files.get("cre_csv") + + if file is None: + abort(400, "No file provided") + contents = file.read() + csv_read = csv.DictReader(contents.decode("utf-8").splitlines()) + response = spreadsheet_parsers.suggest_from_export_format( + list(csv_read), database=database, use_llm=True + ) + csvVal = write_csv(docs=response).getvalue().encode("utf-8") + + # Creating the byteIO object from the StringIO Object + mem = io.BytesIO() + mem.write(csvVal) + mem.seek(0) + + return send_file( + mem, + as_attachment=True, + download_name="CRE-Catalogue.csv", + mimetype="text/csv", + ) # /End Importing Handlers