diff --git a/.env.example b/.env.example index c5e00580..35a872cb 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,4 @@ API_BASE_URL=https://rag.superagent.sh COHERE_API_KEY= -HUGGINGFACE_API_KEY= \ No newline at end of file +HUGGINGFACE_API_KEY= +JWT_SECRET= \ No newline at end of file diff --git a/api/delete.py b/api/delete.py index 731bbbbe..dfa57b39 100644 --- a/api/delete.py +++ b/api/delete.py @@ -1,12 +1,13 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Depends from models.delete import RequestPayload, ResponsePayload from service.vector_database import get_vector_service, VectorService +from auth.user import get_current_api_user router = APIRouter() @router.post("/delete", response_model=ResponsePayload) -async def delete(payload: RequestPayload): +async def delete(payload: RequestPayload, _api_user=Depends(get_current_api_user)): vector_service: VectorService = get_vector_service( index_name=payload.index_name, credentials=payload.vector_database ) diff --git a/api/ingest.py b/api/ingest.py index 43ade6b3..e984072f 100644 --- a/api/ingest.py +++ b/api/ingest.py @@ -1,13 +1,16 @@ from typing import Dict -from fastapi import APIRouter +from fastapi import APIRouter, Depends from models.ingest import RequestPayload from service.embedding import EmbeddingService +from auth.user import get_current_api_user router = APIRouter() @router.post("/ingest") -async def ingest(payload: RequestPayload) -> Dict: +async def ingest( + payload: RequestPayload, _api_user=Depends(get_current_api_user) +) -> Dict: embedding_service = EmbeddingService( files=payload.files, index_name=payload.index_name, diff --git a/api/query.py b/api/query.py index a1e9c8aa..05c8460e 100644 --- a/api/query.py +++ b/api/query.py @@ -1,12 +1,13 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Depends from models.query import RequestPayload, ResponsePayload from service.vector_database import get_vector_service, VectorService +from auth.user import get_current_api_user router = APIRouter() @router.post("/query", response_model=ResponsePayload) -async def query(payload: RequestPayload): +async def query(payload: RequestPayload, _api_user=Depends(get_current_api_user)): vector_service: VectorService = get_vector_service( index_name=payload.index_name, credentials=payload.vector_database ) diff --git a/auth/__init__.py b/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/auth/user.py b/auth/user.py new file mode 100644 index 00000000..fe960266 --- /dev/null +++ b/auth/user.py @@ -0,0 +1,33 @@ +import logging +import jwt + +from decouple import config +from fastapi import HTTPException, Security +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from superagent.client import AsyncSuperagent + +logger = logging.getLogger(__name__) +security = HTTPBearer() + + +def generate_jwt(data: dict): + token = jwt.encode({**data}, config("JWT_SECRET"), algorithm="HS256") + return token + + +def decode_jwt(token: str): + return jwt.decode(token, config("JWT_SECRET"), algorithms=["HS256"]) + + +async def get_current_api_user( + authorization: HTTPAuthorizationCredentials = Security(security), +): + token = authorization.credentials + decoded_token = decode_jwt(token) + superagent = AsyncSuperagent( + base_url="https://api.beta.superagent.sh", token=decoded_token + ) + api_user = superagent.api_user.get() + if not api_user: + raise HTTPException(status_code=401, detail="Invalid token or expired token") + return api_user diff --git a/requirements.txt b/requirements.txt index f18174fd..13e2aef7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,8 +45,10 @@ joblib==1.3.2 litellm==1.17.5 llama-index==0.9.30 loguru==0.7.2 +lxml==5.1.0 MarkupSafe==2.1.3 marshmallow==3.20.2 +mpmath==1.3.0 multidict==6.0.4 mypy-extensions==1.0.0 nest-asyncio==1.5.8 @@ -57,34 +59,41 @@ openai==1.7.2 packaging==23.2 pandas==2.1.4 pathspec==0.12.1 +pillow==10.2.0 pinecone-client==3.0.0 platformdirs==4.1.0 portalocker==2.8.2 protobuf==4.25.2 pycparser==2.21 -pydantic==2.5.3 -pydantic_core==2.14.6 +pydantic==2.4.2 +pydantic_core==2.10.1 +PyJWT==2.8.0 pypdf==3.17.4 python-dateutil==2.8.2 python-decouple==3.8 python-dotenv==1.0.0 +python-pptx==0.6.23 pytz==2023.3.post1 PyYAML==6.0.1 qdrant-client==1.7.0 regex==2023.12.25 requests==2.31.0 ruff==0.1.13 -setuptools==69.0.3 +safetensors==0.4.1 six==1.16.0 sniffio==1.3.0 soupsieve==2.5 SQLAlchemy==2.0.25 starlette==0.35.1 +superagent-py==0.1.55 +sympy==1.12 tenacity==8.2.3 tiktoken==0.5.2 tokenizers==0.15.0 toml==0.10.2 +torch==2.1.2 tqdm==4.66.1 +transformers==4.36.2 typing-inspect==0.9.0 typing_extensions==4.9.0 tzdata==2023.4 @@ -97,5 +106,6 @@ watchfiles==0.21.0 weaviate-client==3.26.0 websockets==12.0 wrapt==1.16.0 +XlsxWriter==3.1.9 yarl==1.9.4 zipp==3.17.0