diff --git a/pebblo/entity_classifier/entity_classifier.py b/pebblo/entity_classifier/entity_classifier.py index 4acea1cb..60346b45 100644 --- a/pebblo/entity_classifier/entity_classifier.py +++ b/pebblo/entity_classifier/entity_classifier.py @@ -6,6 +6,7 @@ ConfidenceScore, Entities, SecretEntities, + entity_conf_mapping, ) from pebblo.entity_classifier.utils.utils import ( add_custom_regex_analyzer_registry, @@ -41,15 +42,52 @@ def custom_analyze(self): ) def analyze_response(self, input_text, anonymize_all_entities=True): - # Returns analyzed output + """ + Analyze the given input text to detect and classify entities based on predefined criteria. + + Args: + input_text (str): The text to be analyzed for detecting entities. + anonymize_all_entities (bool): Flag to determine if all detected entities should be anonymized. + (Currently not used in the function logic.) + + Returns: + list: A list of detected entities that meet the criteria for classification. + """ + # Analyze the text to detect entities using the Presidio analyzer analyzer_results = self.analyzer.analyze(text=input_text, language="en") - analyzer_results = [ - result - for result in analyzer_results - if result.score >= float(ConfidenceScore.Entity.value) - and result.entity_type in self.entities - ] - return analyzer_results + + # Initialize the list to hold the final classified entities + final_results = [] + + # Iterate through the detected entities + for entity in analyzer_results: + try: + mapped_entity = None + + # Map entity type to predefined entities if it exists in the Entities enumeration + if entity.entity_type in Entities.__members__: + mapped_entity = Entities[entity.entity_type].value + + # Check if the entity type exists in SecretEntities enumeration + elif entity.entity_type in SecretEntities.__members__: + mapped_entity = SecretEntities[entity.entity_type].value + + # Append entity to final results if it meets the confidence threshold and is in the desired entities list + if ( + mapped_entity + and entity.score >= float(entity_conf_mapping[mapped_entity]) + and entity.entity_type in self.entities + ): + final_results.append(entity) + + # Handle any exceptions that occur during entity classification + except Exception as ex: + logger.warning( + f"Error in analyze_response in entity classification. {str(ex)}" + ) + + # Return the list of classified entities that met the criteria + return final_results def anonymize_response(self, analyzer_results, input_text): # Returns anonymized output @@ -63,6 +101,7 @@ def anonymize_response(self, analyzer_results, input_text): def get_analyzed_entities_response(data, anonymized_response=None): # Returns entities with its location i.e. start to end and confidence score response = [] + for index, value in enumerate(data): location = f"{value.start}_{value.end}" if anonymized_response: diff --git a/pebblo/entity_classifier/utils/config.py b/pebblo/entity_classifier/utils/config.py index d5c0a12b..31b38bee 100644 --- a/pebblo/entity_classifier/utils/config.py +++ b/pebblo/entity_classifier/utils/config.py @@ -38,6 +38,27 @@ class SecretEntities(Enum): GOOGLE_API_KEY = "google-api-key" +entity_conf_mapping = { + # Identification + Entities.US_SSN.value: 0.8, + Entities.US_PASSPORT.value: 0.4, + Entities.US_DRIVER_LICENSE.value: 0.4, + # Financial + Entities.US_ITIN.value: 0.8, + Entities.CREDIT_CARD.value: 0.8, + Entities.US_BANK_NUMBER.value: 0.4, + Entities.IBAN_CODE.value: 0.8, + # Secret + SecretEntities.GITHUB_TOKEN.value: 0.8, + SecretEntities.SLACK_TOKEN.value: 0.8, + SecretEntities.AWS_ACCESS_KEY.value: 0.45, + SecretEntities.AWS_SECRET_KEY.value: 0.8, + SecretEntities.AZURE_KEY_ID.value: 0.8, + SecretEntities.AZURE_CLIENT_SECRET.value: 0.8, + SecretEntities.GOOGLE_API_KEY.value: 0.8, +} + + class ConfidenceScore(Enum): Entity = "0.8" # based on this score entity output is finalized EntityMinScore = "0.45" # It denotes the pattern's strength diff --git a/pebblo/entity_classifier/utils/regex_pattern.py b/pebblo/entity_classifier/utils/regex_pattern.py index 4368005a..a42641d1 100644 --- a/pebblo/entity_classifier/utils/regex_pattern.py +++ b/pebblo/entity_classifier/utils/regex_pattern.py @@ -11,6 +11,6 @@ "aws-access-key": r"""\b((?:AKIA|ABIA|ACCA|ASIA)[0-9A-Z]{16})\b""", "aws-secret-key": r"""\b([A-Za-z0-9+/]{40})[ \r\n'"\x60]""", "azure-key-id": r"""(?i)(%s).{0,20}([a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12})""", - "azure-client-secret": r"""(?i)(%s).{0,20}([a-z0-9_\.\-~]{34})""", + "azure-client-secret": r"""\b(?i)(%s).{0,20}([a-z0-9_\.\-~]{34})\b""", "google-api-key": r"""(?i)(?:youtube)(?:.|[\n\r]){0,40}\bAIza[0-9A-Za-z\-_]{35}\b""", } diff --git a/pebblo/entity_classifier/utils/utils.py b/pebblo/entity_classifier/utils/utils.py index 6fd20b13..5fcd620b 100644 --- a/pebblo/entity_classifier/utils/utils.py +++ b/pebblo/entity_classifier/utils/utils.py @@ -15,8 +15,9 @@ def get_entities(entities_list, response): - entity_groups = dict() - entity_details = dict() + entity_groups: dict = dict() + entity_details: dict = dict() + mapped_entity = None total_count = 0 for entity in response: diff --git a/tests/app/service/test_prompt_gov.py b/tests/app/service/test_prompt_gov.py index 5710c22f..0ddd111f 100644 --- a/tests/app/service/test_prompt_gov.py +++ b/tests/app/service/test_prompt_gov.py @@ -20,7 +20,14 @@ def test_process_request_success(mock_entity_classifier): {"us-ssn": 1}, 1, "anonymized document", - {"us-ssn": [{"location": "16_27", "confidence_score": "HIGH"}]}, + { + "us-ssn": [ + { + "location": "16_27", + "confidence_score": "HIGH", + } + ] + }, ) data = {"prompt": "Sachin's SSN is 222-85-4836"} diff --git a/tests/app/test_prompt_api.py b/tests/app/test_prompt_api.py index 243847cc..677a057f 100644 --- a/tests/app/test_prompt_api.py +++ b/tests/app/test_prompt_api.py @@ -52,12 +52,12 @@ def test_app_prompt_success(mock_write_json_to_file): } response = client.post("/v1/prompt", json=test_payload) - assert response.status_code == 200 assert response.json()["message"] == "AiApp prompt request completed successfully" assert response.json()["retrieval_data"]["prompt"] == { "entities": {}, } + assert response.json()["retrieval_data"]["response"] == { "entities": {"us-ssn": 1}, "topics": {}, diff --git a/tests/entity_classifier/test_entity_classifier.py b/tests/entity_classifier/test_entity_classifier.py index 9f9f985a..4f87059a 100644 --- a/tests/entity_classifier/test_entity_classifier.py +++ b/tests/entity_classifier/test_entity_classifier.py @@ -235,10 +235,30 @@ def test_presidio_entity_classifier_and_anonymizer( assert total_count == 4 assert anonymized_text == input_text1 assert entity_details == { - "us-ssn": [{"location": "17_28", "confidence_score": "HIGH"}], - "us-itin": [{"location": "42_53", "confidence_score": "HIGH"}], - "aws-access-key": [{"location": "77_97", "confidence_score": "HIGH"}], - "github-token": [{"location": "120_210", "confidence_score": "HIGH"}], + "us-ssn": [ + { + "location": "17_28", + "confidence_score": "HIGH", + } + ], + "us-itin": [ + { + "location": "42_53", + "confidence_score": "HIGH", + } + ], + "aws-access-key": [ + { + "location": "77_97", + "confidence_score": "HIGH", + } + ], + "github-token": [ + { + "location": "120_210", + "confidence_score": "HIGH", + } + ], } ( @@ -256,10 +276,30 @@ def test_presidio_entity_classifier_and_anonymizer( assert total_count == 4 assert anonymized_text == mock_input_text1_anonymize_snippet_true assert entity_details == { - "us-ssn": [{"location": "17_25", "confidence_score": "HIGH"}], - "us-itin": [{"location": "39_48", "confidence_score": "HIGH"}], - "aws-access-key": [{"location": "72_88", "confidence_score": "HIGH"}], - "github-token": [{"location": "111_125", "confidence_score": "HIGH"}], + "us-ssn": [ + { + "location": "17_25", + "confidence_score": "HIGH", + } + ], + "us-itin": [ + { + "location": "39_48", + "confidence_score": "HIGH", + } + ], + "aws-access-key": [ + { + "location": "72_88", + "confidence_score": "HIGH", + } + ], + "github-token": [ + { + "location": "111_125", + "confidence_score": "HIGH", + } + ], } ( @@ -281,16 +321,57 @@ def test_presidio_entity_classifier_and_anonymizer( assert total_count == 9 assert anonymized_text == input_text2 assert entity_details == { - "credit-card-number": [{"location": "1367_1382", "confidence_score": "HIGH"}], - "iban-code": [{"location": "1406_1434", "confidence_score": "HIGH"}], - "us-ssn": [{"location": "1178_1189", "confidence_score": "HIGH"}], - "us-itin": [{"location": "1450_1461", "confidence_score": "HIGH"}], - "aws-access-key": [{"location": "1545_1565", "confidence_score": "HIGH"}], - "aws-secret-key": [{"location": "1587_1628", "confidence_score": "HIGH"}], - "github-token": [{"location": "1646_1736", "confidence_score": "HIGH"}], + "credit-card-number": [ + { + "location": "1367_1382", + "confidence_score": "HIGH", + } + ], + "iban-code": [ + { + "location": "1406_1434", + "confidence_score": "HIGH", + } + ], + "us-ssn": [ + { + "location": "1178_1189", + "confidence_score": "HIGH", + } + ], + "us-itin": [ + { + "location": "1450_1461", + "confidence_score": "HIGH", + } + ], + "aws-access-key": [ + { + "location": "1545_1565", + "confidence_score": "HIGH", + } + ], + "aws-secret-key": [ + { + "location": "1587_1628", + "confidence_score": "HIGH", + } + ], + "github-token": [ + { + "location": "1646_1736", + "confidence_score": "HIGH", + } + ], "slack-token": [ - {"location": "1812_1835", "confidence_score": "HIGH"}, - {"location": "1911_1968", "confidence_score": "HIGH"}, + { + "location": "1812_1835", + "confidence_score": "HIGH", + }, + { + "location": "1911_1968", + "confidence_score": "HIGH", + }, ], } @@ -315,16 +396,57 @@ def test_presidio_entity_classifier_and_anonymizer( assert total_count == 9 assert anonymized_text == mock_input_text2_anonymize_snippet_true assert entity_details == { - "credit-card-number": [{"location": "1178_1186", "confidence_score": "HIGH"}], - "iban-code": [{"location": "1364_1377", "confidence_score": "HIGH"}], - "us-ssn": [{"location": "1401_1412", "confidence_score": "HIGH"}], - "us-itin": [{"location": "1428_1437", "confidence_score": "HIGH"}], - "aws-access-key": [{"location": "1521_1537", "confidence_score": "HIGH"}], - "aws-secret-key": [{"location": "1559_1575", "confidence_score": "HIGH"}], - "github-token": [{"location": "1593_1607", "confidence_score": "HIGH"}], + "credit-card-number": [ + { + "location": "1178_1186", + "confidence_score": "HIGH", + } + ], + "iban-code": [ + { + "location": "1364_1377", + "confidence_score": "HIGH", + } + ], + "us-ssn": [ + { + "location": "1401_1412", + "confidence_score": "HIGH", + } + ], + "us-itin": [ + { + "location": "1428_1437", + "confidence_score": "HIGH", + } + ], + "aws-access-key": [ + { + "location": "1521_1537", + "confidence_score": "HIGH", + } + ], + "aws-secret-key": [ + { + "location": "1559_1575", + "confidence_score": "HIGH", + } + ], + "github-token": [ + { + "location": "1593_1607", + "confidence_score": "HIGH", + } + ], "slack-token": [ - {"location": "1683_1696", "confidence_score": "HIGH"}, - {"location": "1772_1785", "confidence_score": "HIGH"}, + { + "location": "1683_1696", + "confidence_score": "HIGH", + }, + { + "location": "1772_1785", + "confidence_score": "HIGH", + }, ], }