Skip to content

Commit

Permalink
Adding confidence score to entities and topic responses (#460)
Browse files Browse the repository at this point in the history
* Adding confidence score to entities and topic responses

* Adding labels and fixing UTS

* Adding utils

* Fixing UT

* Remove unused imports

* Updating topic classifier

---------

Co-authored-by: dristy.cd <[email protected]>
  • Loading branch information
2 people authored and shreyas-damle committed Aug 21, 2024
1 parent 28493cd commit 2465160
Show file tree
Hide file tree
Showing 13 changed files with 293 additions and 98 deletions.
5 changes: 4 additions & 1 deletion pebblo/app/service/doc_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,14 @@ def _get_classifier_response(self, doc):
)
try:
if doc_info.data:
topics, topic_count = topic_classifier_obj.predict(doc_info.data)
topics, topic_count, topic_details = topic_classifier_obj.predict(
doc_info.data
)
(
entities,
entity_count,
anonymized_doc,
entity_details,
) = self.entity_classifier_obj.presidio_entity_classifier_and_anonymizer(
doc_info.data,
anonymize_snippets=ClassifierConstants.anonymize_snippets.value,
Expand Down
1 change: 1 addition & 0 deletions pebblo/app/service/prompt_gov.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _get_classifier_response(self):
entities,
entity_count,
anonymized_doc,
entity_details,
) = self.entity_classifier_obj.presidio_entity_classifier_and_anonymizer(
self.input.get("prompt"),
anonymize_snippets=False,
Expand Down
5 changes: 4 additions & 1 deletion pebblo/app/service/prompt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _fetch_classified_data(self, input_data, input_type=""):
entities,
entity_count,
_,
_,
) = self.entity_classifier_obj.presidio_entity_classifier_and_anonymizer(
input_data
)
Expand All @@ -53,7 +54,9 @@ def _fetch_classified_data(self, input_data, input_type=""):

# Topic classification is performed only for the response.
if input_type == "response":
topics, topic_count = self.topic_classifier_obj.predict(input_data)
topics, topic_count, topic_details = self.topic_classifier_obj.predict(
input_data
)
data["topicCount"] = topic_count
data["topics"] = topics

Expand Down
2 changes: 1 addition & 1 deletion pebblo/entity_classifier/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ from pebblo.entity_classifier.entity_classifier import EntityClassifier
text = <Input Data>
entity_classifier_obj = EntityClassifier()
entities, total_count, anonymized_text = entity_classifier_obj.presidio_entity_classifier_and_anonymizer(text,anonymize_snippets)
entities, total_count, anonymized_text, entity_details = entity_classifier_obj.presidio_entity_classifier_and_anonymizer(text,anonymize_snippets)
print(f"Entity Group: {entity_groups}")
print(f"Entity Count: {total_entity_count}")
print(f"Anonymized Text: {anonymized_text}")
Expand Down
49 changes: 36 additions & 13 deletions pebblo/entity_classifier/entity_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,8 @@ def analyze_response(self, input_text, anonymize_all_entities=True):
result
for result in analyzer_results
if result.score >= float(ConfidenceScore.Entity.value)
and result.entity_type in self.entities
]
if not anonymize_all_entities: # Condition for anonymized document
analyzer_results = [
result
for result in analyzer_results
if result.entity_type in self.entities
]
return analyzer_results

def anonymize_response(self, analyzer_results, input_text):
Expand All @@ -64,17 +59,36 @@ def anonymize_response(self, analyzer_results, input_text):

return anonymized_text.items, anonymized_text.text

@staticmethod
def get_analyzed_entities_response(data, anonymized_response=None):
# Returns entities with its location i.e. start to end and confidence score
response = []
for index, value in enumerate(data):
location = f"{value.start}_{value.end}"
if anonymized_response:
anonymized_data = anonymized_response[len(data) - index - 1]
location = f"{anonymized_data.start}_{anonymized_data.end}"
response.append(
{
"entity_type": value.entity_type,
"location": location,
"confidence_score": value.score,
}
)
return response

def presidio_entity_classifier_and_anonymizer(
self, input_text, anonymize_snippets=False
):
"""
Perform classification on the input data and return a dictionary with the count of each entity group.
And also returns plain input text as anonymized text output
:param anonymize_snippets: Flag whether to anonymize snippets in report.
:param input_text: Input string / document snippet
:param anonymize_snippets: Flag whether to anonymize snippets in report.
:return: entities: containing the entity group Name as key and its count as value.
total_count: Total count of entity groupsInput text in anonymized form.
anonymized_text: Input text in anonymized form.
entity_details: Entities with its details such as location and confidence score.
Example:
input_text = " My SSN is 222-85-4836.
Expand All @@ -89,21 +103,30 @@ def presidio_entity_classifier_and_anonymizer(
"""
entities = {}
total_count = 0
anonymized_text = ""
try:
logger.debug("Presidio Entity Classifier and Anonymizer Started.")

analyzer_results = self.analyze_response(input_text)
anonymized_response, anonymized_text = self.anonymize_response(
analyzer_results, input_text
)

