Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added changes for new api for entity and topic classification #518

Merged
merged 10 commits into from
Sep 10, 2024
21 changes: 21 additions & 0 deletions pebblo/app/api/v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from fastapi import APIRouter

from pebblo.app.config.config import var_server_config_dict
from pebblo.app.service.classification import Classification

config_details = var_server_config_dict.get()


class APIv1:
"""
Controller Class for all the api endpoints for App resource.
"""

def __init__(self, prefix: str):
self.router = APIRouter(prefix=prefix)

@staticmethod
def classify_data(data: dict):
cls_obj = Classification(data)
response = cls_obj.process_request()
return response
3 changes: 2 additions & 1 deletion pebblo/app/config/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pebblo.app.routers.redirection_router import redirect_router_instance

with redirect_stdout(StringIO()), redirect_stderr(StringIO()):
from pebblo.app.routers.routers import router_instance
from pebblo.app.routers.routers import api_v1_router_instance, router_instance
from pebblo.log import get_logger, get_uvicorn_logconfig

logger = get_logger(__name__)
Expand All @@ -42,6 +42,7 @@ def __init__(self, config_details):
self.app = FastAPI(exception_handlers=exception_handlers)
# Register the router instance with the main app
self.app.include_router(router_instance.router)
self.app.include_router(api_v1_router_instance.router)
self.app.include_router(local_ui_router_instance.router)
self.app.include_router(redirect_router_instance.router)
# Adding cors
Expand Down
4 changes: 2 additions & 2 deletions pebblo/app/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ class LoaderMetadata(BaseModel):

class AiDataModel(BaseModel):
data: Optional[Union[list, str]] = None
entityCount: int
entityCount: Optional[int] = 0
entities: dict
entityDetails: Optional[dict] = {}
topicCount: Optional[int] = None
topicCount: Optional[int] = 0
topics: Optional[dict] = {}
topicDetails: Optional[dict] = {}

Expand Down
10 changes: 10 additions & 0 deletions pebblo/app/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
"""

from pebblo.app.api.api import App
from pebblo.app.api.v1 import APIv1

# Create an instance of APp with a specific prefix
router_instance = App(prefix="/v1")
api_v1_router_instance = APIv1(prefix="/api/v1")
srics marked this conversation as resolved.
Show resolved Hide resolved

# Add routes to the class-based router
router_instance.router.add_api_route(
Expand All @@ -32,3 +34,11 @@
response_model=dict,
response_model_exclude_none=True,
)

api_v1_router_instance.router.add_api_route(
"/classify",
APIv1.classify_data,
methods=["POST"],
response_model=dict,
response_model_exclude_none=True,
)
134 changes: 134 additions & 0 deletions pebblo/app/service/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Module for text classification
"""

import traceback
from enum import Enum
from typing import Optional

from pydantic import BaseModel, Field, ValidationError

from pebblo.app.libs.responses import PebbloJsonResponse
from pebblo.app.models.models import AiDataModel
from pebblo.entity_classifier.entity_classifier import EntityClassifier
from pebblo.log import get_logger
from pebblo.topic_classifier.topic_classifier import TopicClassifier


class ClassificationMode(Enum):
ENTITY = "entity"
TOPIC = "topic"
ALL = "all"


class ReqClassifier(BaseModel):
data: str
mode: Optional[ClassificationMode] = Field(default=ClassificationMode.ALL)
shreyas-damle marked this conversation as resolved.
Show resolved Hide resolved
anonymize: Optional[bool] = Field(default=False)

class Config:
extra = "forbid"


logger = get_logger(__name__)
topic_classifier_obj = TopicClassifier()
entity_classifier_obj = EntityClassifier()


class Classification:
"""
Classification wrapper class for Entity and Semantic classification with anonymization
"""

def __init__(self, input: dict):
self.input = input

