Skip to content

Commit

Permalink
Confidence score changes in Entity Classifier (#473)
Browse files Browse the repository at this point in the history
* Adding confidence score to entities and topic responses (#460)

* Adding confidence score to entities and topic responses

* Adding labels and fixing UTS

* Adding utils

* Fixing UT

* Remove unused imports

* Updating topic classifier

---------

Co-authored-by: dristy.cd <[email protected]>

* Added changes for prompt group

* resolved linting issue

* added changes for confidence score for entity classification

* added changes for confidence score

* review comment changes

* review comment changes

* review comment changes

---------

Co-authored-by: Dristy Srivastava <[email protected]>
Co-authored-by: dristy.cd <[email protected]>
  • Loading branch information
3 people authored Aug 13, 2024
1 parent e3ffe54 commit e4418f8
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 39 deletions.
55 changes: 47 additions & 8 deletions pebblo/entity_classifier/entity_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ConfidenceScore,
Entities,
SecretEntities,
entity_conf_mapping,
)
from pebblo.entity_classifier.utils.utils import (
add_custom_regex_analyzer_registry,
Expand Down Expand Up @@ -41,15 +42,52 @@ def custom_analyze(self):
)

def analyze_response(self, input_text, anonymize_all_entities=True):
# Returns analyzed output
"""
Analyze the given input text to detect and classify entities based on predefined criteria.
Args:
input_text (str): The text to be analyzed for detecting entities.
anonymize_all_entities (bool): Flag to determine if all detected entities should be anonymized.
(Currently not used in the function logic.)
Returns:
list: A list of detected entities that meet the criteria for classification.
"""
# Analyze the text to detect entities using the Presidio analyzer
analyzer_results = self.analyzer.analyze(text=input_text, language="en")
analyzer_results = [
result
for result in analyzer_results
if result.score >= float(ConfidenceScore.Entity.value)
and result.entity_type in self.entities
]
return analyzer_results

# Initialize the list to hold the final classified entities
final_results = []

# Iterate through the detected entities
for entity in analyzer_results:
try:
mapped_entity = None

# Map entity type to predefined entities if it exists in the Entities enumeration
if entity.entity_type in Entities.__members__:
mapped_entity = Entities[entity.entity_type].value

# Check if the entity type exists in SecretEntities enumeration
elif entity.entity_type in SecretEntities.__members__:
mapped_entity = SecretEntities[entity.entity_type].value

# Append entity to final results if it meets the confidence threshold and is in the desired entities list
if (
mapped_entity
and entity.score >= float(entity_conf_mapping[mapped_entity])
and entity.entity_type in self.entities
):
final_results.append(entity)

# Handle any exceptions that occur during entity classification
except Exception as ex:
logger.warning(
f"Error in analyze_response in entity classification. {str(ex)}"
)

# Return the list of classified entities that met the criteria
return final_results

def anonymize_response(self, analyzer_results, input_text):
# Returns anonymized output
Expand All @@ -63,6 +101,7 @@ def anonymize_response(self, analyzer_results, input_text):
def get_analyzed_entities_response(data, anonymized_response=None):
# Returns entities with its location i.e. start to end and confidence score
response = []

for index, value in enumerate(data):
location = f"{value.start}_{value.end}"
if anonymized_response:
Expand Down
21 changes: 21 additions & 0 deletions pebblo/entity_classifier/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,27 @@ class SecretEntities(Enum):
GOOGLE_API_KEY = "google-api-key"


entity_conf_mapping = {
# Identification
Entities.US_SSN.value: 0.8,
Entities.US_PASSPORT.value: 0.4,
Entities.US_DRIVER_LICENSE.value: 0.4,
# Financial
Entities.US_ITIN.value: 0.8,
Entities.CREDIT_CARD.value: 0.8,
Entities.US_BANK_NUMBER.value: 0.4,
Entities.IBAN_CODE.value: 0.8,
# Secret
SecretEntities.GITHUB_TOKEN.value: 0.8,
SecretEntities.SLACK_TOKEN.value: 0.8,
SecretEntities.AWS_ACCESS_KEY.value: 0.45,
SecretEntities.AWS_SECRET_KEY.value: 0.8,
SecretEntities.AZURE_KEY_ID.value: 0.8,
SecretEntities.AZURE_CLIENT_SECRET.value: 0.8,
SecretEntities.GOOGLE_API_KEY.value: 0.8,
}


class ConfidenceScore(Enum):
Entity = "0.8" # based on this score entity output is finalized
EntityMinScore = "0.45" # It denotes the pattern's strength
Expand Down
2 changes: 1 addition & 1 deletion pebblo/entity_classifier/utils/regex_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
"aws-access-key": r"""\b((?:AKIA|ABIA|ACCA|ASIA)[0-9A-Z]{16})\b""",
"aws-secret-key": r"""\b([A-Za-z0-9+/]{40})[ \r\n'"\x60]""",
"azure-key-id": r"""(?i)(%s).{0,20}([a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12})""",
"azure-client-secret": r"""(?i)(%s).{0,20}([a-z0-9_\.\-~]{34})""",
"azure-client-secret": r"""\b(?i)(%s).{0,20}([a-z0-9_\.\-~]{34})\b""",
"google-api-key": r"""(?i)(?:youtube)(?:.|[\n\r]){0,40}\bAIza[0-9A-Za-z\-_]{35}\b""",
}
5 changes: 3 additions & 2 deletions pebblo/entity_classifier/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@