if anonymize_snippets: # If Document snippet needs to be anonymized
anonymized_response, anonymized_text = self.anonymize_response(
analyzer_results, input_text
)
input_text = anonymized_text.replace("<", "&lt;").replace(">", "&gt;")
entities, total_count = get_entities(self.entities, anonymized_response)
entities_response = self.get_analyzed_entities_response(
analyzer_results, anonymized_response
)
else:
entities_response = self.get_analyzed_entities_response(
analyzer_results
)
entities, entity_details, total_count = get_entities(
self.entities, entities_response
)
logger.debug("Presidio Entity Classifier and Anonymizer Finished")
logger.debug(f"Entities: {entities}")
logger.debug(f"Entity Total count: {total_count}")
return entities, total_count, input_text
return entities, total_count, input_text, entity_details
except Exception as e:
logger.error(
f"Presidio Entity Classifier and Anonymizer Failed, Exception: {e}"
Expand Down
25 changes: 19 additions & 6 deletions pebblo/entity_classifier/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,34 @@
secret_entities_context_mapping,
)
from pebblo.entity_classifier.utils.regex_pattern import regex_secrets_patterns
from pebblo.utils import get_confidence_score_label


def get_entities(entities_list, response):
entity_groups = dict()
entity_details = dict()
mapped_entity = None
total_count = 0
for entity in response:
if entity.entity_type in entities_list:
if entity.entity_type in Entities.__members__:
mapped_entity = Entities[entity.entity_type].value
elif entity.entity_type in SecretEntities.__members__:
mapped_entity = SecretEntities[entity.entity_type].value
if entity["entity_type"] in entities_list:
if entity["entity_type"] in Entities.__members__:
mapped_entity = Entities[entity["entity_type"]].value
elif entity["entity_type"] in SecretEntities.__members__:
mapped_entity = SecretEntities[entity["entity_type"]].value
entity_groups[mapped_entity] = entity_groups.get(mapped_entity, 0) + 1
entity_data = {
"location": entity["location"],
"confidence_score": get_confidence_score_label(
entity["confidence_score"]
),
}
if mapped_entity in entity_details.keys():
entity_details[mapped_entity].append(entity_data)
else:
entity_details[mapped_entity] = [entity_data]
total_count += 1

return entity_groups, total_count
return entity_groups, entity_details, total_count


def add_custom_regex_analyzer_registry():
Expand Down
2 changes: 1 addition & 1 deletion pebblo/topic_classifier/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ from pebblo.topic_classifier.topic_classifier import TopicClassifier
text = "Your sample text here."
topic_classifier_obj = TopicClassifier()
topics, total_topic_count = topic_classifier_obj.predict(text)
topics, total_topic_count, topic_details = topic_classifier_obj.predict(text)
print(f"Topic Response: {topics}")
print(f"Topic Count: {total_topic_count}")
```
21 changes: 16 additions & 5 deletions pebblo/topic_classifier/topic_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TOPICS_TO_EXCLUDE,
)
from pebblo.topic_classifier.enums.constants import topic_display_names
from pebblo.utils import get_confidence_score_label

logger = get_logger(__name__)

Expand Down Expand Up @@ -63,15 +64,15 @@ def predict(self, input_text):
f"Text length is below {TOPIC_MIN_TEXT_LENGTH} characters. "
f"Classification not performed."
)
return {}, 0
return {}, 0, {}

topic_model_response = self.classifier(input_text)
topics, total_count = self._get_topics(topic_model_response)
topics, total_count, topic_details = self._get_topics(topic_model_response)
logger.debug(f"Topics: {topics}")
return topics, total_count
return topics, total_count, topic_details
except Exception as e:
logger.error(f"Error in topic_classifier. Exception: {e}")
return {}, 0
return {}, 0, {}

@staticmethod
def _get_topics(topic_model_response):
Expand All @@ -89,7 +90,17 @@ def _get_topics(topic_model_response):
topics[mapped_topic] = topic["score"]

final_topic = {}
topic_details = {}
if len(topics) > 0:
most_possible_advice = max(topics, key=lambda t: topics[t])
final_topic = {most_possible_advice: 1}
return final_topic, len(final_topic.keys())
topic_details = {
most_possible_advice: [
{
"confidence_score": get_confidence_score_label(
(topics[most_possible_advice])
)
}
]
}
return final_topic, len(final_topic.keys()), topic_details
20 changes: 20 additions & 0 deletions pebblo/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Copyright (c) 2024 Cloud Defense, Inc. All rights reserved.
"""

from enum import Enum


class ConfidenceScoreLabel(Enum):
HIGH = "HIGH"
MEDIUM = "MEDIUM"
LOW = "LOW"


def get_confidence_score_label(confidence_score):
if float(confidence_score) >= 0.8:
return ConfidenceScoreLabel.HIGH.value
elif 0.4 <= float(confidence_score) < 0.8:
return ConfidenceScoreLabel.MEDIUM.value
else:
return ConfidenceScoreLabel.LOW.value
1 change: 1 addition & 0 deletions tests/app/service/test_prompt_gov.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_process_request_success(mock_entity_classifier):
{"us-ssn": 1},
1,
"anonymized document",
{"us-ssn": [{"location": "16_27", "confidence_score": "HIGH"}]},
)

data = {"prompt": "Sachin's SSN is 222-85-4836"}
Expand Down
Loading

0 comments on commit 2465160

Please sign in to comment.