Skip to content

Commit

Permalink
added changes for new api for entity and topic classification (#518)
Browse files Browse the repository at this point in the history
* added changes for new api for entity and topic classification
* added model dump changes
* Move classify api under /api/v1 router
* Use field name data instead of inputs
* Add classifier mode for entity-only, topic-only and all the above


---------

Co-authored-by: Sridhar Ramaswamy <[email protected]>
  • Loading branch information
gr8nishan and sridhar-daxa authored Sep 10, 2024
1 parent f6b795f commit df01788
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 3 deletions.
21 changes: 21 additions & 0 deletions pebblo/app/api/v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from fastapi import APIRouter

from pebblo.app.config.config import var_server_config_dict
from pebblo.app.service.classification import Classification

config_details = var_server_config_dict.get()


class APIv1:
"""
Controller Class for all the api endpoints for App resource.
"""

def __init__(self, prefix: str):
self.router = APIRouter(prefix=prefix)

@staticmethod
def classify_data(data: dict):
cls_obj = Classification(data)
response = cls_obj.process_request()
return response
3 changes: 2 additions & 1 deletion pebblo/app/config/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pebblo.app.routers.redirection_router import redirect_router_instance

with redirect_stdout(StringIO()), redirect_stderr(StringIO()):
from pebblo.app.routers.routers import router_instance
from pebblo.app.routers.routers import api_v1_router_instance, router_instance
from pebblo.log import get_logger, get_uvicorn_logconfig

logger = get_logger(__name__)
Expand All @@ -42,6 +42,7 @@ def __init__(self, config_details):
self.app = FastAPI(exception_handlers=exception_handlers)
# Register the router instance with the main app
self.app.include_router(router_instance.router)
self.app.include_router(api_v1_router_instance.router)
self.app.include_router(local_ui_router_instance.router)
self.app.include_router(redirect_router_instance.router)
# Adding cors
Expand Down
4 changes: 2 additions & 2 deletions pebblo/app/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ class LoaderMetadata(BaseModel):

class AiDataModel(BaseModel):
data: Optional[Union[list, str]] = None
entityCount: int
entityCount: Optional[int] = 0
entities: dict
entityDetails: Optional[dict] = {}
topicCount: Optional[int] = None
topicCount: Optional[int] = 0
topics: Optional[dict] = {}
topicDetails: Optional[dict] = {}

Expand Down
10 changes: 10 additions & 0 deletions pebblo/app/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
"""

from pebblo.app.api.api import App
from pebblo.app.api.v1 import APIv1

# Create an instance of APp with a specific prefix
router_instance = App(prefix="/v1")
api_v1_router_instance = APIv1(prefix="/api/v1")

# Add routes to the class-based router
router_instance.router.add_api_route(
Expand All @@ -32,3 +34,11 @@
response_model=dict,
response_model_exclude_none=True,
)

api_v1_router_instance.router.add_api_route(
"/classify",
APIv1.classify_data,
methods=["POST"],
response_model=dict,
response_model_exclude_none=True,
)
134 changes: 134 additions & 0 deletions pebblo/app/service/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Module for text classification
"""

import traceback
from enum import Enum
from typing import Optional

from pydantic import BaseModel, Field, ValidationError

from pebblo.app.libs.responses import PebbloJsonResponse
from pebblo.app.models.models import AiDataModel
from pebblo.entity_classifier.entity_classifier import EntityClassifier
from pebblo.log import get_logger
from pebblo.topic_classifier.topic_classifier import TopicClassifier


class ClassificationMode(Enum):
ENTITY = "entity"
TOPIC = "topic"
ALL = "all"


class ReqClassifier(BaseModel):
data: str
mode: Optional[ClassificationMode] = Field(default=ClassificationMode.ALL)
anonymize: Optional[bool] = Field(default=False)

class Config:
extra = "forbid"


logger = get_logger(__name__)
topic_classifier_obj = TopicClassifier()
entity_classifier_obj = EntityClassifier()


class Classification:
"""
Classification wrapper class for Entity and Semantic classification with anonymization
"""

def __init__(self, input: dict):
self.input = input

def _get_classifier_response(self, req: ReqClassifier):
"""
Processes the input prompt through the entity classifier and anonymizer, and returns
the resulting information encapsulated in an AiDataModel object.
Returns:
AiDataModel: An object containing the anonymized document, entities, and their counts.
"""
doc_info = AiDataModel(
data=None,
entities={},
entityCount=0,
entityDetails={},
topics={},
topicCount=0,
topicDetails={},
)
try:
# Process entity classification
if req.mode in [ClassificationMode.ENTITY, ClassificationMode.ALL]:
(
entities,
entity_count,
anonymized_doc,
entity_details,
) = entity_classifier_obj.presidio_entity_classifier_and_anonymizer(
req.data,
anonymize_snippets=req.anonymize,
)
doc_info.entities = entities
doc_info.entityCount = entity_count
doc_info.entityDetails = entity_details
doc_info.data = anonymized_doc if req.anonymize else ""

# Process topic classification
if req.mode in [ClassificationMode.TOPIC, ClassificationMode.ALL]:
topics, topic_count, topic_details = topic_classifier_obj.predict(
req.data
)
doc_info.topics = topics
doc_info.topicCount = topic_count
doc_info.topicDetails = topic_details
return doc_info
except (KeyError, ValueError, RuntimeError) as e:
logger.error(f"Failed to get classifier response: {e}")
return doc_info
except Exception as e:
logger.error(f"Unexpected error:{e}\n{traceback.format_exc()}")
return doc_info

def process_request(self):
"""
Processes the user request for classification and returns a structured response.
Returns:
PebbloJsonResponse: The response object containing classification results or error details.
"""
try:
req = ReqClassifier.model_validate(self.input)
if not req.data:
return PebbloJsonResponse.build(
body={"error": "Input data is missing"}, status_code=400
)
doc_info = self._get_classifier_response(req)
return PebbloJsonResponse.build(
body=doc_info.model_dump(exclude_none=True), status_code=200
)
except ValidationError as e:
logger.error(
f"Validation error in Classification API process_request:{e}\n{traceback.format_exc()}"
)
return PebbloJsonResponse.build(
body={"error": f"Validation error: {e}"}, status_code=400
)
except Exception:
response = AiDataModel(
data=None,
entities={},
entityCount=0,
topics={},
topicCount=0,
topicDetails={},
)
logger.error(
f"Error in Classification API process_request: {traceback.format_exc()}"
)
return PebbloJsonResponse.build(
body=response.model_dump(exclude_none=True), status_code=500
)
68 changes: 68 additions & 0 deletions tests/app/service/test_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# test_prompt_gov.py
from unittest.mock import patch

import pytest

from pebblo.app.libs.responses import PebbloJsonResponse
from pebblo.app.models.models import AiDataModel
from pebblo.app.service.classification import Classification


@pytest.fixture
def mock_entity_classifier():
with patch("pebblo.app.service.classification.EntityClassifier") as mock:
yield mock


@pytest.fixture
def mock_topic_classifier():
with patch("pebblo.app.service.classification.TopicClassifier") as mock:
yield mock


def test_process_request_success(mock_entity_classifier, mock_topic_classifier):
mock_entity_classifier_instance = mock_entity_classifier.return_value
mock_entity_classifier_instance.presidio_entity_classifier_and_anonymizer.return_value = (
{"us-ssn": 1},
1,
"anonymized document",
{
"us-ssn": [
{
"location": "16_27",
"confidence_score": "HIGH",
"entity_group": "pii-identification",
}
]
},
)

mock_topic_classifier_instance = mock_topic_classifier.return_value
mock_topic_classifier_instance.predict.return_value = ({}, 0, {})

data = {"data": "Sachin's SSN is 222-85-4836"}
cls_obj = Classification(data)
response = cls_obj.process_request()
expected_response = AiDataModel(
data="",
entities={"us-ssn": 1},
entityCount=1,
entityDetails={
"us-ssn": [
{
"location": "16_27",
"confidence_score": "HIGH",
"entity_group": "pii-identification",
}
]
},
topics={},
topicCount=0,
topicDetails={},
)
expected_response = PebbloJsonResponse.build(
body=expected_response.model_dump(exclude_none=True), status_code=200
)

assert response.status_code == expected_response.status_code
assert response.body == expected_response.body

0 comments on commit df01788

Please sign in to comment.