Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confidence score changes in Entity Classifier #473

Merged
merged 9 commits into from
Aug 13, 2024
58 changes: 50 additions & 8 deletions pebblo/entity_classifier/entity_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ConfidenceScore,
Entities,
SecretEntities,
entity_conf_mapping,
)
from pebblo.entity_classifier.utils.utils import (
add_custom_regex_analyzer_registry,
Expand Down Expand Up @@ -41,15 +42,52 @@ def custom_analyze(self):
)

def analyze_response(self, input_text, anonymize_all_entities=True):
# Returns analyzed output
"""
Analyze the given input text to detect and classify entities based on predefined criteria.

Args:
input_text (str): The text to be analyzed for detecting entities.
anonymize_all_entities (bool): Flag to determine if all detected entities should be anonymized.
(Currently not used in the function logic.)

Returns:
list: A list of detected entities that meet the criteria for classification.
"""
# Analyze the text to detect entities using the Presidio analyzer
analyzer_results = self.analyzer.analyze(text=input_text, language="en")
analyzer_results = [
result
for result in analyzer_results
if result.score >= float(ConfidenceScore.Entity.value)
and result.entity_type in self.entities
]
return analyzer_results

# Initialize the list to hold the final classified entities
final_results = []

# Iterate through the detected entities
for entity in analyzer_results:
try:
mapped_entity = None

# Map entity type to predefined entities if it exists in the Entities enumeration
if entity.entity_type in Entities.__members__:
mapped_entity = Entities[entity.entity_type].value

# Check if the entity type exists in SecretEntities enumeration
elif entity.entity_type in SecretEntities.__members__:
mapped_entity = SecretEntities[entity.entity_type].value

# Append entity to final results if it meets the confidence threshold and is in the desired entities list
if (
mapped_entity
and entity.score >= float(entity_conf_mapping[mapped_entity])
and entity.entity_type in self.entities
):
final_results.append(entity)

# Handle any exceptions that occur during entity classification
except Exception as ex:
logger.warning(
f"Error in analyze_response in entity classification. {str(ex)}"
)

# Return the list of classified entities that met the criteria
return final_results

def anonymize_response(self, analyzer_results, input_text):
# Returns anonymized output
Expand All @@ -63,6 +101,7 @@ def anonymize_response(self, analyzer_results, input_text):
def get_analyzed_entities_response(data, anonymized_response=None):
# Returns entities with its location i.e. start to end and confidence score
response = []

for index, value in enumerate(data):
location = f"{value.start}_{value.end}"
if anonymized_response:
Expand Down Expand Up @@ -132,3 +171,6 @@ def presidio_entity_classifier_and_anonymizer(
f"Presidio Entity Classifier and Anonymizer Failed, Exception: {e}"
)
return entities, total_count, input_text



21 changes: 21 additions & 0 deletions pebblo/entity_classifier/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,27 @@ class SecretEntities(Enum):
GOOGLE_API_KEY = "google-api-key"


entity_conf_mapping = {
# Identification
Entities.US_SSN.value: 0.8,
Entities.US_PASSPORT.value: 0.4,
Entities.US_DRIVER_LICENSE.value: 0.4,
# Financial
Entities.US_ITIN.value: 0.8,
Entities.CREDIT_CARD.value: 0.8,
Entities.US_BANK_NUMBER.value: 0.4,
Entities.IBAN_CODE.value: 0.8,
# Secret
SecretEntities.GITHUB_TOKEN.value: 0.8,
SecretEntities.SLACK_TOKEN.value: 0.8,
SecretEntities.AWS_ACCESS_KEY.value: 0.45,
SecretEntities.AWS_SECRET_KEY.value: 0.8,
SecretEntities.AZURE_KEY_ID.value: 0.8,
SecretEntities.AZURE_CLIENT_SECRET.value: 0.8,
SecretEntities.GOOGLE_API_KEY.value: 0.8,
}