def get_entities(entities_list, response):
entity_groups = dict()
entity_details = dict()
entity_groups: dict = dict()
entity_details: dict = dict()

mapped_entity = None
total_count = 0
for entity in response:
Expand Down
9 changes: 8 additions & 1 deletion tests/app/service/test_prompt_gov.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@ def test_process_request_success(mock_entity_classifier):
{"us-ssn": 1},
1,
"anonymized document",
{"us-ssn": [{"location": "16_27", "confidence_score": "HIGH"}]},
{
"us-ssn": [
{
"location": "16_27",
"confidence_score": "HIGH",
}
]
},
)

data = {"prompt": "Sachin's SSN is 222-85-4836"}
Expand Down
2 changes: 1 addition & 1 deletion tests/app/test_prompt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ def test_app_prompt_success(mock_write_json_to_file):
}

response = client.post("/v1/prompt", json=test_payload)

assert response.status_code == 200
assert response.json()["message"] == "AiApp prompt request completed successfully"
assert response.json()["retrieval_data"]["prompt"] == {
"entities": {},
}

assert response.json()["retrieval_data"]["response"] == {
"entities": {"us-ssn": 1},
"topics": {},
Expand Down
174 changes: 148 additions & 26 deletions tests/entity_classifier/test_entity_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,30 @@ def test_presidio_entity_classifier_and_anonymizer(
assert total_count == 4
assert anonymized_text == input_text1
assert entity_details == {
"us-ssn": [{"location": "17_28", "confidence_score": "HIGH"}],
"us-itin": [{"location": "42_53", "confidence_score": "HIGH"}],
"aws-access-key": [{"location": "77_97", "confidence_score": "HIGH"}],
"github-token": [{"location": "120_210", "confidence_score": "HIGH"}],
"us-ssn": [
{
"location": "17_28",
"confidence_score": "HIGH",
}
],
"us-itin": [
{
"location": "42_53",
"confidence_score": "HIGH",
}
],
"aws-access-key": [
{
"location": "77_97",
"confidence_score": "HIGH",
}
],
"github-token": [
{
"location": "120_210",
"confidence_score": "HIGH",
}
],
}

(
Expand All @@ -256,10 +276,30 @@ def test_presidio_entity_classifier_and_anonymizer(
assert total_count == 4
assert anonymized_text == mock_input_text1_anonymize_snippet_true
assert entity_details == {
"us-ssn": [{"location": "17_25", "confidence_score": "HIGH"}],
"us-itin": [{"location": "39_48", "confidence_score": "HIGH"}],
"aws-access-key": [{"location": "72_88", "confidence_score": "HIGH"}],
"github-token": [{"location": "111_125", "confidence_score": "HIGH"}],
"us-ssn": [
{
"location": "17_25",
"confidence_score": "HIGH",
}
],
"us-itin": [
{
"location": "39_48",
"confidence_score": "HIGH",
}
],
"aws-access-key": [
{
"location": "72_88",
"confidence_score": "HIGH",
}
],
"github-token": [
{
"location": "111_125",
"confidence_score": "HIGH",
}
],
}

(
Expand All @@ -281,16 +321,57 @@ def test_presidio_entity_classifier_and_anonymizer(
assert total_count == 9
assert anonymized_text == input_text2
assert entity_details == {
"credit-card-number": [{"location": "1367_1382", "confidence_score": "HIGH"}],
"iban-code": [{"location": "1406_1434", "confidence_score": "HIGH"}],
"us-ssn": [{"location": "1178_1189", "confidence_score": "HIGH"}],
"us-itin": [{"location": "1450_1461", "confidence_score": "HIGH"}],
"aws-access-key": [{"location": "1545_1565", "confidence_score": "HIGH"}],
"aws-secret-key": [{"location": "1587_1628", "confidence_score": "HIGH"}],
"github-token": [{"location": "1646_1736", "confidence_score": "HIGH"}],
"credit-card-number": [
{
"location": "1367_1382",
"confidence_score": "HIGH",
}
],
"iban-code": [
{
"location": "1406_1434",
"confidence_score": "HIGH",
}
],
"us-ssn": [
{
"location": "1178_1189",
"confidence_score": "HIGH",
}
],
"us-itin": [
{
"location": "1450_1461",
"confidence_score": "HIGH",
}
],
"aws-access-key": [
{
"location": "1545_1565",
"confidence_score": "HIGH",
}
],
"aws-secret-key": [
{
"location": "1587_1628",
"confidence_score": "HIGH",
}
],
"github-token": [
{
"location": "1646_1736",
"confidence_score": "HIGH",
}
],
"slack-token": [
{"location": "1812_1835", "confidence_score": "HIGH"},
{"location": "1911_1968", "confidence_score": "HIGH"},
{
"location": "1812_1835",
"confidence_score": "HIGH",
},
{
"location": "1911_1968",
"confidence_score": "HIGH",
},
],
}

Expand All @@ -315,16 +396,57 @@ def test_presidio_entity_classifier_and_anonymizer(
assert total_count == 9
assert anonymized_text == mock_input_text2_anonymize_snippet_true
assert entity_details == {
"credit-card-number": [{"location": "1178_1186", "confidence_score": "HIGH"}],
"iban-code": [{"location": "1364_1377", "confidence_score": "HIGH"}],
"us-ssn": [{"location": "1401_1412", "confidence_score": "HIGH"}],
"us-itin": [{"location": "1428_1437", "confidence_score": "HIGH"}],
"aws-access-key": [{"location": "1521_1537", "confidence_score": "HIGH"}],
"aws-secret-key": [{"location": "1559_1575", "confidence_score": "HIGH"}],
"github-token": [{"location": "1593_1607", "confidence_score": "HIGH"}],
"credit-card-number": [
{
"location": "1178_1186",
"confidence_score": "HIGH",
}
],
"iban-code": [
{
"location": "1364_1377",
"confidence_score": "HIGH",
}
],
"us-ssn": [
{
"location": "1401_1412",
"confidence_score": "HIGH",
}
],
"us-itin": [
{
"location": "1428_1437",
"confidence_score": "HIGH",
}
],
"aws-access-key": [
{
"location": "1521_1537",
"confidence_score": "HIGH",
}
],
"aws-secret-key": [
{
"location": "1559_1575",
"confidence_score": "HIGH",
}
],
"github-token": [
{
"location": "1593_1607",
"confidence_score": "HIGH",
}
],
"slack-token": [
{"location": "1683_1696", "confidence_score": "HIGH"},
{"location": "1772_1785", "confidence_score": "HIGH"},
{
"location": "1683_1696",
"confidence_score": "HIGH",
},
{
"location": "1772_1785",
"confidence_score": "HIGH",
},
],
}

Expand Down

0 comments on commit e4418f8

Please sign in to comment.