Skip to content

Commit

Permalink
lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sridhar-daxa committed Sep 9, 2024
1 parent a5b1264 commit 87e77d6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
4 changes: 2 additions & 2 deletions pebblo/app/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from pebblo.app.config.config import var_server_config_dict
from pebblo.app.service.classification import Classification


config_details = var_server_config_dict.get()


class APIv1:
"""
Controller Class for all the api endpoints for App resource.
Expand All @@ -18,4 +18,4 @@ def __init__(self, prefix: str):
def classify_data(data: dict):
cls_obj = Classification(data)
response = cls_obj.process_request()
return response
return response
26 changes: 16 additions & 10 deletions pebblo/app/service/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
"""

import traceback

from pydantic import ValidationError
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional
from enum import Enum
from typing import Optional

from pydantic import BaseModel, ConfigDict, Field, ValidationError

from pebblo.app.libs.responses import PebbloJsonResponse
from pebblo.app.models.models import AiDataModel
Expand All @@ -16,18 +15,19 @@
from pebblo.topic_classifier.topic_classifier import TopicClassifier



class ClassificationMode(Enum):
ENTITY = "entity"
TOPIC = "topic"
ALL = "all"


class ReqClassifier(BaseModel):
data: str
mode: Optional[ClassificationMode] = Field(default=ClassificationMode.ALL)
anonymize: Optional[bool] = Field(default=False)

model_config = ConfigDict(extra='forbid')
model_config = ConfigDict(extra="forbid")


logger = get_logger(__name__)
topic_classifier_obj = TopicClassifier()
Expand All @@ -38,11 +38,11 @@ class Classification:
Class for loader doc related task
"""

def __init__(self, input:dict):
def __init__(self, input: dict):
self.input = input
self.entity_classifier_obj = EntityClassifier()

def _get_classifier_response(self, req:ReqClassifier):
def _get_classifier_response(self, req: ReqClassifier):
"""
Processes the input prompt through the entity classifier and anonymizer, and returns
the resulting information encapsulated in an AiDataModel object.
Expand All @@ -60,7 +60,10 @@ def _get_classifier_response(self, req:ReqClassifier):
topicDetails={},
)
try:
if req.mode == ClassificationMode.ENTITY or req.mode == ClassificationMode.ALL:
if (
req.mode == ClassificationMode.ENTITY
or req.mode == ClassificationMode.ALL
):
(
entities,
entity_count,
Expand All @@ -77,7 +80,10 @@ def _get_classifier_response(self, req:ReqClassifier):
doc_info.data = anonymized_doc
else:
doc_info.data = ""
if req.mode == ClassificationMode.TOPIC or req.mode == ClassificationMode.ALL:
if (
req.mode == ClassificationMode.TOPIC
or req.mode == ClassificationMode.ALL
):
topics, topic_count, topic_details = topic_classifier_obj.predict(
req.data
)
Expand Down

0 comments on commit 87e77d6

Please sign in to comment.