class ConfidenceScore(Enum):
Entity = "0.8" # based on this score entity output is finalized
EntityMinScore = "0.45" # It denotes the pattern's strength
Expand Down
2 changes: 1 addition & 1 deletion pebblo/entity_classifier/utils/regex_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
"aws-access-key": r"""\b((?:AKIA|ABIA|ACCA|ASIA)[0-9A-Z]{16})\b""",
"aws-secret-key": r"""\b([A-Za-z0-9+/]{40})[ \r\n'"\x60]""",
"azure-key-id": r"""(?i)(%s).{0,20}([a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12})""",
"azure-client-secret": r"""(?i)(%s).{0,20}([a-z0-9_\.\-~]{34})""",
"azure-client-secret": r"""\b(?i)(%s).{0,20}([a-z0-9_\.\-~]{34})\b""",
"google-api-key": r"""(?i)(?:youtube)(?:.|[\n\r]){0,40}\bAIza[0-9A-Za-z\-_]{35}\b""",
}
5 changes: 3 additions & 2 deletions pebblo/entity_classifier/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@


def get_entities(entities_list, response):
entity_groups = dict()
entity_details = dict()
entity_groups: dict = dict()
entity_details: dict = dict()

mapped_entity = None
total_count = 0
for entity in response:
Expand Down
9 changes: 8 additions & 1 deletion tests/app/service/test_prompt_gov.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@ def test_process_request_success(mock_entity_classifier):
{"us-ssn": 1},
1,
"anonymized document",
{"us-ssn": [{"location": "16_27", "confidence_score": "HIGH"}]},
{
"us-ssn": [
{
"location": "16_27",
"confidence_score": "HIGH",
}
]
},
)

data = {"prompt": "Sachin's SSN is 222-85-4836"}
Expand Down
2 changes: 1 addition & 1 deletion tests/app/test_prompt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ def test_app_prompt_success(mock_write_json_to_file):
}

response = client.post("/v1/prompt", json=test_payload)

assert response.status_code == 200
assert response.json()["message"] == "AiApp prompt request completed successfully"
assert response.json()["retrieval_data"]["prompt"] == {
"entities": {},
}

