Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swagger changes in pebblo server APIs #530

Merged
merged 5 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions pebblo/app/api/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import APIRouter, Depends

from pebblo.app.api.req_models import ReqDiscover, ReqLoaderDoc, ReqPrompt, ReqPromptGov
from pebblo.app.config.config import var_server_config_dict
from pebblo.app.service.prompt_gov import PromptGov
from pebblo.app.utils.handler_mapper import get_handler
Expand All @@ -17,34 +18,36 @@ def __init__(self, prefix: str):

@staticmethod
def discover(
data: dict, discover_obj=Depends(lambda: get_handler(handler_name="discover"))
data: ReqDiscover,
discover_obj=Depends(lambda: get_handler(handler_name="discover")),
):
# "/app/discover" API entrypoint
# Execute discover object based on a storage type
response = discover_obj.process_request(data)
response = discover_obj.process_request(data.model_dump())
return response

@staticmethod
def loader_doc(
data: dict, loader_doc_obj=Depends(lambda: get_handler(handler_name="loader"))
data: ReqLoaderDoc,
loader_doc_obj=Depends(lambda: get_handler(handler_name="loader")),
):
# "/loader/doc" API entrypoint
# Execute loader doc object based on a storage type
response = loader_doc_obj.process_request(data)
response = loader_doc_obj.process_request(data.model_dump())
return response

@staticmethod
def prompt(
data: dict, prompt_obj=Depends(lambda: get_handler(handler_name="prompt"))
data: ReqPrompt, prompt_obj=Depends(lambda: get_handler(handler_name="prompt"))
):
# "/prompt" API entrypoint
# Execute a prompt object based on a storage type
response = prompt_obj.process_request(data)
response = prompt_obj.process_request(data.model_dump())
return response

