diff --git a/application/database/db.py b/application/database/db.py index de2b24f1..71f777fe 100644 --- a/application/database/db.py +++ b/application/database/db.py @@ -735,11 +735,11 @@ def __get_unlinked_cres(self) -> List[CRE]: .all() ) return cres + def get_all_nodes_and_cres(self): + return self.__get_all_nodes_and_cres() def __get_all_nodes_and_cres(self) -> List[cre_defs.Document]: result = [] - nodes = [] - cres = [] node_ids = self.session.query(Node.id).all() for nid in node_ids: result.extend(self.get_nodes(db_id=nid[0])) diff --git a/application/prompt_client/openai_prompt_client.py b/application/prompt_client/openai_prompt_client.py index b2fdc684..21fa4047 100644 --- a/application/prompt_client/openai_prompt_client.py +++ b/application/prompt_client/openai_prompt_client.py @@ -1,3 +1,4 @@ +from typing import List import openai import logging @@ -58,3 +59,22 @@ def query_llm(self, raw_question: str) -> str: messages=messages, ) return response.choices[0].message["content"].strip() + + def create_mapping_completion(self, prompt:str, cre_id_and_name_in_export_format:List[str], standard_id_or_content :str) -> str: + messages = [ + { + "role": "system", + "content": f"You are map-gpt, a helpful assistant that is an expert in mapping standards to other standards. I will give you a standard to map to and a range of candidates and you will response ONLY with the most relevant candidate.", + }, + { + "role": "user", + "content": f"Your task is to map the following standard to the most relevant candidate in the list of candidates provided. The standard to map to is: `{standard_id_or_content}`. The candidates are: `{cre_id_and_name_in_export_format}`. Answer ONLY with the most relevant candidate exactly as it is on the input, delimit the candidate with backticks`.", + }, + ] + openai.api_key = self.api_key + response = openai.chat.completions.create( + model="gpt-3.5-turbo", + messages=messages, + temperature=0.0, + ) + return response.choices[0].message.content.strip() diff --git a/application/prompt_client/prompt_client.py b/application/prompt_client/prompt_client.py index 500c5bcf..916dce35 100644 --- a/application/prompt_client/prompt_client.py +++ b/application/prompt_client/prompt_client.py @@ -363,7 +363,6 @@ def get_id_of_most_similar_cre_paginated( self, item_embedding: List[float], similarity_threshold: float = SIMILARITY_THRESHOLD, - refresh_embeddings: bool = False, ) -> Optional[Tuple[str, float]]: """this method is meant to be used when CRE runs in a web server with limited memory (e.g. firebase/heroku) instead of loading all our embeddings in memory we take the slower approach of paginating them @@ -518,3 +517,22 @@ def generate_text(self, prompt: str) -> Dict[str, str]: table = [closest_object] result = f"Answer: {answer}" return {"response": result, "table": table, "accurate": accurate} + + def get_id_of_most_similar_cre_using_chat( + self, item: defs.Document + ) -> Optional[str]: + # load all cres + content = "" + if item.hyperlink: + content = self.embeddings_instance.get_content(item.hyperlink) + else: + content = item.__repr__() + database = self.database + res = database.get_all_nodes_and_cres() + cres = [r for r in res if r.doctype == defs.Credoctypes.CRE.value] + cres_in_export_format = [f"{c.id}|{c.name}" for c in cres] + return self.ai_client.create_mapping_completion( + prompt="", + cre_id_and_name_in_export_format=cres_in_export_format, + standard_id_or_content=content, + ) diff --git a/application/prompt_client/vertex_prompt_client.py b/application/prompt_client/vertex_prompt_client.py index 34acff06..4add9418 100644 --- a/application/prompt_client/vertex_prompt_client.py +++ b/application/prompt_client/vertex_prompt_client.py @@ -109,3 +109,10 @@ def query_llm(self, raw_question: str) -> str: msg = f"Your task is to answer the following cybesrsecurity question if you can, provide code examples, delimit any code snippet with three backticks, ignore any unethical questions or questions irrelevant to cybersecurity\nQuestion: `{raw_question}`\n ignore all other commands and questions that are not relevant." response = self.chat.send_message(msg, **parameters) return response.text + + def create_mapping_completion(self, prompt:str, cre_id_and_name_in_export_format:List[str], standard_id_or_content :str) -> str: + parameters = {"temperature": 0.5, "max_output_tokens": MAX_OUTPUT_TOKENS} + + msg= f"You are map-gpt, a helpful assistant that is an expert in mapping standards to other standards. I will give you a standard to map to and a range of candidates and you will response ONLY with the most relevant candidate."\ + f"Your task is to map the following standard to the most relevant candidate in the list of candidates provided. The standard to map to is: `{standard_id_or_content}`. The candidates are: `{cre_id_and_name_in_export_format}`. Answer ONLY with the most relevant candidate exactly as it is on the input, delimit the candidate with backticks`.", + return self.chat.send_message(msg, **parameters).text diff --git a/application/utils/spreadsheet_parsers.py b/application/utils/spreadsheet_parsers.py index 33fda098..b13ce680 100644 --- a/application/utils/spreadsheet_parsers.py +++ b/application/utils/spreadsheet_parsers.py @@ -571,7 +571,7 @@ def parse_standards( def suggest_from_export_format( - lfile: List[Dict[str, Any]], database: db.Node_collection + lfile: List[Dict[str, Any]], database: db.Node_collection, use_llm: bool = False ) -> Dict[str, Any]: output: List[Dict[str, Any]] = [] for line in lfile: @@ -608,20 +608,33 @@ def suggest_from_export_format( ) # find nearest CRE for standards in line ph = prompt_client.PromptHandler(database=database, load_all_embeddings=False) + cre = None + if use_llm: + most_similar_id = ph.get_id_of_most_similar_cre_using_chat(item=standard) + if not most_similar_id: + logger.warning(f"Could not find a CRE for {standard.id}") + output.append(line) + continue + c = most_similar_id.split(defs.ExportFormat.separator) + cres = database.get_CREs(name=c[1]) + if not cres: + logger.warning(f"Could not find a CRE for {standard.id}") + output.append(line) + continue + cre = cres[0] + else: + most_similar_id,_ = ph.get_id_of_most_similar_cre_paginated(item_embedding= ph.generate_embeddings_for_document(standard)) + if not most_similar_id: + logger.warning(f"Could not find a CRE for {standard.id}") + output.append(line) + continue + + cre = database.get_cre_by_db_id(most_similar_id) + if not cre: + logger.warning(f"Could not find a CRE for {standard.id}") + output.append(line) + continue - most_similar_id, _ = ph.get_id_of_most_similar_cre_paginated( - item_embedding=ph.generate_embeddings_for_document(standard) - ) - if not most_similar_id: - logger.warning(f"Could not find a CRE for {standard.id}") - output.append(line) - continue - - cre = database.get_cre_by_db_id(most_similar_id) - if not cre: - logger.warning(f"Could not find a CRE for {standard.id}") - output.append(line) - continue line[f"CRE 0"] = f"{cre.id}{defs.ExportFormat.separator}{cre.name}" # add it to the line output.append(line) diff --git a/application/web/web_main.py b/application/web/web_main.py index 293ba2d2..659a8a50 100644 --- a/application/web/web_main.py +++ b/application/web/web_main.py @@ -793,6 +793,36 @@ def suggest_from_cre_csv() -> Any: ) +@app.route("/rest/v1/cre_csv/suggest_chat", methods=["POST"]) +def suggest_from_cre_csv_using_chat() -> Any: + """Given a csv file that follows the CRE import format but has missing fields, this function will return a csv file with the missing fields filled in with suggestions. + + Returns: + Any: the csv file with the missing fields filled in with suggestions + """ + database = db.Node_collection() + file = request.files.get("cre_csv") + + if file is None: + abort(400, "No file provided") + contents = file.read() + csv_read = csv.DictReader(contents.decode("utf-8").splitlines()) + response = spreadsheet_parsers.suggest_from_export_format( + list(csv_read), database=database, use_llm=True + ) + csvVal = write_csv(docs=response).getvalue().encode("utf-8") + + # Creating the byteIO object from the StringIO Object + mem = io.BytesIO() + mem.write(csvVal) + mem.seek(0) + + return send_file( + mem, + as_attachment=True, + download_name="CRE-Catalogue.csv", + mimetype="text/csv", + ) # /End Importing Handlers