assert response.json()["retrieval_data"]["response"] == {
"entities": {"us-ssn": 1},
"topics": {},
Expand Down
174 changes: 148 additions & 26 deletions tests/entity_classifier/test_entity_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,30 @@ def test_presidio_entity_classifier_and_anonymizer(
assert total_count == 4
assert anonymized_text == input_text1
assert entity_details == {
"us-ssn": [{"location": "17_28", "confidence_score": "HIGH"}],
"us-itin": [{"location": "42_53", "confidence_score": "HIGH"}],
"aws-access-key": [{"location": "77_97", "confidence_score": "HIGH"}],
"github-token": [{"location": "120_210", "confidence_score": "HIGH"}],
"us-ssn": [
{
"location": "17_28",
"confidence_score": "HIGH",
}
],
"us-itin": [
{
"location": "42_53",
"confidence_score": "HIGH",
}
],
"aws-access-key": [
{
"location": "77_97",
"confidence_score": "HIGH",
}
],
"github-token": [
{
"location": "120_210",
"confidence_score": "HIGH",
}
],
}

(
Expand All @@ -256,10 +276,30 @@ def test_presidio_entity_classifier_and_anonymizer(
assert total_count == 4
assert anonymized_text == mock_input_text1_anonymize_snippet_true
assert entity_details == {
"us-ssn": [{"location": "17_25", "confidence_score": "HIGH"}],
"us-itin": [{"location": "39_48", "confidence_score": "HIGH"}],
"aws-access-key": [{"location": "72_88", "confidence_score": "HIGH"}],
"github-token": [{"location": "111_125", "confidence_score": "HIGH"}],
"us-ssn": [
{
"location": "17_25",
"confidence_score": "HIGH",
}
],
"us-itin": [
{
"location": "39_48",
"confidence_score": "HIGH",
}
],
"aws-access-key": [
{
"location": "72_88",
"confidence_score": "HIGH",
}
],
"github-token": [
{
"location": "111_125",
"confidence_score": "HIGH",
}
],
}

(
Expand All @@ -281,16 +321,57 @@ def test_presidio_entity_classifier_and_anonymizer(
assert total_count == 9
assert anonymized_text == input_text2
assert entity_details == {
"credit-card-number": [{"location": "1367_1382", "confidence_score": "HIGH"}],
"iban-code": [{"location": "1406_1434", "confidence_score": "HIGH"}],
"us-ssn": [{"location": "1178_1189", "confidence_score": "HIGH"}],
"us-itin": [{"location": "1450_1461", "confidence_score": "HIGH"}],
"aws-access-key": [{"location": "1545_1565", "confidence_score": "HIGH"}],
"aws-secret-key": [{"location": "1587_1628", "confidence_score": "HIGH"}],
"github-token": [{"location": "1646_1736", "confidence_score": "HIGH"}],
"credit-card-number": [
{
"location": "1367_1382",
"confidence_score": "HIGH",
}
],
"iban-code": [
{
"location": "1406_1434",
"confidence_score": "HIGH",
}
],
"us-ssn": [
{
"location": "1178_1189",
"confidence_score": "HIGH",
}
],
"us-itin": [
{
"location": "1450_1461",
"confidence_score": "HIGH",
}
],
"aws-access-key": [
{
"location": "1545_1565",
"confidence_score": "HIGH",
}
],
"aws-secret-key": [
{
"location": "1587_1628",
"confidence_score": "HIGH",
}
],
"github-token": [
{
"location": "1646_1736",
"confidence_score": "HIGH",
}
],
"slack-token": [
{"location": "1812_1835", "confidence_score": "HIGH"},
{"location": "1911_1968", "confidence_score": "HIGH"},
{
"location": "1812_1835",
"confidence_score": "HIGH",
},
{
"location": "1911_1968",
"confidence_score": "HIGH",
},
],
}

Expand All @@ -315,16 +396,57 @@ def test_presidio_entity_classifier_and_anonymizer(
assert total_count == 9
assert anonymized_text == mock_input_text2_anonymize_snippet_true
assert entity_details == {
"credit-card-number": [{"location": "1178_1186", "confidence_score": "HIGH"}],
"iban-code": [{"location": "1364_1377", "confidence_score": "HIGH"}],
"us-ssn": [{"location": "1401_1412", "confidence_score": "HIGH"}],
"us-itin": [{"location": "1428_1437", "confidence_score": "HIGH"}],
"aws-access-key": [{"location": "1521_1537", "confidence_score": "HIGH"}],
"aws-secret-key": [{"location": "1559_1575", "confidence_score": "HIGH"}],
"github-token": [{"location": "1593_1607", "confidence_score": "HIGH"}],
"credit-card-number": [
{
"location": "1178_1186",
"confidence_score": "HIGH",
}
],
"iban-code": [
{
"location": "1364_1377",
"confidence_score": "HIGH",
}
],
"us-ssn": [
{
"location": "1401_1412",
"confidence_score": "HIGH",
}
],
"us-itin": [
{
"location": "1428_1437",
"confidence_score": "HIGH",
}
],
"aws-access-key": [
{
"location": "1521_1537",
"confidence_score": "HIGH",
}
],
"aws-secret-key": [
{
"location": "1559_1575",
"confidence_score": "HIGH",
}
],
"github-token": [
{
"location": "1593_1607",
"confidence_score": "HIGH",
}
],
"slack-token": [
{"location": "1683_1696", "confidence_score": "HIGH"},
{"location": "1772_1785", "confidence_score": "HIGH"},
{
"location": "1683_1696",
"confidence_score": "HIGH",
},
{
"location": "1772_1785",
"confidence_score": "HIGH",
},
],
}

Expand Down
Loading