diff --git a/application/prompt_client/prompt_client.py b/application/prompt_client/prompt_client.py index c3838298..500c5bcf 100644 --- a/application/prompt_client/prompt_client.py +++ b/application/prompt_client/prompt_client.py @@ -1,6 +1,10 @@ from application.database import db -from application.defs import cre_defs -from application.prompt_client import openai_prompt_client, vertex_prompt_client +from application.defs import cre_defs as defs +from application.prompt_client import ( + openai_prompt_client, + vertex_prompt_client, + spacy_prompt_client, +) from datetime import datetime from multiprocessing import Pool from nltk.corpus import stopwords @@ -25,6 +29,8 @@ def is_valid_url(url): + if not url: + return False return url.startswith("http://") or url.startswith("https://") @@ -103,9 +109,9 @@ def find_missing_embeddings(self, database: db.Node_collection) -> List[str]: """ logger.info(f"syncing nodes with embeddings") missing_embeddings = [] - for doc_type in cre_defs.Credoctypes: + for doc_type in defs.Credoctypes: db_ids = [] - if doc_type.value == cre_defs.Credoctypes.CRE: + if doc_type.value == defs.Credoctypes.CRE: db_ids = [a[0] for a in database.list_cre_ids()] else: db_ids = [a[0] for a in database.list_node_ids_by_ntype(doc_type.value)] @@ -128,10 +134,10 @@ def generate_embeddings_for(self, database: db.Node_collection, item_name: str): For example if "ASVS" is passed the method will generate all embeddings for ASVS Args: database (db.Node_collection): the Node_collection instance to use - item_name (str): the item for which to generate embeddings, this can be either `cre_defs.Credoctypes.CRE.value` for generating all CRE embeddings or the name of any Standard or Tool. + item_name (str): the item for which to generate embeddings, this can be either `defs.Credoctypes.CRE.value` for generating all CRE embeddings or the name of any Standard or Tool. """ db_ids = [] - if item_name == cre_defs.Credoctypes.CRE.value: + if item_name == defs.Credoctypes.CRE.value: db_ids = [a[0] for a in database.list_cre_ids()] else: db_ids = [a[0] for a in database.list_node_ids_by_name(item_name)] @@ -144,11 +150,13 @@ def generate_embeddings_for(self, database: db.Node_collection, item_name: str): def generate_embeddings( self, database: db.Node_collection, missing_embeddings: List[str] ): - """method generate embeddings accepts a list of Database IDs of object which do not have embeddings and generates embeddings for those objects""" + """ + accepts a list of Database IDs of object which do not have embeddings and generates embeddings for those objects + """ logger.info(f"generating {len(missing_embeddings)} embeddings") for id in missing_embeddings: cre = database.get_cre_by_db_id(id) - node = database.get_nodes(db_id=id) + node = database.get_nodes(db_id=id)[0] content = "" if node: if is_valid_url(node.hyperlink): @@ -174,9 +182,16 @@ def generate_embeddings( if not dbcre: logger.fatal(node, "cannot be converted to database Node") dbcre.id = id - database.add_embedding( - dbcre, cre_defs.Credoctypes.CRE, embedding, content - ) + database.add_embedding(dbcre, defs.Credoctypes.CRE, embedding, content) + + def generate_embeddings_for_document(self, node: defs.Node): + content = "" + if is_valid_url(node.hyperlink): + content = self.clean_content(self.get_content(node.hyperlink)) + else: + content = node.__repr__() + logger.info(f"making embedding for {node.id}") + return self.ai_client.get_text_embeddings(content) class PromptHandler: @@ -197,8 +212,9 @@ def __init__(self, database: db.Node_collection, load_all_embeddings=False) -> N os.getenv("OPENAI_API_KEY") ) else: - logger.error( - "cannot instantiate ai client, neither OPENAI_API_KEY nor SERVICE_ACCOUNT_CREDENTIALS are set " + self.ai_client = spacy_prompt_client.SpacyPromptClient() + logger.info( + "cannot instantiate ai client, neither OPENAI_API_KEY nor SERVICE_ACCOUNT_CREDENTIALS are set, using spacy " ) self.database = database self.embeddings_instance = in_memory_embeddings.instance().with_ai_client( @@ -219,6 +235,12 @@ def __init__(self, database: db.Node_collection, load_all_embeddings=False) -> N f"there are {len(missing_embeddings)} embeddings missing from the dataset, db inclompete" ) + def generate_embeddings_for_document(self, node: defs.Node): + self.embeddings_instance.setup_playwright() + embeddings = self.embeddings_instance.generate_embeddings_for_document(node) + self.embeddings_instance.teardown_playwright() + return embeddings + def generate_embeddings_for(self, item_name: str): self.embeddings_instance.setup_playwright() self.embeddings_instance.generate_embeddings_for(self.database, item_name) @@ -277,7 +299,7 @@ def get_id_of_most_similar_cre(self, item_embedding: List[float]) -> Optional[st self.existing_cre_embeddings, self.existing_cre_ids, ) = self.__load_cre_embeddings( - self.database.get_embeddings_by_doc_type(cre_defs.Credoctypes.CRE.value) + self.database.get_embeddings_by_doc_type(defs.Credoctypes.CRE.value) ) if not self.existing_cre_embeddings.getnnz() or not len(self.existing_cre_ids): raise ValueError( @@ -316,7 +338,7 @@ def get_id_of_most_similar_node(self, standard_text_embedding: List[float]) -> s self.existing_node_ids, ) = self.__load_node_embeddings( self.database.get_embeddings_by_doc_type( - cre_defs.Credoctypes.Standard.value + defs.Credoctypes.Standard.value ) ) if not self.existing_node_embeddings.getnnz() or not len( @@ -341,6 +363,7 @@ def get_id_of_most_similar_cre_paginated( self, item_embedding: List[float], similarity_threshold: float = SIMILARITY_THRESHOLD, + refresh_embeddings: bool = False, ) -> Optional[Tuple[str, float]]: """this method is meant to be used when CRE runs in a web server with limited memory (e.g. firebase/heroku) instead of loading all our embeddings in memory we take the slower approach of paginating them @@ -354,13 +377,12 @@ def get_id_of_most_similar_cre_paginated( embedding_array = sparse.csr_matrix( np.array(item_embedding).reshape(1, -1) ) # convert embedding into a 1-dimentional numpy array - ( embeddings, total_pages, starting_page, ) = self.database.get_embeddings_by_doc_type_paginated( - cre_defs.Credoctypes.CRE.value + defs.Credoctypes.CRE.value ) max_similarity = -1 most_similar_index = 0 @@ -378,7 +400,7 @@ def get_id_of_most_similar_cre_paginated( total_pages, _, ) = self.database.get_embeddings_by_doc_type_paginated( - cre_defs.Credoctypes.CRE.value, page=page + defs.Credoctypes.CRE.value, page=page ) if max_similarity < similarity_threshold: @@ -411,7 +433,7 @@ def get_id_of_most_similar_node_paginated( total_pages, starting_page, ) = self.database.get_embeddings_by_doc_type_paginated( - doc_type=cre_defs.Credoctypes.Standard.value, + doc_type=defs.Credoctypes.Standard.value, page=1, ) @@ -429,7 +451,7 @@ def get_id_of_most_similar_node_paginated( most_similar_id = existing_standard_ids[most_similar_index] embeddings, _, _ = self.database.get_embeddings_by_doc_type_paginated( - doc_type=cre_defs.Credoctypes.Standard.value, page=page + doc_type=defs.Credoctypes.Standard.value, page=page ) if max_similarity < similarity_threshold: logger.info( diff --git a/application/prompt_client/spacy_prompt_client.py b/application/prompt_client/spacy_prompt_client.py new file mode 100644 index 00000000..17031169 --- /dev/null +++ b/application/prompt_client/spacy_prompt_client.py @@ -0,0 +1,35 @@ +import spacy +import logging + +logging.basicConfig() +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class SpacyPromptClient: + + def __init__(self) -> None: + try: + self.nlp = spacy.load("en_core_web_sm") + except OSError: + logger.info( + "Downloading language model for the spaCy POS tagger\n" + "(don't worry, this will only happen once)" + ) + from spacy.cli import download + + download("en_core_web_sm") + self.nlp = spacy.load("en_core_web_sm") + + def get_text_embeddings(self, text: str): + return self.nlp(text).vector + + def create_chat_completion(self, prompt, closest_object_str) -> str: + raise NotImplementedError( + "Spacy does not support chat completion you need to set up a different client if you need this functionality" + ) + + def query_llm(self, raw_question: str) -> str: + raise NotImplementedError( + "Spacy does not support chat completion you need to set up a different client if you need this functionality" + ) diff --git a/application/tests/spreadsheet_parsers_test.py b/application/tests/spreadsheet_parsers_test.py index 800f1b5c..3511f350 100644 --- a/application/tests/spreadsheet_parsers_test.py +++ b/application/tests/spreadsheet_parsers_test.py @@ -1,11 +1,14 @@ import json from pprint import pprint import unittest +from application.database import db from application.tests.utils import data_gen from application.defs import cre_defs as defs +from application import create_app, sqla # type: ignore from application.utils.spreadsheet_parsers import ( parse_export_format, parse_hierarchical_export_format, + suggest_from_export_format, ) @@ -37,6 +40,47 @@ def test_parse_hierarchical_export_format(self) -> None: for element in v: self.assertIn(element, output[k]) + def test_suggest_from_export_format(self) -> None: + self.app = create_app(mode="test") + self.app_context = self.app.app_context() + self.app_context.push() + sqla.create_all() + collection = db.Node_collection() + + input_data, expected_output = data_gen.export_format_data() + for cre in expected_output[defs.Credoctypes.CRE.value]: + collection.add_cre(cre=cre) + + # clean every other cre + index = 0 + input_data_no_cres = [] + for line in input_data: + no_cre_line = line.copy() + if index % 2 == 0: + [no_cre_line.pop(key) for key in line.keys() if key.startswith("CRE")] + index += 1 + input_data_no_cres.append(no_cre_line) + output = suggest_from_export_format( + lfile=input_data_no_cres, database=collection + ) + self.maxDiff = None + + empty_lines = 0 + for line in output: + cres_in_line = [ + line[c] for c in line.keys() if c.startswith("CRE") and line[c] + ] + if len(cres_in_line) == 0: + empty_lines += 1 + + self.assertGreater( + len(input_data) / 2, empty_lines + ) # assert that there was at least some suggestions + + sqla.session.remove() + sqla.drop_all() + self.app_context.pop() + if __name__ == "__main__": unittest.main() diff --git a/application/tests/web_main_test.py b/application/tests/web_main_test.py index 0f79e68d..69f0936f 100644 --- a/application/tests/web_main_test.py +++ b/application/tests/web_main_test.py @@ -931,3 +931,57 @@ def test_get_cre_csv(self) -> None: data.getvalue(), response.data.decode(), ) + + def test_suggest_from_cre_csv(self) -> None: + # empty string means temporary db + # self.app = create_app(mode="test") + # self.app_context = self.app.app_context() + # self.app_context.push() + # sqla.create_all() + collection = db.Node_collection() + + input_data, expected_output = data_gen.export_format_data() + for cre in expected_output[defs.Credoctypes.CRE.value]: + collection.add_cre(cre=cre) + + # clean every other cre + index = 0 + input_data_no_cres = [] + keys = {} + for line in input_data: + keys.update(line) + no_cre_line = line.copy() + if index % 2 == 0: + [no_cre_line.pop(key) for key in line.keys() if key.startswith("CRE")] + index += 1 + input_data_no_cres.append(no_cre_line) + + workspace = tempfile.mkdtemp() + data = {} + with open(os.path.join(workspace, "cre.csv"), "w") as f: + cdw = csv.DictWriter(f, fieldnames=keys.keys()) + cdw.writeheader() + cdw.writerows(input_data_no_cres) + + data["cre_csv"] = open(os.path.join(workspace, "cre.csv"), "rb") + + with self.app.test_client() as client: + response = client.post( + "/rest/v1/cre_csv/suggest", + data=data, + buffered=True, + content_type="multipart/form-data", + ) + self.assertEqual(200, response.status_code) + empty_lines = 0 + + pprint(response.data.decode()) + input() + + for line in json.loads(response.data.decode()): + cres_in_line = [ + line[c] for c in line.keys() if c.startswith("CRE") and line[c] + ] + if len(cres_in_line) == 0: + empty_lines += 1 + self.assertGreater(len(input_data_no_cres) / 2, empty_lines) diff --git a/application/utils/spreadsheet_parsers.py b/application/utils/spreadsheet_parsers.py index ade14414..33fda098 100644 --- a/application/utils/spreadsheet_parsers.py +++ b/application/utils/spreadsheet_parsers.py @@ -4,8 +4,9 @@ from copy import copy from typing import Any, Dict, List, Optional from dataclasses import dataclass - +from application.prompt_client import prompt_client from application.defs import cre_defs as defs +from application.database import db # collection of methods to parse different versions of spreadsheet standards # each method returns a list of cre_defs documents @@ -567,3 +568,61 @@ def parse_standards( ) ) return links + + +def suggest_from_export_format( + lfile: List[Dict[str, Any]], database: db.Node_collection +) -> Dict[str, Any]: + output: List[Dict[str, Any]] = [] + for line in lfile: + standard: defs.Node = None + if any( + [ + entry.startswith("CRE ") + for entry, value in line.items() + if not is_empty(value) + ] + ): # we found a mapping in the line, no need to do anything, flush to buffer + output.append(line) + break + for entry, value in line.items(): + if entry.startswith("CRE "): + continue # we established above there are no CRE entries in this line + + if not is_empty(value): + standard_name = entry.split("|")[0] + standard = defs.Standard( + name=standard_name, + sectionID=line.get( + f"{standard_name}{defs.ExportFormat.separator}{defs.ExportFormat.id}" + ), + section=line.get( + f"{standard_name}{defs.ExportFormat.separator}{defs.ExportFormat.section}" + ), + hyperlink=line.get( + f"{standard_name}{defs.ExportFormat.separator}{defs.ExportFormat.hyperlink}" + ), + description=line.get( + f"{standard_name}{defs.ExportFormat.separator}{defs.ExportFormat.description}" + ), + ) + # find nearest CRE for standards in line + ph = prompt_client.PromptHandler(database=database, load_all_embeddings=False) + + most_similar_id, _ = ph.get_id_of_most_similar_cre_paginated( + item_embedding=ph.generate_embeddings_for_document(standard) + ) + if not most_similar_id: + logger.warning(f"Could not find a CRE for {standard.id}") + output.append(line) + continue + + cre = database.get_cre_by_db_id(most_similar_id) + if not cre: + logger.warning(f"Could not find a CRE for {standard.id}") + output.append(line) + continue + line[f"CRE 0"] = f"{cre.id}{defs.ExportFormat.separator}{cre.name}" + # add it to the line + output.append(line) + return output diff --git a/application/web/web_main.py b/application/web/web_main.py index 2ac44d05..293ba2d2 100644 --- a/application/web/web_main.py +++ b/application/web/web_main.py @@ -713,6 +713,7 @@ def get_cre_csv() -> Any: @app.route("/rest/v1/cre_csv_import", methods=["POST"]) +@app.route("/rest/v1/cre_csv/import", methods=["POST"]) def import_from_cre_csv() -> Any: if not os.environ.get("CRE_ALLOW_IMPORT"): abort( @@ -760,6 +761,38 @@ def import_from_cre_csv() -> Any: ) +@app.route("/rest/v1/cre_csv/suggest", methods=["POST"]) +def suggest_from_cre_csv() -> Any: + """Given a csv file that follows the CRE import format but has missing fields, this function will return a csv file with the missing fields filled in with suggestions. + + Returns: + Any: the csv file with the missing fields filled in with suggestions + """ + database = db.Node_collection() + file = request.files.get("cre_csv") + + if file is None: + abort(400, "No file provided") + contents = file.read() + csv_read = csv.DictReader(contents.decode("utf-8").splitlines()) + response = spreadsheet_parsers.suggest_from_export_format( + list(csv_read), database=database + ) + csvVal = write_csv(docs=response).getvalue().encode("utf-8") + + # Creating the byteIO object from the StringIO Object + mem = io.BytesIO() + mem.write(csvVal) + mem.seek(0) + + return send_file( + mem, + as_attachment=True, + download_name="CRE-Catalogue.csv", + mimetype="text/csv", + ) + + # /End Importing Handlers diff --git a/requirements.txt b/requirements.txt index c4e004db..79757391 100644 --- a/requirements.txt +++ b/requirements.txt @@ -111,4 +111,5 @@ urllib3 vertexai xmltodict google-cloud-trace -alive-progress \ No newline at end of file +alive-progress +spacy