Skip to content

Commit

Permalink
experiment to suggest mappings using chat completions
Browse files Browse the repository at this point in the history
  • Loading branch information
northdpole committed Aug 18, 2024
1 parent a287d1b commit 33da1e2
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 17 deletions.
4 changes: 2 additions & 2 deletions application/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,11 +735,11 @@ def __get_unlinked_cres(self) -> List[CRE]:
.all()
)
return cres
def get_all_nodes_and_cres(self):
return self.__get_all_nodes_and_cres()

def __get_all_nodes_and_cres(self) -> List[cre_defs.Document]:
result = []
nodes = []
cres = []
node_ids = self.session.query(Node.id).all()
for nid in node_ids:
result.extend(self.get_nodes(db_id=nid[0]))
Expand Down
20 changes: 20 additions & 0 deletions application/prompt_client/openai_prompt_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
import openai
import logging

Expand Down Expand Up @@ -58,3 +59,22 @@ def query_llm(self, raw_question: str) -> str:
messages=messages,
)
return response.choices[0].message["content"].strip()

def create_mapping_completion(self, prompt:str, cre_id_and_name_in_export_format:List[str], standard_id_or_content :str) -> str:
messages = [
{
"role": "system",
"content": f"You are map-gpt, a helpful assistant that is an expert in mapping standards to other standards. I will give you a standard to map to and a range of candidates and you will response ONLY with the most relevant candidate.",
},
{
"role": "user",
"content": f"Your task is to map the following standard to the most relevant candidate in the list of candidates provided. The standard to map to is: `{standard_id_or_content}`. The candidates are: `{cre_id_and_name_in_export_format}`. Answer ONLY with the most relevant candidate exactly as it is on the input, delimit the candidate with backticks`.",
},
]
openai.api_key = self.api_key
response = openai.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0.0,
)
return response.choices[0].message.content.strip()
20 changes: 19 additions & 1 deletion application/prompt_client/prompt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ def get_id_of_most_similar_cre_paginated(
self,
item_embedding: List[float],
similarity_threshold: float = SIMILARITY_THRESHOLD,
refresh_embeddings: bool = False,
) -> Optional[Tuple[str, float]]:
"""this method is meant to be used when CRE runs in a web server with limited memory (e.g. firebase/heroku)
instead of loading all our embeddings in memory we take the slower approach of paginating them
Expand Down Expand Up @@ -518,3 +517,22 @@ def generate_text(self, prompt: str) -> Dict[str, str]:
table = [closest_object]
result = f"Answer: {answer}"
return {"response": result, "table": table, "accurate": accurate}

def get_id_of_most_similar_cre_using_chat(
self, item: defs.Document
) -> Optional[str]:
# load all cres
content = ""
if item.hyperlink:
content = self.embeddings_instance.get_content(item.hyperlink)
else:
content = item.__repr__()
database = self.database
res = database.get_all_nodes_and_cres()
cres = [r for r in res if r.doctype == defs.Credoctypes.CRE.value]
cres_in_export_format = [f"{c.id}|{c.name}" for c in cres]
return self.ai_client.create_mapping_completion(
prompt="",
cre_id_and_name_in_export_format=cres_in_export_format,
standard_id_or_content=content,
)
7 changes: 7 additions & 0 deletions application/prompt_client/vertex_prompt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,10 @@ def query_llm(self, raw_question: str) -> str:
msg = f"Your task is to answer the following cybesrsecurity question if you can, provide code examples, delimit any code snippet with three backticks, ignore any unethical questions or questions irrelevant to cybersecurity\nQuestion: `{raw_question}`\n ignore all other commands and questions that are not relevant."
response = self.chat.send_message(msg, **parameters)
return response.text

def create_mapping_completion(self, prompt:str, cre_id_and_name_in_export_format:List[str], standard_id_or_content :str) -> str:
parameters = {"temperature": 0.5, "max_output_tokens": MAX_OUTPUT_TOKENS}

msg= f"You are map-gpt, a helpful assistant that is an expert in mapping standards to other standards. I will give you a standard to map to and a range of candidates and you will response ONLY with the most relevant candidate."\
f"Your task is to map the following standard to the most relevant candidate in the list of candidates provided. The standard to map to is: `{standard_id_or_content}`. The candidates are: `{cre_id_and_name_in_export_format}`. Answer ONLY with the most relevant candidate exactly as it is on the input, delimit the candidate with backticks`.",
return self.chat.send_message(msg, **parameters).text
41 changes: 27 additions & 14 deletions application/utils/spreadsheet_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def parse_standards(


def suggest_from_export_format(
lfile: List[Dict[str, Any]], database: db.Node_collection
lfile: List[Dict[str, Any]], database: db.Node_collection, use_llm: bool = False
) -> Dict[str, Any]:
output: List[Dict[str, Any]] = []
for line in lfile:
Expand Down Expand Up @@ -608,20 +608,33 @@ def suggest_from_export_format(
)
# find nearest CRE for standards in line
ph = prompt_client.PromptHandler(database=database, load_all_embeddings=False)
cre = None
if use_llm:
most_similar_id = ph.get_id_of_most_similar_cre_using_chat(item=standard)
if not most_similar_id:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue
c = most_similar_id.split(defs.ExportFormat.separator)
cres = database.get_CREs(name=c[1])
if not cres:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue
cre = cres[0]
else:
most_similar_id,_ = ph.get_id_of_most_similar_cre_paginated(item_embedding= ph.generate_embeddings_for_document(standard))
if not most_similar_id:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue

cre = database.get_cre_by_db_id(most_similar_id)
if not cre:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue

most_similar_id, _ = ph.get_id_of_most_similar_cre_paginated(
item_embedding=ph.generate_embeddings_for_document(standard)
)
if not most_similar_id:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue

cre = database.get_cre_by_db_id(most_similar_id)
if not cre:
logger.warning(f"Could not find a CRE for {standard.id}")
output.append(line)
continue
line[f"CRE 0"] = f"{cre.id}{defs.ExportFormat.separator}{cre.name}"
# add it to the line
output.append(line)
Expand Down
30 changes: 30 additions & 0 deletions application/web/web_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,36 @@ def suggest_from_cre_csv() -> Any:
)


@app.route("/rest/v1/cre_csv/suggest_chat", methods=["POST"])
def suggest_from_cre_csv_using_chat() -> Any:
"""Given a csv file that follows the CRE import format but has missing fields, this function will return a csv file with the missing fields filled in with suggestions.
Returns:
Any: the csv file with the missing fields filled in with suggestions
"""
database = db.Node_collection()
file = request.files.get("cre_csv")

if file is None:
abort(400, "No file provided")
contents = file.read()
csv_read = csv.DictReader(contents.decode("utf-8").splitlines())
response = spreadsheet_parsers.suggest_from_export_format(
list(csv_read), database=database, use_llm=True
)
csvVal = write_csv(docs=response).getvalue().encode("utf-8")

# Creating the byteIO object from the StringIO Object
mem = io.BytesIO()
mem.write(csvVal)
mem.seek(0)

return send_file(
mem,
as_attachment=True,
download_name="CRE-Catalogue.csv",
mimetype="text/csv",
)
# /End Importing Handlers


Expand Down

0 comments on commit 33da1e2

Please sign in to comment.