-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added changes for new api for entity and topic classification (#518)
* added changes for new api for entity and topic classification * added model dump changes * Move classify api under /api/v1 router * Use field name data instead of inputs * Add classifier mode for entity-only, topic-only and all the above --------- Co-authored-by: Sridhar Ramaswamy <[email protected]>
- Loading branch information
1 parent
f6b795f
commit df01788
Showing
6 changed files
with
237 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from fastapi import APIRouter | ||
|
||
from pebblo.app.config.config import var_server_config_dict | ||
from pebblo.app.service.classification import Classification | ||
|
||
config_details = var_server_config_dict.get() | ||
|
||
|
||
class APIv1: | ||
""" | ||
Controller Class for all the api endpoints for App resource. | ||
""" | ||
|
||
def __init__(self, prefix: str): | ||
self.router = APIRouter(prefix=prefix) | ||
|
||
@staticmethod | ||
def classify_data(data: dict): | ||
cls_obj = Classification(data) | ||
response = cls_obj.process_request() | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
""" | ||
Module for text classification | ||
""" | ||
|
||
import traceback | ||
from enum import Enum | ||
from typing import Optional | ||
|
||
from pydantic import BaseModel, Field, ValidationError | ||
|
||
from pebblo.app.libs.responses import PebbloJsonResponse | ||
from pebblo.app.models.models import AiDataModel | ||
from pebblo.entity_classifier.entity_classifier import EntityClassifier | ||
from pebblo.log import get_logger | ||
from pebblo.topic_classifier.topic_classifier import TopicClassifier | ||
|
||
|
||
class ClassificationMode(Enum): | ||
ENTITY = "entity" | ||
TOPIC = "topic" | ||
ALL = "all" | ||
|
||
|
||
class ReqClassifier(BaseModel): | ||
data: str | ||
mode: Optional[ClassificationMode] = Field(default=ClassificationMode.ALL) | ||
anonymize: Optional[bool] = Field(default=False) | ||
|
||
class Config: | ||
extra = "forbid" | ||
|
||
|
||
logger = get_logger(__name__) | ||
topic_classifier_obj = TopicClassifier() | ||
entity_classifier_obj = EntityClassifier() | ||
|
||
|
||
class Classification: | ||
""" | ||
Classification wrapper class for Entity and Semantic classification with anonymization | ||
""" | ||
|
||
def __init__(self, input: dict): | ||
self.input = input | ||
|
||
def _get_classifier_response(self, req: ReqClassifier): | ||
""" | ||
Processes the input prompt through the entity classifier and anonymizer, and returns | ||
the resulting information encapsulated in an AiDataModel object. | ||
Returns: | ||
AiDataModel: An object containing the anonymized document, entities, and their counts. | ||
""" | ||
doc_info = AiDataModel( | ||
data=None, | ||
entities={}, | ||
entityCount=0, | ||
entityDetails={}, | ||
topics={}, | ||
topicCount=0, | ||
topicDetails={}, | ||
) | ||
try: | ||
# Process entity classification | ||
if req.mode in [ClassificationMode.ENTITY, ClassificationMode.ALL]: | ||
( | ||
entities, | ||
entity_count, | ||
anonymized_doc, | ||
entity_details, | ||
) = entity_classifier_obj.presidio_entity_classifier_and_anonymizer( | ||
req.data, | ||
anonymize_snippets=req.anonymize, | ||
) | ||
doc_info.entities = entities | ||
doc_info.entityCount = entity_count | ||
doc_info.entityDetails = entity_details | ||
doc_info.data = anonymized_doc if req.anonymize else "" | ||
|
||
# Process topic classification | ||
if req.mode in [ClassificationMode.TOPIC, ClassificationMode.ALL]: | ||
topics, topic_count, topic_details = topic_classifier_obj.predict( | ||
req.data | ||
) | ||
doc_info.topics = topics | ||
doc_info.topicCount = topic_count | ||
doc_info.topicDetails = topic_details | ||
return doc_info | ||
except (KeyError, ValueError, RuntimeError) as e: | ||
logger.error(f"Failed to get classifier response: {e}") | ||
return doc_info | ||
except Exception as e: | ||
logger.error(f"Unexpected error:{e}\n{traceback.format_exc()}") | ||
return doc_info | ||
|
||
def process_request(self): | ||
""" | ||
Processes the user request for classification and returns a structured response. | ||
Returns: | ||
PebbloJsonResponse: The response object containing classification results or error details. | ||
""" | ||
try: | ||
req = ReqClassifier.model_validate(self.input) | ||
if not req.data: | ||
return PebbloJsonResponse.build( | ||
body={"error": "Input data is missing"}, status_code=400 | ||
) | ||
doc_info = self._get_classifier_response(req) | ||
return PebbloJsonResponse.build( | ||
body=doc_info.model_dump(exclude_none=True), status_code=200 | ||
) | ||
except ValidationError as e: | ||
logger.error( | ||
f"Validation error in Classification API process_request:{e}\n{traceback.format_exc()}" | ||
) | ||
return PebbloJsonResponse.build( | ||
body={"error": f"Validation error: {e}"}, status_code=400 | ||
) | ||
except Exception: | ||
response = AiDataModel( | ||
data=None, | ||
entities={}, | ||
entityCount=0, | ||
topics={}, | ||
topicCount=0, | ||
topicDetails={}, | ||
) | ||
logger.error( | ||
f"Error in Classification API process_request: {traceback.format_exc()}" | ||
) | ||
return PebbloJsonResponse.build( | ||
body=response.model_dump(exclude_none=True), status_code=500 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# test_prompt_gov.py | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from pebblo.app.libs.responses import PebbloJsonResponse | ||
from pebblo.app.models.models import AiDataModel | ||
from pebblo.app.service.classification import Classification | ||
|
||
|
||
@pytest.fixture | ||
def mock_entity_classifier(): | ||
with patch("pebblo.app.service.classification.EntityClassifier") as mock: | ||
yield mock | ||
|
||
|
||
@pytest.fixture | ||
def mock_topic_classifier(): | ||
with patch("pebblo.app.service.classification.TopicClassifier") as mock: | ||
yield mock | ||
|
||
|
||
def test_process_request_success(mock_entity_classifier, mock_topic_classifier): | ||
mock_entity_classifier_instance = mock_entity_classifier.return_value | ||
mock_entity_classifier_instance.presidio_entity_classifier_and_anonymizer.return_value = ( | ||
{"us-ssn": 1}, | ||
1, | ||
"anonymized document", | ||
{ | ||
"us-ssn": [ | ||
{ | ||
"location": "16_27", | ||
"confidence_score": "HIGH", | ||
"entity_group": "pii-identification", | ||
} | ||
] | ||
}, | ||
) | ||
|
||
mock_topic_classifier_instance = mock_topic_classifier.return_value | ||
mock_topic_classifier_instance.predict.return_value = ({}, 0, {}) | ||
|
||
data = {"data": "Sachin's SSN is 222-85-4836"} | ||
cls_obj = Classification(data) | ||
response = cls_obj.process_request() | ||
expected_response = AiDataModel( | ||
data="", | ||
entities={"us-ssn": 1}, | ||
entityCount=1, | ||
entityDetails={ | ||
"us-ssn": [ | ||
{ | ||
"location": "16_27", | ||
"confidence_score": "HIGH", | ||
"entity_group": "pii-identification", | ||
} | ||
] | ||
}, | ||
topics={}, | ||
topicCount=0, | ||
topicDetails={}, | ||
) | ||
expected_response = PebbloJsonResponse.build( | ||
body=expected_response.model_dump(exclude_none=True), status_code=200 | ||
) | ||
|
||
assert response.status_code == expected_response.status_code | ||
assert response.body == expected_response.body |