diff --git a/pebblo/app/service/doc_helper.py b/pebblo/app/service/doc_helper.py index 06efd218..e07acd13 100644 --- a/pebblo/app/service/doc_helper.py +++ b/pebblo/app/service/doc_helper.py @@ -182,11 +182,14 @@ def _get_classifier_response(self, doc): ) try: if doc_info.data: - topics, topic_count = topic_classifier_obj.predict(doc_info.data) + topics, topic_count, topic_details = topic_classifier_obj.predict( + doc_info.data + ) ( entities, entity_count, anonymized_doc, + entity_details, ) = self.entity_classifier_obj.presidio_entity_classifier_and_anonymizer( doc_info.data, anonymize_snippets=ClassifierConstants.anonymize_snippets.value, diff --git a/pebblo/app/service/prompt_gov.py b/pebblo/app/service/prompt_gov.py index e5086beb..d143a13e 100644 --- a/pebblo/app/service/prompt_gov.py +++ b/pebblo/app/service/prompt_gov.py @@ -44,6 +44,7 @@ def _get_classifier_response(self): entities, entity_count, anonymized_doc, + entity_details, ) = self.entity_classifier_obj.presidio_entity_classifier_and_anonymizer( self.input.get("prompt"), anonymize_snippets=False, diff --git a/pebblo/app/service/prompt_service.py b/pebblo/app/service/prompt_service.py index b809f8a8..ab400288 100644 --- a/pebblo/app/service/prompt_service.py +++ b/pebblo/app/service/prompt_service.py @@ -45,6 +45,7 @@ def _fetch_classified_data(self, input_data, input_type=""): entities, entity_count, _, + _, ) = self.entity_classifier_obj.presidio_entity_classifier_and_anonymizer( input_data ) @@ -53,7 +54,9 @@ def _fetch_classified_data(self, input_data, input_type=""): # Topic classification is performed only for the response. if input_type == "response": - topics, topic_count = self.topic_classifier_obj.predict(input_data) + topics, topic_count, topic_details = self.topic_classifier_obj.predict( + input_data + ) data["topicCount"] = topic_count data["topics"] = topics diff --git a/pebblo/entity_classifier/README.md b/pebblo/entity_classifier/README.md index 34b5f833..b579d7ff 100644 --- a/pebblo/entity_classifier/README.md +++ b/pebblo/entity_classifier/README.md @@ -24,7 +24,7 @@ from pebblo.entity_classifier.entity_classifier import EntityClassifier text = entity_classifier_obj = EntityClassifier() -entities, total_count, anonymized_text = entity_classifier_obj.presidio_entity_classifier_and_anonymizer(text,anonymize_snippets) +entities, total_count, anonymized_text, entity_details = entity_classifier_obj.presidio_entity_classifier_and_anonymizer(text,anonymize_snippets) print(f"Entity Group: {entity_groups}") print(f"Entity Count: {total_entity_count}") print(f"Anonymized Text: {anonymized_text}") diff --git a/pebblo/entity_classifier/entity_classifier.py b/pebblo/entity_classifier/entity_classifier.py index 48aca284..4acea1cb 100644 --- a/pebblo/entity_classifier/entity_classifier.py +++ b/pebblo/entity_classifier/entity_classifier.py @@ -47,13 +47,8 @@ def analyze_response(self, input_text, anonymize_all_entities=True): result for result in analyzer_results if result.score >= float(ConfidenceScore.Entity.value) + and result.entity_type in self.entities ] - if not anonymize_all_entities: # Condition for anonymized document - analyzer_results = [ - result - for result in analyzer_results - if result.entity_type in self.entities - ] return analyzer_results def anonymize_response(self, analyzer_results, input_text): @@ -64,17 +59,36 @@ def anonymize_response(self, analyzer_results, input_text): return anonymized_text.items, anonymized_text.text + @staticmethod + def get_analyzed_entities_response(data, anonymized_response=None): + # Returns entities with its location i.e. start to end and confidence score + response = [] + for index, value in enumerate(data): + location = f"{value.start}_{value.end}" + if anonymized_response: + anonymized_data = anonymized_response[len(data) - index - 1] + location = f"{anonymized_data.start}_{anonymized_data.end}" + response.append( + { + "entity_type": value.entity_type, + "location": location, + "confidence_score": value.score, + } + ) + return response + def presidio_entity_classifier_and_anonymizer( self, input_text, anonymize_snippets=False ): """ Perform classification on the input data and return a dictionary with the count of each entity group. And also returns plain input text as anonymized text output - :param anonymize_snippets: Flag whether to anonymize snippets in report. :param input_text: Input string / document snippet + :param anonymize_snippets: Flag whether to anonymize snippets in report. :return: entities: containing the entity group Name as key and its count as value. total_count: Total count of entity groupsInput text in anonymized form. anonymized_text: Input text in anonymized form. + entity_details: Entities with its details such as location and confidence score. Example: input_text = " My SSN is 222-85-4836. @@ -89,21 +103,30 @@ def presidio_entity_classifier_and_anonymizer( """ entities = {} total_count = 0 - anonymized_text = "" try: logger.debug("Presidio Entity Classifier and Anonymizer Started.") analyzer_results = self.analyze_response(input_text) - anonymized_response, anonymized_text = self.anonymize_response( - analyzer_results, input_text - ) + if anonymize_snippets: # If Document snippet needs to be anonymized + anonymized_response, anonymized_text = self.anonymize_response( + analyzer_results, input_text + ) input_text = anonymized_text.replace("<", "<").replace(">", ">") - entities, total_count = get_entities(self.entities, anonymized_response) + entities_response = self.get_analyzed_entities_response( + analyzer_results, anonymized_response + ) + else: + entities_response = self.get_analyzed_entities_response( + analyzer_results + ) + entities, entity_details, total_count = get_entities( + self.entities, entities_response + ) logger.debug("Presidio Entity Classifier and Anonymizer Finished") logger.debug(f"Entities: {entities}") logger.debug(f"Entity Total count: {total_count}") - return entities, total_count, input_text + return entities, total_count, input_text, entity_details except Exception as e: logger.error( f"Presidio Entity Classifier and Anonymizer Failed, Exception: {e}" diff --git a/pebblo/entity_classifier/utils/utils.py b/pebblo/entity_classifier/utils/utils.py index 4329177f..6fd20b13 100644 --- a/pebblo/entity_classifier/utils/utils.py +++ b/pebblo/entity_classifier/utils/utils.py @@ -11,21 +11,34 @@ secret_entities_context_mapping, ) from pebblo.entity_classifier.utils.regex_pattern import regex_secrets_patterns +from pebblo.utils import get_confidence_score_label def get_entities(entities_list, response): entity_groups = dict() + entity_details = dict() + mapped_entity = None total_count = 0 for entity in response: - if entity.entity_type in entities_list: - if entity.entity_type in Entities.__members__: - mapped_entity = Entities[entity.entity_type].value - elif entity.entity_type in SecretEntities.__members__: - mapped_entity = SecretEntities[entity.entity_type].value + if entity["entity_type"] in entities_list: + if entity["entity_type"] in Entities.__members__: + mapped_entity = Entities[entity["entity_type"]].value + elif entity["entity_type"] in SecretEntities.__members__: + mapped_entity = SecretEntities[entity["entity_type"]].value entity_groups[mapped_entity] = entity_groups.get(mapped_entity, 0) + 1 + entity_data = { + "location": entity["location"], + "confidence_score": get_confidence_score_label( + entity["confidence_score"] + ), + } + if mapped_entity in entity_details.keys(): + entity_details[mapped_entity].append(entity_data) + else: + entity_details[mapped_entity] = [entity_data] total_count += 1 - return entity_groups, total_count + return entity_groups, entity_details, total_count def add_custom_regex_analyzer_registry(): diff --git a/pebblo/topic_classifier/README.md b/pebblo/topic_classifier/README.md index 7a74d87b..b6d69546 100644 --- a/pebblo/topic_classifier/README.md +++ b/pebblo/topic_classifier/README.md @@ -31,7 +31,7 @@ from pebblo.topic_classifier.topic_classifier import TopicClassifier text = "Your sample text here." topic_classifier_obj = TopicClassifier() -topics, total_topic_count = topic_classifier_obj.predict(text) +topics, total_topic_count, topic_details = topic_classifier_obj.predict(text) print(f"Topic Response: {topics}") print(f"Topic Count: {total_topic_count}") ``` diff --git a/pebblo/topic_classifier/topic_classifier.py b/pebblo/topic_classifier/topic_classifier.py index ebe0b13b..7dd9d692 100644 --- a/pebblo/topic_classifier/topic_classifier.py +++ b/pebblo/topic_classifier/topic_classifier.py @@ -18,6 +18,7 @@ TOPICS_TO_EXCLUDE, ) from pebblo.topic_classifier.enums.constants import topic_display_names +from pebblo.utils import get_confidence_score_label logger = get_logger(__name__) @@ -63,15 +64,15 @@ def predict(self, input_text): f"Text length is below {TOPIC_MIN_TEXT_LENGTH} characters. " f"Classification not performed." ) - return {}, 0 + return {}, 0, {} topic_model_response = self.classifier(input_text) - topics, total_count = self._get_topics(topic_model_response) + topics, total_count, topic_details = self._get_topics(topic_model_response) logger.debug(f"Topics: {topics}") - return topics, total_count + return topics, total_count, topic_details except Exception as e: logger.error(f"Error in topic_classifier. Exception: {e}") - return {}, 0 + return {}, 0, {} @staticmethod def _get_topics(topic_model_response): @@ -89,7 +90,17 @@ def _get_topics(topic_model_response): topics[mapped_topic] = topic["score"] final_topic = {} + topic_details = {} if len(topics) > 0: most_possible_advice = max(topics, key=lambda t: topics[t]) final_topic = {most_possible_advice: 1} - return final_topic, len(final_topic.keys()) + topic_details = { + most_possible_advice: [ + { + "confidence_score": get_confidence_score_label( + (topics[most_possible_advice]) + ) + } + ] + } + return final_topic, len(final_topic.keys()), topic_details diff --git a/pebblo/utils.py b/pebblo/utils.py new file mode 100644 index 00000000..85497b2a --- /dev/null +++ b/pebblo/utils.py @@ -0,0 +1,20 @@ +""" +Copyright (c) 2024 Cloud Defense, Inc. All rights reserved. +""" + +from enum import Enum + + +class ConfidenceScoreLabel(Enum): + HIGH = "HIGH" + MEDIUM = "MEDIUM" + LOW = "LOW" + + +def get_confidence_score_label(confidence_score): + if float(confidence_score) >= 0.8: + return ConfidenceScoreLabel.HIGH.value + elif 0.4 <= float(confidence_score) < 0.8: + return ConfidenceScoreLabel.MEDIUM.value + else: + return ConfidenceScoreLabel.LOW.value diff --git a/tests/app/service/test_prompt_gov.py b/tests/app/service/test_prompt_gov.py index 962cb16a..5710c22f 100644 --- a/tests/app/service/test_prompt_gov.py +++ b/tests/app/service/test_prompt_gov.py @@ -20,6 +20,7 @@ def test_process_request_success(mock_entity_classifier): {"us-ssn": 1}, 1, "anonymized document", + {"us-ssn": [{"location": "16_27", "confidence_score": "HIGH"}]}, ) data = {"prompt": "Sachin's SSN is 222-85-4836"} diff --git a/tests/entity_classifier/test_entity_classifier.py b/tests/entity_classifier/test_entity_classifier.py index f60f83c0..9f9f985a 100644 --- a/tests/entity_classifier/test_entity_classifier.py +++ b/tests/entity_classifier/test_entity_classifier.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import List, Tuple from unittest.mock import Mock, patch import pytest @@ -43,18 +43,6 @@ def mocked_entity_classifier_response(mocker): ) anonymize_response1: Tuple[list, str] = ( - [ - TestAnonymizerResult("PERSON"), - TestAnonymizerResult("GITHUB_TOKEN"), - TestAnonymizerResult("AWS_ACCESS_KEY"), - TestAnonymizerResult("PERSON"), - TestAnonymizerResult("US_ITIN"), - TestAnonymizerResult("US_SSN"), - ], - input_text1, - ) - - anonymize_response2: Tuple[list, str] = ( [ TestAnonymizerResult("GITHUB_TOKEN"), TestAnonymizerResult("AWS_ACCESS_KEY"), @@ -63,70 +51,150 @@ def mocked_entity_classifier_response(mocker): ], mock_input_text1_anonymize_snippet_true, ) - - anonymize_response3: Tuple[list, str] = ( - [ - TestAnonymizerResult("SLACK_TOKEN"), - TestAnonymizerResult("SLACK_TOKEN"), - TestAnonymizerResult("GITHUB_TOKEN"), - TestAnonymizerResult("AWS_SECRET_KEY"), - TestAnonymizerResult("AWS_ACCESS_KEY"), - TestAnonymizerResult("US_ITIN"), - TestAnonymizerResult("IBAN_CODE"), - TestAnonymizerResult("CREDIT_CARD"), - TestAnonymizerResult("US_SSN"), - ], - input_text2, - ) - - anonymize_response4: Tuple[list, str] = ( + anonymize_response2: Tuple[list, str] = ( [ TestAnonymizerResult("SLACK_TOKEN"), - TestAnonymizerResult("PERSON"), TestAnonymizerResult("SLACK_TOKEN"), - TestAnonymizerResult("PERSON"), - TestAnonymizerResult("PERSON"), TestAnonymizerResult("GITHUB_TOKEN"), TestAnonymizerResult("AWS_SECRET_KEY"), TestAnonymizerResult("AWS_ACCESS_KEY"), TestAnonymizerResult("US_ITIN"), TestAnonymizerResult("IBAN_CODE"), TestAnonymizerResult("CREDIT_CARD"), - TestAnonymizerResult("NRP"), - TestAnonymizerResult("PERSON"), - TestAnonymizerResult("NRP"), - TestAnonymizerResult("PERSON"), TestAnonymizerResult("US_SSN"), - TestAnonymizerResult("DATE_TIME"), - TestAnonymizerResult("PERSON"), - TestAnonymizerResult("DATE_TIME"), - TestAnonymizerResult("DATE_TIME"), - TestAnonymizerResult("DATE_TIME"), - TestAnonymizerResult("DATE_TIME"), - TestAnonymizerResult("PERSON"), ], mock_input_text2_anonymize_snippet_true, ) - - anonymize_negative_response1: Tuple[list, str] = ( + anonymize_negative_response: Tuple[list, str] = ( [], negative_data, ) - - anonymize_negative_response2: Tuple[list, str] = ( - [], - negative_data, - ) - mocker.patch( "pebblo.entity_classifier.entity_classifier.EntityClassifier.anonymize_response", side_effect=[ anonymize_response1, anonymize_response2, - anonymize_response3, - anonymize_response4, - anonymize_negative_response1, - anonymize_negative_response2, + anonymize_negative_response, + ], + ) + + analyzed_entities_response1: List[dict] = [ + {"entity_type": "US_SSN", "location": "17_28", "confidence_score": 0.85}, + {"entity_type": "US_ITIN", "location": "42_53", "confidence_score": 0.85}, + { + "entity_type": "AWS_ACCESS_KEY", + "location": "77_97", + "confidence_score": 0.8, + }, + { + "entity_type": "GITHUB_TOKEN", + "location": "120_210", + "confidence_score": 0.8, + }, + ] + analyzed_entities_response2: List[dict] = [ + {"entity_type": "US_SSN", "location": "17_25", "confidence_score": 0.85}, + {"entity_type": "US_ITIN", "location": "39_48", "confidence_score": 0.85}, + { + "entity_type": "AWS_ACCESS_KEY", + "location": "72_88", + "confidence_score": 0.8, + }, + { + "entity_type": "GITHUB_TOKEN", + "location": "111_125", + "confidence_score": 0.8, + }, + ] + analyzed_entities_response3: List[dict] = [ + { + "entity_type": "CREDIT_CARD", + "location": "1367_1382", + "confidence_score": 1.0, + }, + { + "entity_type": "IBAN_CODE", + "location": "1406_1434", + "confidence_score": 1.0, + }, + {"entity_type": "US_SSN", "location": "1178_1189", "confidence_score": 0.85}, + {"entity_type": "US_ITIN", "location": "1450_1461", "confidence_score": 0.85}, + { + "entity_type": "AWS_ACCESS_KEY", + "location": "1545_1565", + "confidence_score": 0.8, + }, + { + "entity_type": "AWS_SECRET_KEY", + "location": "1587_1628", + "confidence_score": 0.8, + }, + { + "entity_type": "GITHUB_TOKEN", + "location": "1646_1736", + "confidence_score": 0.8, + }, + { + "entity_type": "SLACK_TOKEN", + "location": "1812_1835", + "confidence_score": 0.8, + }, + { + "entity_type": "SLACK_TOKEN", + "location": "1911_1968", + "confidence_score": 0.8, + }, + ] + analyzed_entities_response4: List[dict] = [ + { + "entity_type": "CREDIT_CARD", + "location": "1178_1186", + "confidence_score": 1.0, + }, + { + "entity_type": "IBAN_CODE", + "location": "1364_1377", + "confidence_score": 1.0, + }, + {"entity_type": "US_SSN", "location": "1401_1412", "confidence_score": 0.85}, + {"entity_type": "US_ITIN", "location": "1428_1437", "confidence_score": 0.85}, + { + "entity_type": "AWS_ACCESS_KEY", + "location": "1521_1537", + "confidence_score": 0.8, + }, + { + "entity_type": "AWS_SECRET_KEY", + "location": "1559_1575", + "confidence_score": 0.8, + }, + { + "entity_type": "GITHUB_TOKEN", + "location": "1593_1607", + "confidence_score": 0.8, + }, + { + "entity_type": "SLACK_TOKEN", + "location": "1683_1696", + "confidence_score": 0.8, + }, + { + "entity_type": "SLACK_TOKEN", + "location": "1772_1785", + "confidence_score": 0.8, + }, + ] + analyzed_entities_negative_response1: List = [] + analyzed_entities_negative_response2: List = [] + mocker.patch( + "pebblo.entity_classifier.entity_classifier.EntityClassifier.get_analyzed_entities_response", + side_effect=[ + analyzed_entities_response1, + analyzed_entities_response2, + analyzed_entities_response3, + analyzed_entities_response4, + analyzed_entities_negative_response1, + analyzed_entities_negative_response2, ], ) @@ -156,6 +224,7 @@ def test_presidio_entity_classifier_and_anonymizer( entities, total_count, anonymized_text, + entity_details, ) = entity_classifier.presidio_entity_classifier_and_anonymizer(input_text1) assert entities == { "github-token": 1, @@ -165,14 +234,19 @@ def test_presidio_entity_classifier_and_anonymizer( } assert total_count == 4 assert anonymized_text == input_text1 + assert entity_details == { + "us-ssn": [{"location": "17_28", "confidence_score": "HIGH"}], + "us-itin": [{"location": "42_53", "confidence_score": "HIGH"}], + "aws-access-key": [{"location": "77_97", "confidence_score": "HIGH"}], + "github-token": [{"location": "120_210", "confidence_score": "HIGH"}], + } ( entities, total_count, anonymized_text, - ) = entity_classifier.presidio_entity_classifier_and_anonymizer( - input_text1, anonymize_snippets=True - ) + entity_details, + ) = entity_classifier.presidio_entity_classifier_and_anonymizer(input_text1, True) assert entities == { "github-token": 1, "aws-access-key": 1, @@ -181,11 +255,18 @@ def test_presidio_entity_classifier_and_anonymizer( } assert total_count == 4 assert anonymized_text == mock_input_text1_anonymize_snippet_true + assert entity_details == { + "us-ssn": [{"location": "17_25", "confidence_score": "HIGH"}], + "us-itin": [{"location": "39_48", "confidence_score": "HIGH"}], + "aws-access-key": [{"location": "72_88", "confidence_score": "HIGH"}], + "github-token": [{"location": "111_125", "confidence_score": "HIGH"}], + } ( entities, total_count, anonymized_text, + entity_details, ) = entity_classifier.presidio_entity_classifier_and_anonymizer(input_text2) assert entities == { "slack-token": 2, @@ -197,16 +278,29 @@ def test_presidio_entity_classifier_and_anonymizer( "credit-card-number": 1, "us-ssn": 1, } - assert total_count == 9 assert anonymized_text == input_text2 + assert entity_details == { + "credit-card-number": [{"location": "1367_1382", "confidence_score": "HIGH"}], + "iban-code": [{"location": "1406_1434", "confidence_score": "HIGH"}], + "us-ssn": [{"location": "1178_1189", "confidence_score": "HIGH"}], + "us-itin": [{"location": "1450_1461", "confidence_score": "HIGH"}], + "aws-access-key": [{"location": "1545_1565", "confidence_score": "HIGH"}], + "aws-secret-key": [{"location": "1587_1628", "confidence_score": "HIGH"}], + "github-token": [{"location": "1646_1736", "confidence_score": "HIGH"}], + "slack-token": [ + {"location": "1812_1835", "confidence_score": "HIGH"}, + {"location": "1911_1968", "confidence_score": "HIGH"}, + ], + } ( entities, total_count, anonymized_text, + entity_details, ) = entity_classifier.presidio_entity_classifier_and_anonymizer( - input_text1, anonymize_snippets=True + input_text2, anonymize_snippets=True ) assert entities == { "slack-token": 2, @@ -220,11 +314,25 @@ def test_presidio_entity_classifier_and_anonymizer( } assert total_count == 9 assert anonymized_text == mock_input_text2_anonymize_snippet_true + assert entity_details == { + "credit-card-number": [{"location": "1178_1186", "confidence_score": "HIGH"}], + "iban-code": [{"location": "1364_1377", "confidence_score": "HIGH"}], + "us-ssn": [{"location": "1401_1412", "confidence_score": "HIGH"}], + "us-itin": [{"location": "1428_1437", "confidence_score": "HIGH"}], + "aws-access-key": [{"location": "1521_1537", "confidence_score": "HIGH"}], + "aws-secret-key": [{"location": "1559_1575", "confidence_score": "HIGH"}], + "github-token": [{"location": "1593_1607", "confidence_score": "HIGH"}], + "slack-token": [ + {"location": "1683_1696", "confidence_score": "HIGH"}, + {"location": "1772_1785", "confidence_score": "HIGH"}, + ], + } ( entities, total_count, anonymized_text, + entity_details, ) = entity_classifier.presidio_entity_classifier_and_anonymizer(negative_data) assert entities == {} assert total_count == 0 @@ -234,6 +342,7 @@ def test_presidio_entity_classifier_and_anonymizer( entities, total_count, anonymized_text, + entity_details, ) = entity_classifier.presidio_entity_classifier_and_anonymizer( negative_data, anonymize_snippets=True ) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..2d54fe76 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,7 @@ +from pebblo.utils import get_confidence_score_label + + +def test_get_confidence_score_label(): + assert get_confidence_score_label(1.0) == "HIGH" + assert get_confidence_score_label(0.4) == "MEDIUM" + assert get_confidence_score_label(0.2) == "LOW" diff --git a/tests/topic_classifier/test_topic_classifier.py b/tests/topic_classifier/test_topic_classifier.py index 7e6a85a7..f6fd2a34 100644 --- a/tests/topic_classifier/test_topic_classifier.py +++ b/tests/topic_classifier/test_topic_classifier.py @@ -99,13 +99,14 @@ def test_predict_expected_topic(topic_classifier, mock_topic_display_names): # Setting the return value of the classifier's predict method topic_classifier.classifier = MagicMock() topic_classifier.classifier.return_value = mock_response - topics, total_count = topic_classifier.predict(input_text) + topics, total_count, topic_details = topic_classifier.predict(input_text) # Assertions assert total_count == 1 assert HARMFUL_ADVICE in topics assert topics[HARMFUL_ADVICE] == 1 - assert topics == {HARMFUL_ADVICE: 1} + assert topics == {"harmful-advice": 1} + assert topic_details == {"harmful-advice": [{"confidence_score": "MEDIUM"}]} def test_predict_low_score_topics(topic_classifier, mock_topic_display_names): @@ -121,11 +122,12 @@ def test_predict_low_score_topics(topic_classifier, mock_topic_display_names): # Setting the return value of the classifier's predict method topic_classifier.classifier = MagicMock() topic_classifier.classifier.return_value = mock_response - topics, total_count = topic_classifier.predict(input_text) + topics, total_count, topic_details = topic_classifier.predict(input_text) # Assertions assert total_count == 0 assert topics == {} + assert topic_details == {} @patch("pebblo.topic_classifier.topic_classifier.TOPIC_CONFIDENCE_SCORE", 0.4) @@ -142,13 +144,14 @@ def test_predict_confidence_score_update(topic_classifier, mock_topic_display_na # Setting the return value of the classifier's predict method topic_classifier.classifier = MagicMock() topic_classifier.classifier.return_value = mock_response - topics, total_count = topic_classifier.predict(input_text) + topics, total_count, topic_details = topic_classifier.predict(input_text) # Assertions assert total_count == 1 assert MEDICAL_ADVICE in topics assert topics[MEDICAL_ADVICE] == 1 - assert topics == {MEDICAL_ADVICE: 1} + assert topics == {"medical-advice": 1} + assert topic_details == {"medical-advice": [{"confidence_score": "MEDIUM"}]} def test_predict_empty_topics(topic_classifier): @@ -159,7 +162,7 @@ def test_predict_empty_topics(topic_classifier): # Setting the return value of the classifier's predict method topic_classifier.classifier = MagicMock() topic_classifier.classifier.return_value = mock_response - topics, total_count = topic_classifier.predict(input_text) + topics, total_count, topic_details = topic_classifier.predict(input_text) # Assertions assert topics == {} @@ -173,10 +176,11 @@ def test_predict_on_exception(topic_classifier): # Setting the return value of the classifier's predict method topic_classifier.classifier = MagicMock() topic_classifier.classifier.side_effect = Exception("Mocked exception") - topics, total_count = topic_classifier.predict(input_text) + topics, total_count, topic_details = topic_classifier.predict(input_text) assert topics == {} assert total_count == 0 + assert topic_details == {} @patch("pebblo.topic_classifier.topic_classifier.TOPIC_MIN_TEXT_LENGTH", 16) @@ -193,7 +197,7 @@ def test_predict_min_len_not_met(topic_classifier, mock_topic_display_names): # Setting the return value of the classifier's predict method topic_classifier.classifier = MagicMock() topic_classifier.classifier.return_value = mock_response - topics, total_count = topic_classifier.predict(input_text) + topics, total_count, topic_details = topic_classifier.predict(input_text) # Assertions assert topics == {} @@ -216,7 +220,7 @@ def test_predict_exclude_topics(topic_classifier, mock_topic_display_names): # Setting the return value of the classifier's predict method topic_classifier.classifier = MagicMock() topic_classifier.classifier.return_value = mock_response - topics, total_count = topic_classifier.predict(input_text) + topics, total_count, topic_details = topic_classifier.predict(input_text) # Assertions assert "HARMFUL_ADVICE" not in topics