def _get_classifier_response(self, req: ReqClassifier):
"""
Processes the input prompt through the entity classifier and anonymizer, and returns
the resulting information encapsulated in an AiDataModel object.

Returns:
AiDataModel: An object containing the anonymized document, entities, and their counts.
"""
doc_info = AiDataModel(
data=None,
entities={},
entityCount=0,
entityDetails={},
topics={},
topicCount=0,
topicDetails={},
)
try:
# Process entity classification
if req.mode in [ClassificationMode.ENTITY, ClassificationMode.ALL]:
(
entities,
entity_count,
anonymized_doc,
entity_details,
) = entity_classifier_obj.presidio_entity_classifier_and_anonymizer(
req.data,
anonymize_snippets=req.anonymize,
)
doc_info.entities = entities
doc_info.entityCount = entity_count
doc_info.entityDetails = entity_details
doc_info.data = anonymized_doc if req.anonymize else ""

# Process topic classification
if req.mode in [ClassificationMode.TOPIC, ClassificationMode.ALL]:
topics, topic_count, topic_details = topic_classifier_obj.predict(
req.data
)
doc_info.topics = topics
doc_info.topicCount = topic_count
doc_info.topicDetails = topic_details
return doc_info
except (KeyError, ValueError, RuntimeError) as e:
logger.error(f"Failed to get classifier response: {e}")
return doc_info
except Exception as e:
logger.error(f"Unexpected error:{e}\n{traceback.format_exc()}")
return doc_info

def process_request(self):
"""
Processes the user request for classification and returns a structured response.

Returns:
PebbloJsonResponse: The response object containing classification results or error details.
"""
try:
req = ReqClassifier.model_validate(self.input)
if not req.data:
return PebbloJsonResponse.build(
body={"error": "Input data is missing"}, status_code=400
)
doc_info = self._get_classifier_response(req)
return PebbloJsonResponse.build(
body=doc_info.model_dump(exclude_none=True), status_code=200
)
except ValidationError as e:
logger.error(
f"Validation error in Classification API process_request:{e}\n{traceback.format_exc()}"
)
return PebbloJsonResponse.build(
body={"error": f"Validation error: {e}"}, status_code=400
)
except Exception:
response = AiDataModel(
data=None,
entities={},
entityCount=0,
topics={},
topicCount=0,
topicDetails={},
)
logger.error(
f"Error in Classification API process_request: {traceback.format_exc()}"
)
return PebbloJsonResponse.build(
body=response.model_dump(exclude_none=True), status_code=500
)
68 changes: 68 additions & 0 deletions tests/app/service/test_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# test_prompt_gov.py
from unittest.mock import patch

import pytest

from pebblo.app.libs.responses import PebbloJsonResponse
from pebblo.app.models.models import AiDataModel
from pebblo.app.service.classification import Classification


@pytest.fixture
def mock_entity_classifier():
with patch("pebblo.app.service.classification.EntityClassifier") as mock:
yield mock


@pytest.fixture
def mock_topic_classifier():
with patch("pebblo.app.service.classification.TopicClassifier") as mock:
yield mock


def test_process_request_success(mock_entity_classifier, mock_topic_classifier):
mock_entity_classifier_instance = mock_entity_classifier.return_value
mock_entity_classifier_instance.presidio_entity_classifier_and_anonymizer.return_value = (
{"us-ssn": 1},
1,
"anonymized document",
{
"us-ssn": [
{
"location": "16_27",
"confidence_score": "HIGH",
"entity_group": "pii-identification",
}
]
},
)

mock_topic_classifier_instance = mock_topic_classifier.return_value
mock_topic_classifier_instance.predict.return_value = ({}, 0, {})

data = {"data": "Sachin's SSN is 222-85-4836"}
cls_obj = Classification(data)
response = cls_obj.process_request()
expected_response = AiDataModel(
data="",
entities={"us-ssn": 1},
entityCount=1,
entityDetails={
"us-ssn": [
{
"location": "16_27",
"confidence_score": "HIGH",
"entity_group": "pii-identification",
}
]
},
topics={},
topicCount=0,
topicDetails={},
)
expected_response = PebbloJsonResponse.build(
body=expected_response.model_dump(exclude_none=True), status_code=200
)

assert response.status_code == expected_response.status_code
assert response.body == expected_response.body