@staticmethod
def promptgov(data: dict):
def promptgov(data: ReqPromptGov):
# "/prompt/governance" API entrypoint
prompt_obj = PromptGov(data=data)
prompt_obj = PromptGov(data=data.model_dump())
response = prompt_obj.process_request()
return response
94 changes: 94 additions & 0 deletions pebblo/app/api/req_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""API Request Model Class"""

from typing import List, Optional, Union

from pydantic import BaseModel


class Runtime(BaseModel):
type: str = "local"
host: str
path: str
ip: Optional[str] = None
platform: str
os: str
os_version: str
language: str
language_version: str
runtime: str = "local"


class Framework(BaseModel):
name: str
version: str


class VectorDB(BaseModel):
name: Optional[str] = None
version: Optional[str] = None
location: Optional[str] = None
embedding_model: Optional[str] = None


class Model(BaseModel):
vendor: Optional[str] = None
name: Optional[str] = None


class ChainInfo(BaseModel):
name: str
model: Optional[Model] = None
vector_dbs: Optional[List[VectorDB]] = None


class ReqDiscover(BaseModel):
name: str
owner: str
description: Optional[str] = None
load_id: Optional[str] = None
runtime: Runtime
framework: Framework
chains: Optional[List[ChainInfo]] = None
plugin_version: str
client_version: Framework


class ReqLoaderDoc(BaseModel):
name: str
owner: str
docs: list[dict] = None
plugin_version: str
load_id: str
loader_details: dict
loading_end: bool
source_owner: str
classifier_location: str


class Context(BaseModel):
retrieved_from: Optional[str] = None
doc: Optional[str] = None
vector_db: str
pb_checksum: Optional[str] = None


class Prompt(BaseModel):
data: Optional[Union[list, str]] = None
entityCount: Optional[int] = None
entities: Optional[dict] = None
prompt_gov_enabled: Optional[bool] = None


class ReqPrompt(BaseModel):
name: str
context: Optional[List[Context]] = None
prompt: Optional[Prompt] = None
response: Optional[Prompt] = None
prompt_time: str
user: str
user_identities: Optional[List[str]] = None
classifier_location: str


class ReqPromptGov(BaseModel):
prompt: str
55 changes: 28 additions & 27 deletions pebblo/app/service/discovery_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,35 +124,36 @@ def _fetch_chain_details(self, app_metadata) -> list[Chain]:
logger.debug(f"Existing Chains : {chains}")

logger.debug(f"Input chains : {self.data.get('chains', [])}")
for chain in self.data.get("chains", []):
name = chain["name"]
model = chain["model"]
# vector db details
vector_db_details = []
for vector_db in chain.get("vector_dbs", []):
vector_db_obj = VectorDB(
name=vector_db.get("name"),
version=vector_db.get("version"),
location=vector_db.get("location"),
embeddingModel=vector_db.get("embedding_model"),
pkgInfo=None,
)

package_info = vector_db.get("pkg_info")
if package_info:
pkg_info_obj = PackageInfo(
projectHomePage=package_info.get("project_home_page"),
documentationUrl=package_info.get("documentation_url"),
pypiUrl=package_info.get("pypi_url"),
licenceType=package_info.get("licence_type"),
installedVia=package_info.get("installed_via"),
location=package_info.get("location"),
if self.data.get("chains") not in [None, []]:
for chain in self.data.get("chains", []):
name = chain["name"]
model = chain["model"]
# vector db details
vector_db_details = []
for vector_db in chain.get("vector_dbs", []):
vector_db_obj = VectorDB(
name=vector_db.get("name"),
version=vector_db.get("version"),
location=vector_db.get("location"),
embeddingModel=vector_db.get("embedding_model"),
pkgInfo=None,
)
vector_db_obj.pkgInfo = pkg_info_obj

vector_db_details.append(vector_db_obj)
chain_obj = Chain(name=name, model=model, vectorDbs=vector_db_details)
chains.append(chain_obj.model_dump())
package_info = vector_db.get("pkg_info")
if package_info:
pkg_info_obj = PackageInfo(
projectHomePage=package_info.get("project_home_page"),
documentationUrl=package_info.get("documentation_url"),
pypiUrl=package_info.get("pypi_url"),
licenceType=package_info.get("licence_type"),
installedVia=package_info.get("installed_via"),
location=package_info.get("location"),
)
vector_db_obj.pkgInfo = pkg_info_obj

vector_db_details.append(vector_db_obj)
chain_obj = Chain(name=name, model=model, vectorDbs=vector_db_details)
chains.append(chain_obj.model_dump())

logger.debug(f"Application Name [{self.application_name}]: Chains: {chains}")
return chains
Expand Down
55 changes: 33 additions & 22 deletions tests/app/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,28 @@
client = TestClient(app)


app_discover_payload = {
"name": "Test App",
"owner": "Test owner",
"description": "This is a test app.",
"runtime": {
"type": "desktop",
"host": "MacBook-Pro.local",
"path": "Test/Path",
"ip": "127.0.0.1",
"platform": "macOS-14.6.1-arm64-i386-64bit",
"os": "Darwin",
"os_version": "Darwin Kernel Version 23.6.0",
"language": "python",
"language_version": "3.11.9",
"runtime": "Mac OSX",
},
"framework": {"name": "langchain", "version": "0.2.35"},
"plugin_version": "0.1",
"client_version": {"name": "langchain_community", "version": "0.2.12"},
}


@pytest.fixture(scope="module")
def mocked_objects():
with (
Expand Down Expand Up @@ -96,13 +118,7 @@ def test_app_discover_success(mock_write_json_to_file, mock_pebblo_server_versio
Test the app discover endpoint.
"""
mock_write_json_to_file.return_value = None
app_payload = {
"name": "Test App",
"owner": "Test owner",
"description": "This is a test app.",
"plugin_version": "0.1",
}
response = client.post("/v1/app/discover", json=app_payload)
response = client.post("/v1/app/discover", json=app_discover_payload)

# Assertions
assert response.status_code == 200
Expand All @@ -115,28 +131,22 @@ def test_app_discover_validation_errors(mock_write_json_to_file):
Test the app discover endpoint with validation errors.
"""
mock_write_json_to_file.return_value = None
app = {
"owner": "Test owner",
"description": "This is a test app.",
"plugin_version": "0.1",
}
response = client.post("/v1/app/discover", json=app)
assert response.status_code == 400
assert "1 validation error for AiApp" in response.json()["message"]
app_payload = app_discover_payload.copy()
app_payload.pop("name")

response = client.post("/v1/app/discover", json=app_payload)
assert response.status_code == 422
assert "'type': 'missing', 'loc': ['body', 'name'], 'msg': 'Field required'" in str(
response.json()["detail"]
)


def test_app_discover_server_error(mock_write_json_to_file):
"""
Test the app discover endpoint with server error.
"""
mock_write_json_to_file.side_effect = Exception("Mocked exception")
app_payload = {
"name": "Test App",
"owner": "Test owner",
"description": "This is a test app.",
"plugin_version": "0.1",
}
response = client.post("/v1/app/discover", json=app_payload)
response = client.post("/v1/app/discover", json=app_discover_payload)

# Assertions
assert response.status_code == 500
Expand Down Expand Up @@ -186,6 +196,7 @@ def test_loader_doc_success(
"source_aggr_size": 306,
},
"plugin_version": "0.1.0",
"classifier_location": "local",
}
response = client.post("/v1/loader/doc", json=loader_doc)
assert response.status_code == 200
Expand Down
Loading