Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggest mappings #547

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 42 additions & 20 deletions application/prompt_client/prompt_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from application.database import db
from application.defs import cre_defs
from application.prompt_client import openai_prompt_client, vertex_prompt_client
from application.defs import cre_defs as defs
from application.prompt_client import (
openai_prompt_client,
vertex_prompt_client,
spacy_prompt_client,
)
from datetime import datetime
from multiprocessing import Pool
from nltk.corpus import stopwords
Expand All @@ -25,6 +29,8 @@


def is_valid_url(url):
if not url:
return False
return url.startswith("http://") or url.startswith("https://")


Expand Down Expand Up @@ -103,9 +109,9 @@ def find_missing_embeddings(self, database: db.Node_collection) -> List[str]:
"""
logger.info(f"syncing nodes with embeddings")
missing_embeddings = []
for doc_type in cre_defs.Credoctypes:
for doc_type in defs.Credoctypes:
db_ids = []
if doc_type.value == cre_defs.Credoctypes.CRE:
if doc_type.value == defs.Credoctypes.CRE:
db_ids = [a[0] for a in database.list_cre_ids()]
else:
db_ids = [a[0] for a in database.list_node_ids_by_ntype(doc_type.value)]
Expand All @@ -128,10 +134,10 @@ def generate_embeddings_for(self, database: db.Node_collection, item_name: str):
For example if "ASVS" is passed the method will generate all embeddings for ASVS
Args:
database (db.Node_collection): the Node_collection instance to use
item_name (str): the item for which to generate embeddings, this can be either `cre_defs.Credoctypes.CRE.value` for generating all CRE embeddings or the name of any Standard or Tool.
item_name (str): the item for which to generate embeddings, this can be either `defs.Credoctypes.CRE.value` for generating all CRE embeddings or the name of any Standard or Tool.
"""
db_ids = []
if item_name == cre_defs.Credoctypes.CRE.value:
if item_name == defs.Credoctypes.CRE.value:
db_ids = [a[0] for a in database.list_cre_ids()]
else:
db_ids = [a[0] for a in database.list_node_ids_by_name(item_name)]
Expand All @@ -144,11 +150,13 @@ def generate_embeddings_for(self, database: db.Node_collection, item_name: str):
def generate_embeddings(
self, database: db.Node_collection, missing_embeddings: List[str]
):
"""method generate embeddings accepts a list of Database IDs of object which do not have embeddings and generates embeddings for those objects"""
"""
accepts a list of Database IDs of object which do not have embeddings and generates embeddings for those objects
"""
logger.info(f"generating {len(missing_embeddings)} embeddings")
for id in missing_embeddings:
cre = database.get_cre_by_db_id(id)
node = database.get_nodes(db_id=id)
node = database.get_nodes(db_id=id)[0]
content = ""
if node:
if is_valid_url(node.hyperlink):
Expand All @@ -174,9 +182,16 @@ def generate_embeddings(
if not dbcre:
logger.fatal(node, "cannot be converted to database Node")
dbcre.id = id
database.add_embedding(
dbcre, cre_defs.Credoctypes.CRE, embedding, content
)
database.add_embedding(dbcre, defs.Credoctypes.CRE, embedding, content)

def generate_embeddings_for_document(self, node: defs.Node):
content = ""
if is_valid_url(node.hyperlink):
content = self.clean_content(self.get_content(node.hyperlink))
else:
content = node.__repr__()
logger.info(f"making embedding for {node.id}")
return self.ai_client.get_text_embeddings(content)


class PromptHandler:
Expand All @@ -197,8 +212,9 @@ def __init__(self, database: db.Node_collection, load_all_embeddings=False) -> N
os.getenv("OPENAI_API_KEY")
)
else:
logger.error(
"cannot instantiate ai client, neither OPENAI_API_KEY nor SERVICE_ACCOUNT_CREDENTIALS are set "
self.ai_client = spacy_prompt_client.SpacyPromptClient()
logger.info(
"cannot instantiate ai client, neither OPENAI_API_KEY nor SERVICE_ACCOUNT_CREDENTIALS are set, using spacy "
)
self.database = database
self.embeddings_instance = in_memory_embeddings.instance().with_ai_client(
Expand All @@ -219,6 +235,12 @@ def __init__(self, database: db.Node_collection, load_all_embeddings=False) -> N
f"there are {len(missing_embeddings)} embeddings missing from the dataset, db inclompete"
)

def generate_embeddings_for_document(self, node: defs.Node):
self.embeddings_instance.setup_playwright()
embeddings = self.embeddings_instance.generate_embeddings_for_document(node)
self.embeddings_instance.teardown_playwright()
return embeddings

def generate_embeddings_for(self, item_name: str):
self.embeddings_instance.setup_playwright()
self.embeddings_instance.generate_embeddings_for(self.database, item_name)
Expand Down Expand Up @@ -277,7 +299,7 @@ def get_id_of_most_similar_cre(self, item_embedding: List[float]) -> Optional[st
self.existing_cre_embeddings,
self.existing_cre_ids,
) = self.__load_cre_embeddings(
self.database.get_embeddings_by_doc_type(cre_defs.Credoctypes.CRE.value)
self.database.get_embeddings_by_doc_type(defs.Credoctypes.CRE.value)
)
if not self.existing_cre_embeddings.getnnz() or not len(self.existing_cre_ids):
raise ValueError(
Expand Down Expand Up @@ -316,7 +338,7 @@ def get_id_of_most_similar_node(self, standard_text_embedding: List[float]) -> s
self.existing_node_ids,
) = self.__load_node_embeddings(
self.database.get_embeddings_by_doc_type(
cre_defs.Credoctypes.Standard.value
defs.Credoctypes.Standard.value
)
)
if not self.existing_node_embeddings.getnnz() or not len(
Expand All @@ -341,6 +363,7 @@ def get_id_of_most_similar_cre_paginated(
self,
item_embedding: List[float],
similarity_threshold: float = SIMILARITY_THRESHOLD,
refresh_embeddings: bool = False,
) -> Optional[Tuple[str, float]]:
"""this method is meant to be used when CRE runs in a web server with limited memory (e.g. firebase/heroku)
instead of loading all our embeddings in memory we take the slower approach of paginating them
Expand All @@ -354,13 +377,12 @@ def get_id_of_most_similar_cre_paginated(
embedding_array = sparse.csr_matrix(
np.array(item_embedding).reshape(1, -1)
) # convert embedding into a 1-dimentional numpy array

(
embeddings,
total_pages,
starting_page,
) = self.database.get_embeddings_by_doc_type_paginated(
cre_defs.Credoctypes.CRE.value
defs.Credoctypes.CRE.value
)
max_similarity = -1
most_similar_index = 0
Expand All @@ -378,7 +400,7 @@ def get_id_of_most_similar_cre_paginated(
total_pages,
_,
) = self.database.get_embeddings_by_doc_type_paginated(
cre_defs.Credoctypes.CRE.value, page=page
defs.Credoctypes.CRE.value, page=page
)

if max_similarity < similarity_threshold:
Expand Down Expand Up @@ -411,7 +433,7 @@ def get_id_of_most_similar_node_paginated(
total_pages,
starting_page,
) = self.database.get_embeddings_by_doc_type_paginated(
doc_type=cre_defs.Credoctypes.Standard.value,
doc_type=defs.Credoctypes.Standard.value,
page=1,
)

Expand All @@ -429,7 +451,7 @@ def get_id_of_most_similar_node_paginated(
most_similar_id = existing_standard_ids[most_similar_index]

embeddings, _, _ = self.database.get_embeddings_by_doc_type_paginated(
doc_type=cre_defs.Credoctypes.Standard.value, page=page
doc_type=defs.Credoctypes.Standard.value, page=page
)
if max_similarity < similarity_threshold:
logger.info(
Expand Down
35 changes: 35 additions & 0 deletions application/prompt_client/spacy_prompt_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import spacy
import logging

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class SpacyPromptClient:

def __init__(self) -> None:
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
logger.info(
"Downloading language model for the spaCy POS tagger\n"
"(don't worry, this will only happen once)"
)
from spacy.cli import download

download("en_core_web_sm")
self.nlp = spacy.load("en_core_web_sm")

def get_text_embeddings(self, text: str):
return self.nlp(text).vector

def create_chat_completion(self, prompt, closest_object_str) -> str:
raise NotImplementedError(
"Spacy does not support chat completion you need to set up a different client if you need this functionality"
)

def query_llm(self, raw_question: str) -> str:
raise NotImplementedError(
"Spacy does not support chat completion you need to set up a different client if you need this functionality"
)
44 changes: 44 additions & 0 deletions application/tests/spreadsheet_parsers_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
from pprint import pprint
import unittest
from application.database import db
from application.tests.utils import data_gen
from application.defs import cre_defs as defs
from application import create_app, sqla # type: ignore
from application.utils.spreadsheet_parsers import (
parse_export_format,
parse_hierarchical_export_format,
suggest_from_export_format,
)


Expand Down Expand Up @@ -37,6 +40,47 @@ def test_parse_hierarchical_export_format(self) -> None:
for element in v:
self.assertIn(element, output[k])

def test_suggest_from_export_format(self) -> None:
self.app = create_app(mode="test")
self.app_context = self.app.app_context()
self.app_context.push()
sqla.create_all()
collection = db.Node_collection()

input_data, expected_output = data_gen.export_format_data()
for cre in expected_output[defs.Credoctypes.CRE.value]:
collection.add_cre(cre=cre)

# clean every other cre
index = 0
input_data_no_cres = []
for line in input_data:
no_cre_line = line.copy()
if index % 2 == 0:
[no_cre_line.pop(key) for key in line.keys() if key.startswith("CRE")]
index += 1
input_data_no_cres.append(no_cre_line)
output = suggest_from_export_format(
lfile=input_data_no_cres, database=collection
)
self.maxDiff = None

empty_lines = 0
for line in output:
cres_in_line = [
line[c] for c in line.keys() if c.startswith("CRE") and line[c]
]
if len(cres_in_line) == 0:
empty_lines += 1

self.assertGreater(
len(input_data) / 2, empty_lines
) # assert that there was at least some suggestions

sqla.session.remove()
sqla.drop_all()
self.app_context.pop()


if __name__ == "__main__":
unittest.main()
54 changes: 54 additions & 0 deletions application/tests/web_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,3 +931,57 @@ def test_get_cre_csv(self) -> None:
data.getvalue(),
response.data.decode(),
)

def test_suggest_from_cre_csv(self) -> None:
# empty string means temporary db
# self.app = create_app(mode="test")
# self.app_context = self.app.app_context()
# self.app_context.push()
# sqla.create_all()
collection = db.Node_collection()

input_data, expected_output = data_gen.export_format_data()
for cre in expected_output[defs.Credoctypes.CRE.value]:
collection.add_cre(cre=cre)

# clean every other cre
index = 0
input_data_no_cres = []
keys = {}
for line in input_data:
keys.update(line)
no_cre_line = line.copy()
if index % 2 == 0:
[no_cre_line.pop(key) for key in line.keys() if key.startswith("CRE")]
index += 1
input_data_no_cres.append(no_cre_line)

workspace = tempfile.mkdtemp()
data = {}
with open(os.path.join(workspace, "cre.csv"), "w") as f:
cdw = csv.DictWriter(f, fieldnames=keys.keys())
cdw.writeheader()
cdw.writerows(input_data_no_cres)

data["cre_csv"] = open(os.path.join(workspace, "cre.csv"), "rb")

with self.app.test_client() as client:
response = client.post(
"/rest/v1/cre_csv/suggest",
data=data,
buffered=True,
content_type="multipart/form-data",
)
self.assertEqual(200, response.status_code)
empty_lines = 0

pprint(response.data.decode())
input()

for line in json.loads(response.data.decode()):
cres_in_line = [
line[c] for c in line.keys() if c.startswith("CRE") and line[c]
]
if len(cres_in_line) == 0:
empty_lines += 1
self.assertGreater(len(input_data_no_cres) / 2, empty_lines)
Loading
Loading