Skip to content

Commit

Permalink
!feat: secure most of api behind api_key
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaisberg authored and Gaisberg committed Oct 18, 2024
1 parent 8501c36 commit c65d817
Show file tree
Hide file tree
Showing 18 changed files with 457 additions and 412 deletions.
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ HARD_RESET=false
# This will attempt to fix broken symlinks in the library, and then exit after running!
REPAIR_SYMLINKS=false

# Manual api key, must be 32 characters long
API_KEY=1234567890qwertyuiopas

# This is the number of workers to use for reindexing symlinks after a database reset.
# More workers = faster symlinking but uses more memory.
# For lower end machines, stick to around 1-3.
Expand Down
14 changes: 14 additions & 0 deletions src/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from fastapi import HTTPException, Security, status
from fastapi.security import APIKeyHeader
from program.settings.manager import settings_manager

api_key_header = APIKeyHeader(name="x-api-key")

def resolve_api_key(api_key_header: str = Security(api_key_header)):
if api_key_header == settings_manager.settings.api_key:
return True
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid API key"
)
5 changes: 0 additions & 5 deletions src/controllers/models/shared.py

This file was deleted.

25 changes: 5 additions & 20 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,12 @@
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request

from controllers.default import router as default_router
from controllers.items import router as items_router
from controllers.scrape import router as scrape_router
from controllers.settings import router as settings_router
from controllers.tmdb import router as tmdb_router
from controllers.webhooks import router as webhooks_router
from controllers.ws import router as ws_router
from scalar_fastapi import get_scalar_api_reference
from program import Program
from program.settings.models import get_version
from routers import app_router
from scalar_fastapi import get_scalar_api_reference
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from utils.cli import handle_args
from utils.logger import logger

Expand Down Expand Up @@ -62,7 +55,6 @@ async def scalar_html():
)

app.program = Program()

app.add_middleware(LoguruMiddleware)
app.add_middleware(
CORSMiddleware,
Expand All @@ -72,14 +64,7 @@ async def scalar_html():
allow_headers=["*"],
)

app.include_router(default_router)
app.include_router(settings_router)
app.include_router(items_router)
app.include_router(scrape_router)
app.include_router(webhooks_router)
app.include_router(tmdb_router)
app.include_router(ws_router)

app.include_router(app_router)

class Server(uvicorn.Server):
def install_signal_handlers(self):
Expand Down
18 changes: 5 additions & 13 deletions src/program/settings/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from RTN.models import SettingsModel

from program.settings.migratable import MigratableBaseModel
from utils import root_dir
from utils import generate_api_key, get_version

deprecation_warning = "This has been deprecated and will be removed in a future version."

Expand Down Expand Up @@ -311,18 +311,6 @@ class RTNSettingsModel(SettingsModel, Observable):
class IndexerModel(Observable):
update_interval: int = 60 * 60


def get_version() -> str:
with open(root_dir / "pyproject.toml") as file:
pyproject_toml = file.read()

match = re.search(r'version = "(.+)"', pyproject_toml)
if match:
version = match.group(1)
else:
raise ValueError("Could not find version in pyproject.toml")
return version

class LoggingModel(Observable):
...

Expand Down Expand Up @@ -356,6 +344,7 @@ class PostProcessing(Observable):

class AppModel(Observable):
version: str = get_version()
api_key: str = ""
debug: bool = True
log: bool = True
force_refresh: bool = False
Expand All @@ -378,3 +367,6 @@ def __init__(self, **data: Any):
super().__init__(**data)
if existing_version < current_version:
self.version = current_version

if self.api_key == "":
self.api_key = generate_api_key()
30 changes: 30 additions & 0 deletions src/routers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from auth import resolve_api_key
from fastapi import Depends, Request
from fastapi.routing import APIRouter
from program.settings.manager import settings_manager
from routers.models.shared import RootResponse
from routers.secure.default import router as default_router
from routers.secure.items import router as items_router
from routers.secure.scrape import router as scrape_router
from routers.secure.settings import router as settings_router
from routers.secure.tmdb import router as tmdb_router
from routers.secure.webhooks import router as webooks_router
from routers.secure.ws import router as ws_router

API_VERSION = "v1"

app_router = APIRouter(prefix=f"/api/{API_VERSION}")
@app_router.get("/", operation_id="root")
async def root(_: Request) -> RootResponse:
return {
"message": "Riven is running!",
"version": settings_manager.settings.version,
}

app_router.include_router(default_router, dependencies=[Depends(resolve_api_key)])
app_router.include_router(items_router, dependencies=[Depends(resolve_api_key)])
app_router.include_router(scrape_router, dependencies=[Depends(resolve_api_key)])
app_router.include_router(settings_router, dependencies=[Depends(resolve_api_key)])
app_router.include_router(tmdb_router, dependencies=[Depends(resolve_api_key)])
app_router.include_router(webooks_router, dependencies=[Depends(resolve_api_key)])
app_router.include_router(ws_router, dependencies=[Depends(resolve_api_key)])
File renamed without changes.
File renamed without changes.
8 changes: 8 additions & 0 deletions src/routers/models/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel


class MessageResponse(BaseModel):
message: str

class RootResponse(MessageResponse):
version: str
Empty file added src/routers/secure/__init__.py
Empty file.
Loading

0 comments on commit c65d817

Please sign in to comment.