From edcb60705c1008647fd977ae5b9d26ee7f2febb8 Mon Sep 17 00:00:00 2001 From: Gaisberg Date: Fri, 18 Oct 2024 17:06:16 +0300 Subject: [PATCH 1/6] !feat: secure most of api behind api_key --- .env.example | 3 + src/auth.py | 14 + src/controllers/models/shared.py | 5 - src/main.py | 25 +- src/program/settings/models.py | 18 +- src/routers/__init__.py | 30 ++ .../models/overseerr.py | 0 src/{controllers => routers}/models/plex.py | 0 src/routers/models/shared.py | 8 + src/routers/secure/__init__.py | 0 .../secure}/default.py | 457 +++++++++--------- src/{controllers => routers/secure}/items.py | 12 +- src/{controllers => routers/secure}/scrape.py | 2 +- .../secure}/settings.py | 259 +++++----- src/{controllers => routers/secure}/tmdb.py | 0 .../secure}/webhooks.py | 2 +- src/{controllers => routers/secure}/ws.py | 0 src/utils/__init__.py | 34 +- 18 files changed, 457 insertions(+), 412 deletions(-) create mode 100644 src/auth.py delete mode 100644 src/controllers/models/shared.py create mode 100644 src/routers/__init__.py rename src/{controllers => routers}/models/overseerr.py (100%) rename src/{controllers => routers}/models/plex.py (100%) create mode 100644 src/routers/models/shared.py create mode 100644 src/routers/secure/__init__.py rename src/{controllers => routers/secure}/default.py (92%) rename src/{controllers => routers/secure}/items.py (98%) rename src/{controllers => routers/secure}/scrape.py (99%) rename src/{controllers => routers/secure}/settings.py (95%) rename src/{controllers => routers/secure}/tmdb.py (100%) rename src/{controllers => routers/secure}/webhooks.py (98%) rename src/{controllers => routers/secure}/ws.py (100%) diff --git a/.env.example b/.env.example index cba5b9b9..2c5bcdc6 100644 --- a/.env.example +++ b/.env.example @@ -15,6 +15,9 @@ HARD_RESET=false # This will attempt to fix broken symlinks in the library, and then exit after running! REPAIR_SYMLINKS=false +# Manual api key, must be 32 characters long +API_KEY=1234567890qwertyuiopas + # This is the number of workers to use for reindexing symlinks after a database reset. # More workers = faster symlinking but uses more memory. # For lower end machines, stick to around 1-3. diff --git a/src/auth.py b/src/auth.py new file mode 100644 index 00000000..db3ac969 --- /dev/null +++ b/src/auth.py @@ -0,0 +1,14 @@ +from fastapi import HTTPException, Security, status +from fastapi.security import APIKeyHeader +from program.settings.manager import settings_manager + +api_key_header = APIKeyHeader(name="x-api-key") + +def resolve_api_key(api_key_header: str = Security(api_key_header)): + if api_key_header == settings_manager.settings.api_key: + return True + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing or invalid API key" + ) \ No newline at end of file diff --git a/src/controllers/models/shared.py b/src/controllers/models/shared.py deleted file mode 100644 index 53b5fefc..00000000 --- a/src/controllers/models/shared.py +++ /dev/null @@ -1,5 +0,0 @@ -from pydantic import BaseModel - - -class MessageResponse(BaseModel): - message: str \ No newline at end of file diff --git a/src/main.py b/src/main.py index ea321b96..d266d9b3 100644 --- a/src/main.py +++ b/src/main.py @@ -8,19 +8,12 @@ import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request - -from controllers.default import router as default_router -from controllers.items import router as items_router -from controllers.scrape import router as scrape_router -from controllers.settings import router as settings_router -from controllers.tmdb import router as tmdb_router -from controllers.webhooks import router as webhooks_router -from controllers.ws import router as ws_router -from scalar_fastapi import get_scalar_api_reference from program import Program from program.settings.models import get_version +from routers import app_router +from scalar_fastapi import get_scalar_api_reference +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request from utils.cli import handle_args from utils.logger import logger @@ -62,7 +55,6 @@ async def scalar_html(): ) app.program = Program() - app.add_middleware(LoguruMiddleware) app.add_middleware( CORSMiddleware, @@ -72,14 +64,7 @@ async def scalar_html(): allow_headers=["*"], ) -app.include_router(default_router) -app.include_router(settings_router) -app.include_router(items_router) -app.include_router(scrape_router) -app.include_router(webhooks_router) -app.include_router(tmdb_router) -app.include_router(ws_router) - +app.include_router(app_router) class Server(uvicorn.Server): def install_signal_handlers(self): diff --git a/src/program/settings/models.py b/src/program/settings/models.py index 2af11b9e..de1a6864 100644 --- a/src/program/settings/models.py +++ b/src/program/settings/models.py @@ -7,7 +7,7 @@ from RTN.models import SettingsModel from program.settings.migratable import MigratableBaseModel -from utils import root_dir +from utils import generate_api_key, get_version deprecation_warning = "This has been deprecated and will be removed in a future version." @@ -311,18 +311,6 @@ class RTNSettingsModel(SettingsModel, Observable): class IndexerModel(Observable): update_interval: int = 60 * 60 - -def get_version() -> str: - with open(root_dir / "pyproject.toml") as file: - pyproject_toml = file.read() - - match = re.search(r'version = "(.+)"', pyproject_toml) - if match: - version = match.group(1) - else: - raise ValueError("Could not find version in pyproject.toml") - return version - class LoggingModel(Observable): ... @@ -356,6 +344,7 @@ class PostProcessing(Observable): class AppModel(Observable): version: str = get_version() + api_key: str = "" debug: bool = True log: bool = True force_refresh: bool = False @@ -378,3 +367,6 @@ def __init__(self, **data: Any): super().__init__(**data) if existing_version < current_version: self.version = current_version + + if self.api_key == "": + self.api_key = generate_api_key() diff --git a/src/routers/__init__.py b/src/routers/__init__.py new file mode 100644 index 00000000..24ec207e --- /dev/null +++ b/src/routers/__init__.py @@ -0,0 +1,30 @@ +from auth import resolve_api_key +from fastapi import Depends, Request +from fastapi.routing import APIRouter +from program.settings.manager import settings_manager +from routers.models.shared import RootResponse +from routers.secure.default import router as default_router +from routers.secure.items import router as items_router +from routers.secure.scrape import router as scrape_router +from routers.secure.settings import router as settings_router +from routers.secure.tmdb import router as tmdb_router +from routers.secure.webhooks import router as webooks_router +from routers.secure.ws import router as ws_router + +API_VERSION = "v1" + +app_router = APIRouter(prefix=f"/api/{API_VERSION}") +@app_router.get("/", operation_id="root") +async def root(_: Request) -> RootResponse: + return { + "message": "Riven is running!", + "version": settings_manager.settings.version, + } + +app_router.include_router(default_router, dependencies=[Depends(resolve_api_key)]) +app_router.include_router(items_router, dependencies=[Depends(resolve_api_key)]) +app_router.include_router(scrape_router, dependencies=[Depends(resolve_api_key)]) +app_router.include_router(settings_router, dependencies=[Depends(resolve_api_key)]) +app_router.include_router(tmdb_router, dependencies=[Depends(resolve_api_key)]) +app_router.include_router(webooks_router, dependencies=[Depends(resolve_api_key)]) +app_router.include_router(ws_router, dependencies=[Depends(resolve_api_key)]) \ No newline at end of file diff --git a/src/controllers/models/overseerr.py b/src/routers/models/overseerr.py similarity index 100% rename from src/controllers/models/overseerr.py rename to src/routers/models/overseerr.py diff --git a/src/controllers/models/plex.py b/src/routers/models/plex.py similarity index 100% rename from src/controllers/models/plex.py rename to src/routers/models/plex.py diff --git a/src/routers/models/shared.py b/src/routers/models/shared.py new file mode 100644 index 00000000..62b8ccad --- /dev/null +++ b/src/routers/models/shared.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class MessageResponse(BaseModel): + message: str + +class RootResponse(MessageResponse): + version: str \ No newline at end of file diff --git a/src/routers/secure/__init__.py b/src/routers/secure/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/controllers/default.py b/src/routers/secure/default.py similarity index 92% rename from src/controllers/default.py rename to src/routers/secure/default.py index 9d349cb6..143c68b8 100644 --- a/src/controllers/default.py +++ b/src/routers/secure/default.py @@ -1,234 +1,223 @@ -from typing import Literal - -import requests -from controllers.models.shared import MessageResponse -from fastapi import APIRouter, HTTPException, Request -from loguru import logger -from program.content.trakt import TraktContent -from program.db.db import db -from program.media.item import Episode, MediaItem, Movie, Season, Show -from program.media.state import States -from program.settings.manager import settings_manager -from pydantic import BaseModel, Field -from sqlalchemy import func, select -from utils.event_manager import EventUpdate - -router = APIRouter( - responses={404: {"description": "Not found"}}, -) - - -class RootResponse(MessageResponse): - version: str - - -@router.get("/", operation_id="root") -async def root() -> RootResponse: - return { - "message": "Riven is running!", - "version": settings_manager.settings.version, - } - - -@router.get("/health", operation_id="health") -async def health(request: Request) -> MessageResponse: - return { - "message": request.app.program.initialized, - } - - -class RDUser(BaseModel): - id: int - username: str - email: str - points: int = Field(description="User's RD points") - locale: str - avatar: str = Field(description="URL to the user's avatar") - type: Literal["free", "premium"] - premium: int = Field(description="Premium subscription left in seconds") - - -@router.get("/rd", operation_id="rd") -async def get_rd_user() -> RDUser: - api_key = settings_manager.settings.downloaders.real_debrid.api_key - headers = {"Authorization": f"Bearer {api_key}"} - - proxy = ( - settings_manager.settings.downloaders.real_debrid.proxy_url - if settings_manager.settings.downloaders.real_debrid.proxy_enabled - else None - ) - - response = requests.get( - "https://api.real-debrid.com/rest/1.0/user", - headers=headers, - proxies=proxy if proxy else None, - timeout=10, - ) - - if response.status_code != 200: - return {"success": False, "message": response.json()} - - return response.json() - - -@router.get("/torbox", operation_id="torbox") -async def get_torbox_user(): - api_key = settings_manager.settings.downloaders.torbox.api_key - headers = {"Authorization": f"Bearer {api_key}"} - response = requests.get( - "https://api.torbox.app/v1/api/user/me", headers=headers, timeout=10 - ) - return response.json() - - -@router.get("/services", operation_id="services") -async def get_services(request: Request) -> dict[str, bool]: - data = {} - if hasattr(request.app.program, "services"): - for service in request.app.program.all_services.values(): - data[service.key] = service.initialized - if not hasattr(service, "services"): - continue - for sub_service in service.services.values(): - data[sub_service.key] = sub_service.initialized - return data - - -class TraktOAuthInitiateResponse(BaseModel): - auth_url: str - - -@router.get("/trakt/oauth/initiate", operation_id="trakt_oauth_initiate") -async def initiate_trakt_oauth(request: Request) -> TraktOAuthInitiateResponse: - trakt = request.app.program.services.get(TraktContent) - if trakt is None: - raise HTTPException(status_code=404, detail="Trakt service not found") - auth_url = trakt.perform_oauth_flow() - return {"auth_url": auth_url} - - -@router.get("/trakt/oauth/callback", operation_id="trakt_oauth_callback") -async def trakt_oauth_callback(code: str, request: Request) -> MessageResponse: - trakt = request.app.program.services.get(TraktContent) - if trakt is None: - raise HTTPException(status_code=404, detail="Trakt service not found") - success = trakt.handle_oauth_callback(code) - if success: - return {"message": "OAuth token obtained successfully"} - else: - raise HTTPException(status_code=400, detail="Failed to obtain OAuth token") - - -class StatsResponse(BaseModel): - total_items: int - total_movies: int - total_shows: int - total_seasons: int - total_episodes: int - total_symlinks: int - incomplete_items: int - incomplete_retries: dict[int, int] = Field( - description="Media item log string: number of retries" - ) - states: dict[States, int] - - -@router.get("/stats", operation_id="stats") -async def get_stats(_: Request) -> StatsResponse: - payload = {} - with db.Session() as session: - # Ensure the connection is open for the entire duration of the session - with session.connection().execution_options(stream_results=True) as conn: - movies_symlinks = conn.execute(select(func.count(Movie._id)).where(Movie.symlinked == True)).scalar_one() - episodes_symlinks = conn.execute(select(func.count(Episode._id)).where(Episode.symlinked == True)).scalar_one() - total_symlinks = movies_symlinks + episodes_symlinks - - total_movies = conn.execute(select(func.count(Movie._id))).scalar_one() - total_shows = conn.execute(select(func.count(Show._id))).scalar_one() - total_seasons = conn.execute(select(func.count(Season._id))).scalar_one() - total_episodes = conn.execute(select(func.count(Episode._id))).scalar_one() - total_items = conn.execute(select(func.count(MediaItem._id))).scalar_one() - - # Use a server-side cursor for batch processing - incomplete_retries = {} - batch_size = 1000 - - result = conn.execute( - select(MediaItem._id, MediaItem.scraped_times) - .where(MediaItem.last_state != States.Completed) - ) - - while True: - batch = result.fetchmany(batch_size) - if not batch: - break - - for media_item_id, scraped_times in batch: - incomplete_retries[media_item_id] = scraped_times - - states = {} - for state in States: - states[state] = conn.execute(select(func.count(MediaItem._id)).where(MediaItem.last_state == state)).scalar_one() - - payload["total_items"] = total_items - payload["total_movies"] = total_movies - payload["total_shows"] = total_shows - payload["total_seasons"] = total_seasons - payload["total_episodes"] = total_episodes - payload["total_symlinks"] = total_symlinks - payload["incomplete_items"] = len(incomplete_retries) - payload["incomplete_retries"] = incomplete_retries - payload["states"] = states - - return payload - -class LogsResponse(BaseModel): - logs: str - - -@router.get("/logs", operation_id="logs") -async def get_logs() -> str: - log_file_path = None - for handler in logger._core.handlers.values(): - if ".log" in handler._name: - log_file_path = handler._sink._path - break - - if not log_file_path: - return {"success": False, "message": "Log file handler not found"} - - try: - with open(log_file_path, "r") as log_file: - log_contents = log_file.read() - return {"logs": log_contents} - except Exception as e: - logger.error(f"Failed to read log file: {e}") - raise HTTPException(status_code=500, detail="Failed to read log file") - - -@router.get("/events", operation_id="events") -async def get_events( - request: Request, -) -> dict[str, list[EventUpdate]]: - return request.app.program.em.get_event_updates() - - -@router.get("/mount", operation_id="mount") -async def get_rclone_files() -> dict[str, str]: - """Get all files in the rclone mount.""" - import os - - rclone_dir = settings_manager.settings.symlink.rclone_path - file_map = {} - - def scan_dir(path): - with os.scandir(path) as entries: - for entry in entries: - if entry.is_file(): - file_map[entry.name] = entry.path - elif entry.is_dir(): - scan_dir(entry.path) - - scan_dir(rclone_dir) # dict of `filename: filepath`` - return file_map +from typing import Literal + +import requests +from fastapi import APIRouter, HTTPException, Request +from loguru import logger +from program.content.trakt import TraktContent +from program.db.db import db +from program.media.item import Episode, MediaItem, Movie, Season, Show +from program.media.state import States +from program.settings.manager import settings_manager +from pydantic import BaseModel, Field +from sqlalchemy import func, select +from utils.event_manager import EventUpdate + +from ..models.shared import MessageResponse + +router = APIRouter( + responses={404: {"description": "Not found"}}, +) + + +@router.get("/health", operation_id="health") +async def health(request: Request) -> MessageResponse: + return { + "message": str(request.app.program.initialized), + } + + +class RDUser(BaseModel): + id: int + username: str + email: str + points: int = Field(description="User's RD points") + locale: str + avatar: str = Field(description="URL to the user's avatar") + type: Literal["free", "premium"] + premium: int = Field(description="Premium subscription left in seconds") + + +@router.get("/rd", operation_id="rd") +async def get_rd_user() -> RDUser: + api_key = settings_manager.settings.downloaders.real_debrid.api_key + headers = {"Authorization": f"Bearer {api_key}"} + + proxy = ( + settings_manager.settings.downloaders.real_debrid.proxy_url + if settings_manager.settings.downloaders.real_debrid.proxy_enabled + else None + ) + + response = requests.get( + "https://api.real-debrid.com/rest/1.0/user", + headers=headers, + proxies=proxy if proxy else None, + timeout=10, + ) + + if response.status_code != 200: + return {"success": False, "message": response.json()} + + return response.json() + + +@router.get("/torbox", operation_id="torbox") +async def get_torbox_user(): + api_key = settings_manager.settings.downloaders.torbox.api_key + headers = {"Authorization": f"Bearer {api_key}"} + response = requests.get( + "https://api.torbox.app/v1/api/user/me", headers=headers, timeout=10 + ) + return response.json() + + +@router.get("/services", operation_id="services") +async def get_services(request: Request) -> dict[str, bool]: + data = {} + if hasattr(request.app.program, "services"): + for service in request.app.program.all_services.values(): + data[service.key] = service.initialized + if not hasattr(service, "services"): + continue + for sub_service in service.services.values(): + data[sub_service.key] = sub_service.initialized + return data + + +class TraktOAuthInitiateResponse(BaseModel): + auth_url: str + + +@router.get("/trakt/oauth/initiate", operation_id="trakt_oauth_initiate") +async def initiate_trakt_oauth(request: Request) -> TraktOAuthInitiateResponse: + trakt = request.app.program.services.get(TraktContent) + if trakt is None: + raise HTTPException(status_code=404, detail="Trakt service not found") + auth_url = trakt.perform_oauth_flow() + return {"auth_url": auth_url} + + +@router.get("/trakt/oauth/callback", operation_id="trakt_oauth_callback") +async def trakt_oauth_callback(code: str, request: Request) -> MessageResponse: + trakt = request.app.program.services.get(TraktContent) + if trakt is None: + raise HTTPException(status_code=404, detail="Trakt service not found") + success = trakt.handle_oauth_callback(code) + if success: + return {"message": "OAuth token obtained successfully"} + else: + raise HTTPException(status_code=400, detail="Failed to obtain OAuth token") + + +class StatsResponse(BaseModel): + total_items: int + total_movies: int + total_shows: int + total_seasons: int + total_episodes: int + total_symlinks: int + incomplete_items: int + incomplete_retries: dict[int, int] = Field( + description="Media item log string: number of retries" + ) + states: dict[States, int] + + +@router.get("/stats", operation_id="stats") +async def get_stats(_: Request) -> StatsResponse: + payload = {} + with db.Session() as session: + # Ensure the connection is open for the entire duration of the session + with session.connection().execution_options(stream_results=True) as conn: + movies_symlinks = conn.execute(select(func.count(Movie._id)).where(Movie.symlinked == True)).scalar_one() + episodes_symlinks = conn.execute(select(func.count(Episode._id)).where(Episode.symlinked == True)).scalar_one() + total_symlinks = movies_symlinks + episodes_symlinks + + total_movies = conn.execute(select(func.count(Movie._id))).scalar_one() + total_shows = conn.execute(select(func.count(Show._id))).scalar_one() + total_seasons = conn.execute(select(func.count(Season._id))).scalar_one() + total_episodes = conn.execute(select(func.count(Episode._id))).scalar_one() + total_items = conn.execute(select(func.count(MediaItem._id))).scalar_one() + + # Use a server-side cursor for batch processing + incomplete_retries = {} + batch_size = 1000 + + result = conn.execute( + select(MediaItem._id, MediaItem.scraped_times) + .where(MediaItem.last_state != States.Completed) + ) + + while True: + batch = result.fetchmany(batch_size) + if not batch: + break + + for media_item_id, scraped_times in batch: + incomplete_retries[media_item_id] = scraped_times + + states = {} + for state in States: + states[state] = conn.execute(select(func.count(MediaItem._id)).where(MediaItem.last_state == state)).scalar_one() + + payload["total_items"] = total_items + payload["total_movies"] = total_movies + payload["total_shows"] = total_shows + payload["total_seasons"] = total_seasons + payload["total_episodes"] = total_episodes + payload["total_symlinks"] = total_symlinks + payload["incomplete_items"] = len(incomplete_retries) + payload["incomplete_retries"] = incomplete_retries + payload["states"] = states + + return payload + +class LogsResponse(BaseModel): + logs: str + + +@router.get("/logs", operation_id="logs") +async def get_logs() -> str: + log_file_path = None + for handler in logger._core.handlers.values(): + if ".log" in handler._name: + log_file_path = handler._sink._path + break + + if not log_file_path: + return {"success": False, "message": "Log file handler not found"} + + try: + with open(log_file_path, "r") as log_file: + log_contents = log_file.read() + return {"logs": log_contents} + except Exception as e: + logger.error(f"Failed to read log file: {e}") + raise HTTPException(status_code=500, detail="Failed to read log file") + + +@router.get("/events", operation_id="events") +async def get_events( + request: Request, +) -> dict[str, list[EventUpdate]]: + return request.app.program.em.get_event_updates() + + +@router.get("/mount", operation_id="mount") +async def get_rclone_files() -> dict[str, str]: + """Get all files in the rclone mount.""" + import os + + rclone_dir = settings_manager.settings.symlink.rclone_path + file_map = {} + + def scan_dir(path): + with os.scandir(path) as entries: + for entry in entries: + if entry.is_file(): + file_map[entry.name] = entry.path + elif entry.is_dir(): + scan_dir(entry.path) + + scan_dir(rclone_dir) # dict of `filename: filepath`` + return file_map diff --git a/src/controllers/items.py b/src/routers/secure/items.py similarity index 98% rename from src/controllers/items.py rename to src/routers/secure/items.py index 458cb68f..b4b1f203 100644 --- a/src/controllers/items.py +++ b/src/routers/secure/items.py @@ -3,12 +3,6 @@ from typing import Literal, Optional import Levenshtein -from RTN import Torrent -from fastapi import APIRouter, HTTPException, Request -from sqlalchemy import func, select -from sqlalchemy.exc import NoResultFound - -from controllers.models.shared import MessageResponse from fastapi import APIRouter, HTTPException, Request from program.content import Overseerr from program.db.db import db @@ -20,12 +14,12 @@ get_parent_ids, reset_media_item, ) +from program.downloaders import Downloader, get_needed_media from program.media.item import MediaItem from program.media.state import States -from program.symlink import Symlinker -from program.downloaders import Downloader, get_needed_media from program.media.stream import Stream from program.scrapers.shared import rtn +from program.symlink import Symlinker from program.types import Event from pydantic import BaseModel from RTN import Torrent @@ -33,6 +27,8 @@ from sqlalchemy.exc import NoResultFound from utils.logger import logger +from ..models.shared import MessageResponse + router = APIRouter( prefix="/items", tags=["items"], diff --git a/src/controllers/scrape.py b/src/routers/secure/scrape.py similarity index 99% rename from src/controllers/scrape.py rename to src/routers/secure/scrape.py index e975705b..0985221b 100644 --- a/src/controllers/scrape.py +++ b/src/routers/secure/scrape.py @@ -94,7 +94,7 @@ async def scrape( for stream in results.values() ] - except StopIteration as e: + except StopIteration: raise HTTPException(status_code=204, detail="Media item not found") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/controllers/settings.py b/src/routers/secure/settings.py similarity index 95% rename from src/controllers/settings.py rename to src/routers/secure/settings.py index d02375c4..bfc78a6c 100644 --- a/src/controllers/settings.py +++ b/src/routers/secure/settings.py @@ -1,129 +1,130 @@ -from copy import copy -from typing import Any, Dict, List - -from controllers.models.shared import MessageResponse -from fastapi import APIRouter, HTTPException -from program.settings.manager import settings_manager -from program.settings.models import AppModel -from pydantic import BaseModel, ValidationError - - -class SetSettings(BaseModel): - key: str - value: Any - - -router = APIRouter( - prefix="/settings", - tags=["settings"], - responses={404: {"description": "Not found"}}, -) - - -@router.get("/schema", operation_id="get_settings_schema") -async def get_settings_schema() -> dict[str, Any]: - """ - Get the JSON schema for the settings. - """ - return settings_manager.settings.model_json_schema() - -@router.get("/load", operation_id="load_settings") -async def load_settings() -> MessageResponse: - settings_manager.load() - return { - "message": "Settings loaded!", - } - -@router.post("/save", operation_id="save_settings") -async def save_settings() -> MessageResponse: - settings_manager.save() - return { - "message": "Settings saved!", - } - - -@router.get("/get/all", operation_id="get_all_settings") -async def get_all_settings() -> AppModel: - return copy(settings_manager.settings) - - -@router.get("/get/{paths}", operation_id="get_settings") -async def get_settings(paths: str) -> dict[str, Any]: - current_settings = settings_manager.settings.model_dump() - data = {} - for path in paths.split(","): - keys = path.split(".") - current_obj = current_settings - - for k in keys: - if k not in current_obj: - return None - current_obj = current_obj[k] - - data[path] = current_obj - return data - - -@router.post("/set/all", operation_id="set_all_settings") -async def set_all_settings(new_settings: Dict[str, Any]) -> MessageResponse: - current_settings = settings_manager.settings.model_dump() - - def update_settings(current_obj, new_obj): - for key, value in new_obj.items(): - if isinstance(value, dict) and key in current_obj: - update_settings(current_obj[key], value) - else: - current_obj[key] = value - - update_settings(current_settings, new_settings) - - # Validate and save the updated settings - try: - updated_settings = settings_manager.settings.model_validate(current_settings) - settings_manager.load(settings_dict=updated_settings.model_dump()) - settings_manager.save() # Ensure the changes are persisted - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) - - return { - "message": "All settings updated successfully!", - } - -@router.post("/set", operation_id="set_settings") -async def set_settings(settings: List[SetSettings]) -> MessageResponse: - current_settings = settings_manager.settings.model_dump() - - for setting in settings: - keys = setting.key.split(".") - current_obj = current_settings - - # Navigate to the last key's parent object, ensuring all keys exist. - for k in keys[:-1]: - if k not in current_obj: - raise HTTPException( - status_code=400, - detail=f"Path '{'.'.join(keys[:-1])}' does not exist.", - ) - current_obj = current_obj[k] - - # Ensure the final key exists before setting the value. - if keys[-1] in current_obj: - current_obj[keys[-1]] = setting.value - else: - raise HTTPException( - status_code=400, - detail=f"Key '{keys[-1]}' does not exist in path '{'.'.join(keys[:-1])}'.", - ) - - # Validate and apply the updated settings to the AppModel instance - try: - updated_settings = settings_manager.settings.__class__(**current_settings) - settings_manager.load(settings_dict=updated_settings.model_dump()) - settings_manager.save() # Ensure the changes are persisted - except ValidationError as e: - raise HTTPException from e( - status_code=400, - detail=f"Failed to update settings: {str(e)}", - ) - - return {"message": "Settings updated successfully."} +from copy import copy +from typing import Any, Dict, List + +from fastapi import APIRouter, HTTPException +from program.settings.manager import settings_manager +from program.settings.models import AppModel +from pydantic import BaseModel, ValidationError + +from ..models.shared import MessageResponse + + +class SetSettings(BaseModel): + key: str + value: Any + + +router = APIRouter( + prefix="/settings", + tags=["settings"], + responses={404: {"description": "Not found"}}, +) + + +@router.get("/schema", operation_id="get_settings_schema") +async def get_settings_schema() -> dict[str, Any]: + """ + Get the JSON schema for the settings. + """ + return settings_manager.settings.model_json_schema() + +@router.get("/load", operation_id="load_settings") +async def load_settings() -> MessageResponse: + settings_manager.load() + return { + "message": "Settings loaded!", + } + +@router.post("/save", operation_id="save_settings") +async def save_settings() -> MessageResponse: + settings_manager.save() + return { + "message": "Settings saved!", + } + + +@router.get("/get/all", operation_id="get_all_settings") +async def get_all_settings() -> AppModel: + return copy(settings_manager.settings) + + +@router.get("/get/{paths}", operation_id="get_settings") +async def get_settings(paths: str) -> dict[str, Any]: + current_settings = settings_manager.settings.model_dump() + data = {} + for path in paths.split(","): + keys = path.split(".") + current_obj = current_settings + + for k in keys: + if k not in current_obj: + return None + current_obj = current_obj[k] + + data[path] = current_obj + return data + + +@router.post("/set/all", operation_id="set_all_settings") +async def set_all_settings(new_settings: Dict[str, Any]) -> MessageResponse: + current_settings = settings_manager.settings.model_dump() + + def update_settings(current_obj, new_obj): + for key, value in new_obj.items(): + if isinstance(value, dict) and key in current_obj: + update_settings(current_obj[key], value) + else: + current_obj[key] = value + + update_settings(current_settings, new_settings) + + # Validate and save the updated settings + try: + updated_settings = settings_manager.settings.model_validate(current_settings) + settings_manager.load(settings_dict=updated_settings.model_dump()) + settings_manager.save() # Ensure the changes are persisted + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + return { + "message": "All settings updated successfully!", + } + +@router.post("/set", operation_id="set_settings") +async def set_settings(settings: List[SetSettings]) -> MessageResponse: + current_settings = settings_manager.settings.model_dump() + + for setting in settings: + keys = setting.key.split(".") + current_obj = current_settings + + # Navigate to the last key's parent object, ensuring all keys exist. + for k in keys[:-1]: + if k not in current_obj: + raise HTTPException( + status_code=400, + detail=f"Path '{'.'.join(keys[:-1])}' does not exist.", + ) + current_obj = current_obj[k] + + # Ensure the final key exists before setting the value. + if keys[-1] in current_obj: + current_obj[keys[-1]] = setting.value + else: + raise HTTPException( + status_code=400, + detail=f"Key '{keys[-1]}' does not exist in path '{'.'.join(keys[:-1])}'.", + ) + + # Validate and apply the updated settings to the AppModel instance + try: + updated_settings = settings_manager.settings.__class__(**current_settings) + settings_manager.load(settings_dict=updated_settings.model_dump()) + settings_manager.save() # Ensure the changes are persisted + except ValidationError as e: + raise HTTPException from e( + status_code=400, + detail=f"Failed to update settings: {str(e)}", + ) + + return {"message": "Settings updated successfully."} diff --git a/src/controllers/tmdb.py b/src/routers/secure/tmdb.py similarity index 100% rename from src/controllers/tmdb.py rename to src/routers/secure/tmdb.py diff --git a/src/controllers/webhooks.py b/src/routers/secure/webhooks.py similarity index 98% rename from src/controllers/webhooks.py rename to src/routers/secure/webhooks.py index 097f6c1a..3f6dd500 100644 --- a/src/controllers/webhooks.py +++ b/src/routers/secure/webhooks.py @@ -9,7 +9,7 @@ from requests import RequestException from utils.logger import logger -from .models.overseerr import OverseerrWebhook +from ..models.overseerr import OverseerrWebhook router = APIRouter( prefix="/webhook", diff --git a/src/controllers/ws.py b/src/routers/secure/ws.py similarity index 100% rename from src/controllers/ws.py rename to src/routers/secure/ws.py diff --git a/src/utils/__init__.py b/src/utils/__init__.py index e2f85f0e..f5faed3f 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,6 +1,38 @@ +import os +import re +import secrets +import string from pathlib import Path +from loguru import logger + root_dir = Path(__file__).resolve().parents[2] data_dir_path = root_dir / "data" -alembic_dir = data_dir_path / "alembic" \ No newline at end of file +alembic_dir = data_dir_path / "alembic" + +def get_version() -> str: + with open(root_dir / "pyproject.toml") as file: + pyproject_toml = file.read() + + match = re.search(r'version = "(.+)"', pyproject_toml) + if match: + version = match.group(1) + else: + raise ValueError("Could not find version in pyproject.toml") + return version + +def generate_api_key(): + """Generate a secure API key of the specified length.""" + API_KEY = os.getenv("API_KEY") + if len(API_KEY) != 32: + logger.warning("env.API_KEY is not 32 characters long, generating a new one...") + characters = string.ascii_letters + string.digits + + # Generate the API key + api_key = "".join(secrets.choice(characters) for _ in range(32)) + logger.warning(f"New api key: {api_key}") + else: + api_key = API_KEY + + return api_key \ No newline at end of file From 1575499d778004c3bb97a2a77f34782b384bd416 Mon Sep 17 00:00:00 2001 From: Gaisberg Date: Sat, 19 Oct 2024 15:58:48 +0300 Subject: [PATCH 2/6] chore: add generateapikey endpoint, use loguru imports for logging --- src/main.py | 2 +- src/program/content/listrr.py | 2 +- src/program/content/mdblist.py | 2 +- src/program/content/overseerr.py | 2 +- src/program/content/plex_watchlist.py | 2 +- src/program/content/trakt.py | 2 +- src/program/db/db.py | 2 +- src/program/db/db_functions.py | 2 +- src/program/downloaders/torbox.py | 2 +- src/program/indexers/tmdb.py | 2 +- src/program/indexers/trakt.py | 2 +- src/program/libraries/symlink.py | 2 +- src/program/media/item.py | 2 +- src/program/media/stream.py | 2 +- src/program/program.py | 2 +- src/program/scrapers/__init__.py | 2 +- src/program/scrapers/annatar.py | 2 +- src/program/scrapers/comet.py | 2 +- src/program/scrapers/jackett.py | 2 +- src/program/scrapers/knightcrawler.py | 2 +- src/program/scrapers/mediafusion.py | 2 +- src/program/scrapers/orionoid.py | 2 +- src/program/scrapers/prowlarr.py | 2 +- src/program/scrapers/shared.py | 2 +- src/program/scrapers/torbox.py | 2 +- src/program/scrapers/torrentio.py | 2 +- src/program/scrapers/zilean.py | 2 +- src/program/settings/versions.py | 2 +- src/program/state_transition.py | 2 +- src/program/symlink.py | 2 +- src/program/updaters/__init__.py | 2 +- src/program/updaters/emby.py | 2 +- src/program/updaters/jellyfin.py | 2 +- src/program/updaters/plex.py | 2 +- src/routers/secure/default.py | 8 + src/routers/secure/items.py | 2 +- src/routers/secure/webhooks.py | 2 +- src/tests/test_symlink_creation.py | 2 +- src/utils/cli.py | 2 +- src/utils/{logger.py => logging.py} | 302 +++++++++++++------------- src/utils/notifications.py | 2 +- 41 files changed, 198 insertions(+), 190 deletions(-) rename src/utils/{logger.py => logging.py} (97%) diff --git a/src/main.py b/src/main.py index d266d9b3..5bdf12b7 100644 --- a/src/main.py +++ b/src/main.py @@ -15,7 +15,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from utils.cli import handle_args -from utils.logger import logger +from loguru import logger class LoguruMiddleware(BaseHTTPMiddleware): diff --git a/src/program/content/listrr.py b/src/program/content/listrr.py index e8e2ccd3..241452a6 100644 --- a/src/program/content/listrr.py +++ b/src/program/content/listrr.py @@ -6,7 +6,7 @@ from program.indexers.trakt import get_imdbid_from_tmdb from program.media.item import MediaItem from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger from utils.request import get, ping diff --git a/src/program/content/mdblist.py b/src/program/content/mdblist.py index a5288825..d3c226d8 100644 --- a/src/program/content/mdblist.py +++ b/src/program/content/mdblist.py @@ -4,7 +4,7 @@ from program.media.item import MediaItem from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger from utils.ratelimiter import RateLimiter, RateLimitExceeded from utils.request import get, ping diff --git a/src/program/content/overseerr.py b/src/program/content/overseerr.py index 40760055..94f50f12 100644 --- a/src/program/content/overseerr.py +++ b/src/program/content/overseerr.py @@ -8,7 +8,7 @@ from program.indexers.trakt import get_imdbid_from_tmdb from program.media.item import MediaItem from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger from utils.request import delete, get, ping, post diff --git a/src/program/content/plex_watchlist.py b/src/program/content/plex_watchlist.py index e8ef7f06..d70596d3 100644 --- a/src/program/content/plex_watchlist.py +++ b/src/program/content/plex_watchlist.py @@ -7,7 +7,7 @@ from program.media.item import Episode, MediaItem, Movie, Season, Show from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger from utils.request import get, ping diff --git a/src/program/content/trakt.py b/src/program/content/trakt.py index 468976e8..b280cea1 100644 --- a/src/program/content/trakt.py +++ b/src/program/content/trakt.py @@ -8,7 +8,7 @@ from program.media.item import MediaItem from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger from utils.ratelimiter import RateLimiter from utils.request import get, post diff --git a/src/program/db/db.py b/src/program/db/db.py index c15463a7..73105ef7 100644 --- a/src/program/db/db.py +++ b/src/program/db/db.py @@ -8,7 +8,7 @@ from program.settings.manager import settings_manager from utils import data_dir_path -from utils.logger import logger +from loguru import logger engine_options = { "pool_size": 25, # Prom: Set to 1 when debugging sql queries diff --git a/src/program/db/db_functions.py b/src/program/db/db_functions.py index 6aaf3597..dbcfc290 100644 --- a/src/program/db/db_functions.py +++ b/src/program/db/db_functions.py @@ -11,7 +11,7 @@ from program.media.stream import Stream, StreamBlacklistRelation, StreamRelation from program.settings.manager import settings_manager from utils import alembic_dir -from utils.logger import logger +from loguru import logger from .db import alembic, db diff --git a/src/program/downloaders/torbox.py b/src/program/downloaders/torbox.py index df6be758..f59310b9 100644 --- a/src/program/downloaders/torbox.py +++ b/src/program/downloaders/torbox.py @@ -14,7 +14,7 @@ from program.media.state import States from program.media.stream import Stream from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger from utils.request import get, post API_URL = "https://api.torbox.app/v1/api" diff --git a/src/program/indexers/tmdb.py b/src/program/indexers/tmdb.py index 2ff9d816..2d213661 100644 --- a/src/program/indexers/tmdb.py +++ b/src/program/indexers/tmdb.py @@ -3,7 +3,7 @@ from typing import Generic, Literal, Optional, TypeVar from pydantic import BaseModel -from utils.logger import logger +from loguru import logger from utils.request import get TMDB_READ_ACCESS_TOKEN = "eyJhbGciOiJIUzI1NiJ9.eyJhdWQiOiJlNTkxMmVmOWFhM2IxNzg2Zjk3ZTE1NWY1YmQ3ZjY1MSIsInN1YiI6IjY1M2NjNWUyZTg5NGE2MDBmZjE2N2FmYyIsInNjb3BlcyI6WyJhcGlfcmVhZCJdLCJ2ZXJzaW9uIjoxfQ.xrIXsMFJpI1o1j5g2QpQcFP1X3AfRjFA5FlBFO5Naw8" # noqa: S105 diff --git a/src/program/indexers/trakt.py b/src/program/indexers/trakt.py index 5cb9b223..ae6541c4 100644 --- a/src/program/indexers/trakt.py +++ b/src/program/indexers/trakt.py @@ -7,7 +7,7 @@ from program.db.db import db from program.media.item import Episode, MediaItem, Movie, Season, Show from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger from utils.request import get CLIENT_ID = "0183a05ad97098d87287fe46da4ae286f434f32e8e951caad4cc147c947d79a3" diff --git a/src/program/libraries/symlink.py b/src/program/libraries/symlink.py index f64095d8..09c2f51a 100644 --- a/src/program/libraries/symlink.py +++ b/src/program/libraries/symlink.py @@ -10,7 +10,7 @@ from program.db.db import db from program.media.subtitle import Subtitle from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger if TYPE_CHECKING: from program.media.item import Episode, MediaItem, Movie, Show diff --git a/src/program/media/item.py b/src/program/media/item.py index 955222e8..4cec0767 100644 --- a/src/program/media/item.py +++ b/src/program/media/item.py @@ -13,7 +13,7 @@ from program.db.db import db from program.media.state import States from program.media.subtitle import Subtitle -from utils.logger import logger +from loguru import logger from ..db.db_functions import blacklist_stream, reset_streams from .stream import Stream diff --git a/src/program/media/stream.py b/src/program/media/stream.py index 4182ce7e..2e803f13 100644 --- a/src/program/media/stream.py +++ b/src/program/media/stream.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from program.db.db import db -from utils.logger import logger +from loguru import logger if TYPE_CHECKING: from program.media.item import MediaItem diff --git a/src/program/program.py b/src/program/program.py index 09ca84e2..d03394d1 100644 --- a/src/program/program.py +++ b/src/program/program.py @@ -25,7 +25,7 @@ from program.updaters import Updater from utils import data_dir_path from utils.event_manager import EventManager -from utils.logger import create_progress_bar, log_cleaner, logger +from utils.logging import create_progress_bar, log_cleaner, logger from .state_transition import process_event from .symlink import Symlinker diff --git a/src/program/scrapers/__init__.py b/src/program/scrapers/__init__.py index 3400794c..16f87672 100644 --- a/src/program/scrapers/__init__.py +++ b/src/program/scrapers/__init__.py @@ -17,7 +17,7 @@ from program.scrapers.torrentio import Torrentio from program.scrapers.zilean import Zilean from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger class Scraping: diff --git a/src/program/scrapers/annatar.py b/src/program/scrapers/annatar.py index 92ce090c..98170613 100644 --- a/src/program/scrapers/annatar.py +++ b/src/program/scrapers/annatar.py @@ -7,7 +7,7 @@ from program.media.item import MediaItem from program.settings.manager import settings_manager from program.scrapers.shared import _get_stremio_identifier -from utils.logger import logger +from loguru import logger from utils.ratelimiter import RateLimiter, RateLimitExceeded from utils.request import get diff --git a/src/program/scrapers/comet.py b/src/program/scrapers/comet.py index a28dd148..74ce26ce 100644 --- a/src/program/scrapers/comet.py +++ b/src/program/scrapers/comet.py @@ -10,7 +10,7 @@ from program.media.item import MediaItem, Show from program.settings.manager import settings_manager from program.scrapers.shared import _get_stremio_identifier -from utils.logger import logger +from loguru import logger from utils.request import RateLimiter, RateLimitExceeded, get, ping diff --git a/src/program/scrapers/jackett.py b/src/program/scrapers/jackett.py index eb80bbbf..a12271d6 100644 --- a/src/program/scrapers/jackett.py +++ b/src/program/scrapers/jackett.py @@ -12,7 +12,7 @@ from program.media.item import Episode, MediaItem, Movie, Season, Show from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger from utils.ratelimiter import RateLimiter, RateLimitExceeded diff --git a/src/program/scrapers/knightcrawler.py b/src/program/scrapers/knightcrawler.py index 4274ca75..c4bd765f 100644 --- a/src/program/scrapers/knightcrawler.py +++ b/src/program/scrapers/knightcrawler.py @@ -7,7 +7,7 @@ from program.media.item import Episode, MediaItem from program.scrapers.shared import _get_stremio_identifier from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger from utils.ratelimiter import RateLimiter, RateLimitExceeded from utils.request import get, ping diff --git a/src/program/scrapers/mediafusion.py b/src/program/scrapers/mediafusion.py index 0eea129c..b0f74e5c 100644 --- a/src/program/scrapers/mediafusion.py +++ b/src/program/scrapers/mediafusion.py @@ -10,7 +10,7 @@ from program.scrapers.shared import _get_stremio_identifier from program.settings.manager import settings_manager from program.settings.models import AppModel -from utils.logger import logger +from loguru import logger from utils.ratelimiter import RateLimiter, RateLimitExceeded from utils.request import get, ping diff --git a/src/program/scrapers/orionoid.py b/src/program/scrapers/orionoid.py index fbd91118..8e1a8bcc 100644 --- a/src/program/scrapers/orionoid.py +++ b/src/program/scrapers/orionoid.py @@ -3,7 +3,7 @@ from program.media.item import MediaItem from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger from utils.ratelimiter import RateLimiter, RateLimitExceeded from utils.request import get diff --git a/src/program/scrapers/prowlarr.py b/src/program/scrapers/prowlarr.py index 5f8ee8c2..47e1750b 100644 --- a/src/program/scrapers/prowlarr.py +++ b/src/program/scrapers/prowlarr.py @@ -13,7 +13,7 @@ from program.media.item import Episode, MediaItem, Movie, Season, Show from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger from utils.ratelimiter import RateLimiter, RateLimitExceeded diff --git a/src/program/scrapers/shared.py b/src/program/scrapers/shared.py index ef98bb75..8833573a 100644 --- a/src/program/scrapers/shared.py +++ b/src/program/scrapers/shared.py @@ -9,7 +9,7 @@ from program.media.stream import Stream from program.settings.manager import settings_manager from program.settings.versions import models -from utils.logger import logger +from loguru import logger enable_aliases = settings_manager.settings.scraping.enable_aliases settings_model = settings_manager.settings.ranking diff --git a/src/program/scrapers/torbox.py b/src/program/scrapers/torbox.py index 53e4ecd1..c8e3f8b7 100644 --- a/src/program/scrapers/torbox.py +++ b/src/program/scrapers/torbox.py @@ -5,7 +5,7 @@ from program.media.item import Episode, MediaItem, Movie, Season, Show from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger from utils.ratelimiter import RateLimiter, RateLimitExceeded from utils.request import get, ping diff --git a/src/program/scrapers/torrentio.py b/src/program/scrapers/torrentio.py index 565b5f47..67395a49 100644 --- a/src/program/scrapers/torrentio.py +++ b/src/program/scrapers/torrentio.py @@ -8,7 +8,7 @@ from program.settings.manager import settings_manager from program.settings.models import TorrentioConfig from program.scrapers.shared import _get_stremio_identifier -from utils.logger import logger +from loguru import logger from utils.ratelimiter import RateLimiter, RateLimitExceeded from utils.request import get, ping diff --git a/src/program/scrapers/zilean.py b/src/program/scrapers/zilean.py index a3d67705..3d250f96 100644 --- a/src/program/scrapers/zilean.py +++ b/src/program/scrapers/zilean.py @@ -8,7 +8,7 @@ from program.media.item import Episode, MediaItem, Season, Show from program.settings.manager import settings_manager from program.settings.models import AppModel -from utils.logger import logger +from loguru import logger from utils.ratelimiter import RateLimiter, RateLimitExceeded from utils.request import get, ping diff --git a/src/program/settings/versions.py b/src/program/settings/versions.py index d2f0ecc6..167fed7e 100644 --- a/src/program/settings/versions.py +++ b/src/program/settings/versions.py @@ -1,6 +1,6 @@ from RTN.models import BaseRankingModel, BestRanking, DefaultRanking -from utils.logger import logger +from loguru import logger class RankModels: diff --git a/src/program/state_transition.py b/src/program/state_transition.py index 6abbf048..de0bf009 100644 --- a/src/program/state_transition.py +++ b/src/program/state_transition.py @@ -11,7 +11,7 @@ from program.symlink import Symlinker from program.types import ProcessedEvent, Service from program.updaters import Updater -from utils.logger import logger +from loguru import logger def process_event(existing_item: MediaItem | None, emitted_by: Service, item: MediaItem) -> ProcessedEvent: diff --git a/src/program/symlink.py b/src/program/symlink.py index 9ff41971..68a07aff 100644 --- a/src/program/symlink.py +++ b/src/program/symlink.py @@ -15,7 +15,7 @@ from program.media.state import States from program.media.stream import Stream from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger class Symlinker: diff --git a/src/program/updaters/__init__.py b/src/program/updaters/__init__.py index c99480ce..8432891a 100644 --- a/src/program/updaters/__init__.py +++ b/src/program/updaters/__init__.py @@ -3,7 +3,7 @@ from program.updaters.plex import PlexUpdater from program.updaters.jellyfin import JellyfinUpdater from program.updaters.emby import EmbyUpdater -from utils.logger import logger +from loguru import logger class Updater: diff --git a/src/program/updaters/emby.py b/src/program/updaters/emby.py index 23149ca3..f54a866d 100644 --- a/src/program/updaters/emby.py +++ b/src/program/updaters/emby.py @@ -5,7 +5,7 @@ from program.settings.manager import settings_manager from program.media.item import MediaItem from utils.request import get, post -from utils.logger import logger +from loguru import logger class EmbyUpdater: diff --git a/src/program/updaters/jellyfin.py b/src/program/updaters/jellyfin.py index 040df562..f4dc28ac 100644 --- a/src/program/updaters/jellyfin.py +++ b/src/program/updaters/jellyfin.py @@ -5,7 +5,7 @@ from program.settings.manager import settings_manager from program.media.item import MediaItem from utils.request import get, post -from utils.logger import logger +from loguru import logger class JellyfinUpdater: diff --git a/src/program/updaters/plex.py b/src/program/updaters/plex.py index f3bc5ac6..bf14f4ad 100644 --- a/src/program/updaters/plex.py +++ b/src/program/updaters/plex.py @@ -10,7 +10,7 @@ from program.media.item import Episode, Movie, Season, Show from program.settings.manager import settings_manager -from utils.logger import logger +from loguru import logger class PlexUpdater: diff --git a/src/routers/secure/default.py b/src/routers/secure/default.py index 143c68b8..ae95997d 100644 --- a/src/routers/secure/default.py +++ b/src/routers/secure/default.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, Field from sqlalchemy import func, select from utils.event_manager import EventUpdate +from utils import generate_api_key from ..models.shared import MessageResponse @@ -60,6 +61,13 @@ async def get_rd_user() -> RDUser: return response.json() +@router.post("/generateapikey", operation_id="generateapikey") +async def generate_apikey() -> MessageResponse: + new_key = generate_api_key() + settings_manager.settings.api_key = new_key + settings_manager.save() + return { "message": new_key} + @router.get("/torbox", operation_id="torbox") async def get_torbox_user(): diff --git a/src/routers/secure/items.py b/src/routers/secure/items.py index b4b1f203..6bd4c41f 100644 --- a/src/routers/secure/items.py +++ b/src/routers/secure/items.py @@ -25,7 +25,7 @@ from RTN import Torrent from sqlalchemy import and_, func, or_, select from sqlalchemy.exc import NoResultFound -from utils.logger import logger +from loguru import logger from ..models.shared import MessageResponse diff --git a/src/routers/secure/webhooks.py b/src/routers/secure/webhooks.py index 3f6dd500..9dbba841 100644 --- a/src/routers/secure/webhooks.py +++ b/src/routers/secure/webhooks.py @@ -7,7 +7,7 @@ from program.indexers.trakt import get_imdbid_from_tmdb, get_imdbid_from_tvdb from program.media.item import MediaItem from requests import RequestException -from utils.logger import logger +from loguru import logger from ..models.overseerr import OverseerrWebhook diff --git a/src/tests/test_symlink_creation.py b/src/tests/test_symlink_creation.py index 72fb35f1..f2a28b72 100644 --- a/src/tests/test_symlink_creation.py +++ b/src/tests/test_symlink_creation.py @@ -16,7 +16,7 @@ from sqlalchemy.engine import URL from sqlalchemy.orm import declarative_base, relationship, sessionmaker -from utils.logger import logger +from loguru import logger logger.disable("program") # Suppress diff --git a/src/utils/cli.py b/src/utils/cli.py index 928e4a4c..1d2cabe9 100644 --- a/src/utils/cli.py +++ b/src/utils/cli.py @@ -3,7 +3,7 @@ from program.db.db_functions import hard_reset_database, resolve_duplicates from program.libraries.symlink import fix_broken_symlinks from program.settings.manager import settings_manager -from utils.logger import log_cleaner, logger +from utils.logging import log_cleaner, logger def handle_args(): diff --git a/src/utils/logger.py b/src/utils/logging.py similarity index 97% rename from src/utils/logger.py rename to src/utils/logging.py index 88a48009..5306f26b 100644 --- a/src/utils/logger.py +++ b/src/utils/logging.py @@ -1,152 +1,152 @@ -"""Logging utils""" - -import asyncio -import os -import sys -from datetime import datetime - -from loguru import logger -from rich.console import Console -from rich.progress import ( - BarColumn, - Progress, - SpinnerColumn, - TextColumn, - TimeRemainingColumn, -) - -from program.settings.manager import settings_manager -from utils import data_dir_path -from utils.websockets.logging_handler import Handler as WebSocketHandler - -LOG_ENABLED: bool = settings_manager.settings.log - -def setup_logger(level): - """Setup the logger""" - logs_dir_path = data_dir_path / "logs" - os.makedirs(logs_dir_path, exist_ok=True) - timestamp = datetime.now().strftime("%Y%m%d-%H%M") - log_filename = logs_dir_path / f"riven-{timestamp}.log" - - # Helper function to get log settings from environment or use default - def get_log_settings(name, default_color, default_icon): - color = os.getenv(f"RIVEN_LOGGER_{name}_FG", default_color) - icon = os.getenv(f"RIVEN_LOGGER_{name}_ICON", default_icon) - return f"", icon - - # Define log levels and their default settings - log_levels = { - "PROGRAM": (36, "cc6600", "🤖"), - "DATABASE": (37, "d834eb", "🛢️"), - "DEBRID": (38, "cc3333", "🔗"), - "SYMLINKER": (39, "F9E79F", "🔗"), - "SCRAPER": (40, "3D5A80", "👻"), - "COMPLETED": (41, "FFFFFF", "🟢"), - "CACHE": (42, "527826", "📜"), - "NOT_FOUND": (43, "818589", "🤷‍"), - "NEW": (44, "e63946", "✨"), - "FILES": (45, "FFFFE0", "🗃️ "), - "ITEM": (46, "92a1cf", "🗃️ "), - "DISCOVERY": (47, "e56c49", "🔍"), - "API": (47, "006989", "👾"), - "PLEX": (47, "DAD3BE", "📽️ "), - "LOCAL": (48, "DAD3BE", "📽️ "), - "JELLYFIN": (48, "DAD3BE", "📽️ "), - "EMBY": (48, "DAD3BE", "📽️ "), - "TRAKT": (48, "1DB954", "🎵"), - } - - # Set log levels - for name, (no, default_color, default_icon) in log_levels.items(): - color, icon = get_log_settings(name, default_color, default_icon) - logger.level(name, no=no, color=color, icon=icon) - - # Default log levels - debug_color, debug_icon = get_log_settings("DEBUG", "98C1D9", "🐞") - info_color, info_icon = get_log_settings("INFO", "818589", "📰") - warning_color, warning_icon = get_log_settings("WARNING", "ffcc00", "⚠️ ") - critical_color, critical_icon = get_log_settings("CRITICAL", "ff0000", "") - success_color, success_icon = get_log_settings("SUCCESS", "00ff00", "✔️ ") - - logger.level("DEBUG", color=debug_color, icon=debug_icon) - logger.level("INFO", color=info_color, icon=info_icon) - logger.level("WARNING", color=warning_color, icon=warning_icon) - logger.level("CRITICAL", color=critical_color, icon=critical_icon) - logger.level("SUCCESS", color=success_color, icon=success_icon) - - # Log format to match the old log format, but with color - log_format = ( - "{time:YY-MM-DD} {time:HH:mm:ss} | " - "{level.icon} {level: <9} | " - "{module}.{function} - {message}" - ) - - logger.configure(handlers=[ - { - "sink": sys.stderr, - "level": level.upper() or "INFO", - "format": log_format, - "backtrace": False, - "diagnose": False, - "enqueue": True, - }, - { - "sink": log_filename, - "level": level.upper(), - "format": log_format, - "rotation": "25 MB", - "retention": "24 hours", - "compression": None, - "backtrace": False, - "diagnose": True, - "enqueue": True, - }, - # maybe later - # { - # "sink": manager.send_log_message, - # "level": level.upper() or "INFO", - # "format": log_format, - # "backtrace": False, - # "diagnose": False, - # "enqueue": True, - # } - ]) - - logger.add(WebSocketHandler(), format=log_format) - - -def log_cleaner(): - """Remove old log files based on retention settings.""" - cleaned = False - try: - logs_dir_path = data_dir_path / "logs" - for log_file in logs_dir_path.glob("riven-*.log"): - # remove files older than 8 hours - if (datetime.now() - datetime.fromtimestamp(log_file.stat().st_mtime)).total_seconds() / 3600 > 8: - log_file.unlink() - cleaned = True - if cleaned: - logger.log("COMPLETED", "Cleaned up old logs that were older than 8 hours.") - except Exception as e: - logger.error(f"Failed to clean old logs: {e}") - - -def create_progress_bar(total_items: int) -> tuple[Progress, Console]: - console = Console() - progress = Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), - TimeRemainingColumn(), - TextColumn("[progress.completed]{task.completed}/{task.total}", justify="right"), - TextColumn("[progress.log]{task.fields[log]}", justify="right"), - console=console, - transient=True - ) - return progress, console - - -console = Console() -log_level = "DEBUG" if settings_manager.settings.debug else "INFO" +"""Logging utils""" + +import asyncio +import os +import sys +from datetime import datetime + +from loguru import logger +from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeRemainingColumn, +) + +from program.settings.manager import settings_manager +from utils import data_dir_path +from utils.websockets.logging_handler import Handler as WebSocketHandler + +LOG_ENABLED: bool = settings_manager.settings.log + +def setup_logger(level): + """Setup the logger""" + logs_dir_path = data_dir_path / "logs" + os.makedirs(logs_dir_path, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d-%H%M") + log_filename = logs_dir_path / f"riven-{timestamp}.log" + + # Helper function to get log settings from environment or use default + def get_log_settings(name, default_color, default_icon): + color = os.getenv(f"RIVEN_LOGGER_{name}_FG", default_color) + icon = os.getenv(f"RIVEN_LOGGER_{name}_ICON", default_icon) + return f"", icon + + # Define log levels and their default settings + log_levels = { + "PROGRAM": (36, "cc6600", "🤖"), + "DATABASE": (37, "d834eb", "🛢️"), + "DEBRID": (38, "cc3333", "🔗"), + "SYMLINKER": (39, "F9E79F", "🔗"), + "SCRAPER": (40, "3D5A80", "👻"), + "COMPLETED": (41, "FFFFFF", "🟢"), + "CACHE": (42, "527826", "📜"), + "NOT_FOUND": (43, "818589", "🤷‍"), + "NEW": (44, "e63946", "✨"), + "FILES": (45, "FFFFE0", "🗃️ "), + "ITEM": (46, "92a1cf", "🗃️ "), + "DISCOVERY": (47, "e56c49", "🔍"), + "API": (47, "006989", "👾"), + "PLEX": (47, "DAD3BE", "📽️ "), + "LOCAL": (48, "DAD3BE", "📽️ "), + "JELLYFIN": (48, "DAD3BE", "📽️ "), + "EMBY": (48, "DAD3BE", "📽️ "), + "TRAKT": (48, "1DB954", "🎵"), + } + + # Set log levels + for name, (no, default_color, default_icon) in log_levels.items(): + color, icon = get_log_settings(name, default_color, default_icon) + logger.level(name, no=no, color=color, icon=icon) + + # Default log levels + debug_color, debug_icon = get_log_settings("DEBUG", "98C1D9", "🐞") + info_color, info_icon = get_log_settings("INFO", "818589", "📰") + warning_color, warning_icon = get_log_settings("WARNING", "ffcc00", "⚠️ ") + critical_color, critical_icon = get_log_settings("CRITICAL", "ff0000", "") + success_color, success_icon = get_log_settings("SUCCESS", "00ff00", "✔️ ") + + logger.level("DEBUG", color=debug_color, icon=debug_icon) + logger.level("INFO", color=info_color, icon=info_icon) + logger.level("WARNING", color=warning_color, icon=warning_icon) + logger.level("CRITICAL", color=critical_color, icon=critical_icon) + logger.level("SUCCESS", color=success_color, icon=success_icon) + + # Log format to match the old log format, but with color + log_format = ( + "{time:YY-MM-DD} {time:HH:mm:ss} | " + "{level.icon} {level: <9} | " + "{module}.{function} - {message}" + ) + + logger.configure(handlers=[ + { + "sink": sys.stderr, + "level": level.upper() or "INFO", + "format": log_format, + "backtrace": False, + "diagnose": False, + "enqueue": True, + }, + { + "sink": log_filename, + "level": level.upper(), + "format": log_format, + "rotation": "25 MB", + "retention": "24 hours", + "compression": None, + "backtrace": False, + "diagnose": True, + "enqueue": True, + }, + # maybe later + # { + # "sink": manager.send_log_message, + # "level": level.upper() or "INFO", + # "format": log_format, + # "backtrace": False, + # "diagnose": False, + # "enqueue": True, + # } + ]) + + logger.add(WebSocketHandler(), format=log_format) + + +def log_cleaner(): + """Remove old log files based on retention settings.""" + cleaned = False + try: + logs_dir_path = data_dir_path / "logs" + for log_file in logs_dir_path.glob("riven-*.log"): + # remove files older than 8 hours + if (datetime.now() - datetime.fromtimestamp(log_file.stat().st_mtime)).total_seconds() / 3600 > 8: + log_file.unlink() + cleaned = True + if cleaned: + logger.log("COMPLETED", "Cleaned up old logs that were older than 8 hours.") + except Exception as e: + logger.error(f"Failed to clean old logs: {e}") + + +def create_progress_bar(total_items: int) -> tuple[Progress, Console]: + console = Console() + progress = Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeRemainingColumn(), + TextColumn("[progress.completed]{task.completed}/{task.total}", justify="right"), + TextColumn("[progress.log]{task.fields[log]}", justify="right"), + console=console, + transient=True + ) + return progress, console + + +console = Console() +log_level = "DEBUG" if settings_manager.settings.debug else "INFO" setup_logger(log_level) \ No newline at end of file diff --git a/src/utils/notifications.py b/src/utils/notifications.py index 105ed415..f0255353 100644 --- a/src/utils/notifications.py +++ b/src/utils/notifications.py @@ -8,7 +8,7 @@ from program.settings.manager import settings_manager from program.settings.models import NotificationsModel from utils import root_dir -from utils.logger import logger +from loguru import logger ntfy = Apprise() settings: NotificationsModel = settings_manager.settings.notifications From ac0f3b4354e6e9dfdb5aab22f8b390bba0696ed4 Mon Sep 17 00:00:00 2001 From: Ayush Sehrawat <69469790+AyushSehrawat@users.noreply.github.com> Date: Sat, 19 Oct 2024 14:26:16 +0000 Subject: [PATCH 3/6] fix: add python-dotenv to load .env variables --- poetry.lock | 2 +- pyproject.toml | 1 + src/main.py | 3 +++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 82caf444..32b66a1b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3351,4 +3351,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "818b191e87dc4f7cd4785399ce8b4487eccdbcef55c7195ea4d2ca461d926fdb" +content-hash = "6e0edfa871718836f16c13f2b47ebd76ebe52df2cc345bff4deba04d2fb1a19f" diff --git a/pyproject.toml b/pyproject.toml index 5b59cd61..0292f3b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ rank-torrent-name = "^1.0.2" jsonschema = "^4.23.0" scalar-fastapi = "^1.0.3" psutil = "^6.0.0" +python-dotenv = "^1.0.1" [tool.poetry.group.dev.dependencies] pyright = "^1.1.352" diff --git a/src/main.py b/src/main.py index 5bdf12b7..c021406f 100644 --- a/src/main.py +++ b/src/main.py @@ -16,6 +16,9 @@ from starlette.requests import Request from utils.cli import handle_args from loguru import logger +from dotenv import load_dotenv + +load_dotenv() class LoguruMiddleware(BaseHTTPMiddleware): From e25e0886c9be7da74c7d9aa81095228053100b88 Mon Sep 17 00:00:00 2001 From: Ayush Sehrawat <69469790+AyushSehrawat@users.noreply.github.com> Date: Sat, 19 Oct 2024 14:35:40 +0000 Subject: [PATCH 4/6] fix: add default value for API_KEY --- src/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index f5faed3f..1159014f 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -24,7 +24,7 @@ def get_version() -> str: def generate_api_key(): """Generate a secure API key of the specified length.""" - API_KEY = os.getenv("API_KEY") + API_KEY = os.getenv("API_KEY", "") if len(API_KEY) != 32: logger.warning("env.API_KEY is not 32 characters long, generating a new one...") characters = string.ascii_letters + string.digits From 4164372365cea36e6cc997b85b3d6b268408a1e2 Mon Sep 17 00:00:00 2001 From: Ayush Sehrawat <69469790+AyushSehrawat@users.noreply.github.com> Date: Sat, 19 Oct 2024 14:54:53 +0000 Subject: [PATCH 5/6] chore: remove unused tmdb api router --- src/routers/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/routers/__init__.py b/src/routers/__init__.py index 24ec207e..1f370d05 100644 --- a/src/routers/__init__.py +++ b/src/routers/__init__.py @@ -7,7 +7,7 @@ from routers.secure.items import router as items_router from routers.secure.scrape import router as scrape_router from routers.secure.settings import router as settings_router -from routers.secure.tmdb import router as tmdb_router +# from routers.secure.tmdb import router as tmdb_router from routers.secure.webhooks import router as webooks_router from routers.secure.ws import router as ws_router @@ -25,6 +25,6 @@ async def root(_: Request) -> RootResponse: app_router.include_router(items_router, dependencies=[Depends(resolve_api_key)]) app_router.include_router(scrape_router, dependencies=[Depends(resolve_api_key)]) app_router.include_router(settings_router, dependencies=[Depends(resolve_api_key)]) -app_router.include_router(tmdb_router, dependencies=[Depends(resolve_api_key)]) +# app_router.include_router(tmdb_router, dependencies=[Depends(resolve_api_key)]) app_router.include_router(webooks_router, dependencies=[Depends(resolve_api_key)]) app_router.include_router(ws_router, dependencies=[Depends(resolve_api_key)]) \ No newline at end of file From 0e86e119a415bab0154e2ccbeb89793380b047f0 Mon Sep 17 00:00:00 2001 From: Gaisberg Date: Mon, 21 Oct 2024 09:29:41 +0300 Subject: [PATCH 6/6] feat: we now server sse via /stream --- src/program/media/item.py | 19 ++++---- src/program/media/stream.py | 4 +- src/program/program.py | 3 -- src/routers/__init__.py | 4 +- src/routers/secure/stream.py | 39 ++++++++++++++++ src/utils/event_manager.py | 37 ++++++---------- src/utils/logging.py | 33 ++++---------- src/utils/sse_manager.py | 26 +++++++++++ src/utils/websockets/logging_handler.py | 12 ----- src/utils/websockets/manager.py | 59 ------------------------- 10 files changed, 100 insertions(+), 136 deletions(-) create mode 100644 src/routers/secure/stream.py create mode 100644 src/utils/sse_manager.py delete mode 100644 src/utils/websockets/logging_handler.py delete mode 100644 src/utils/websockets/manager.py diff --git a/src/program/media/item.py b/src/program/media/item.py index 4cec0767..e00ffd3a 100644 --- a/src/program/media/item.py +++ b/src/program/media/item.py @@ -6,10 +6,10 @@ import sqlalchemy from RTN import parse -from sqlalchemy import Index, UniqueConstraint +from sqlalchemy import Index from sqlalchemy.orm import Mapped, mapped_column, object_session, relationship -import utils.websockets.manager as ws_manager +from utils.sse_manager import sse_manager from program.db.db import db from program.media.state import States from program.media.subtitle import Subtitle @@ -133,9 +133,10 @@ def __init__(self, item: dict | None) -> None: self.subtitles = item.get("subtitles", []) def store_state(self) -> None: - if self.last_state and self.last_state != self._determine_state(): - ws_manager.send_item_update(json.dumps(self.to_dict())) - self.last_state = self._determine_state() + new_state = self._determine_state() + if self.last_state and self.last_state != new_state: + sse_manager.publish_event("item_update", {"last_state": self.last_state, "new_state": new_state, "item_id": self._id}) + self.last_state = new_state def is_stream_blacklisted(self, stream: Stream): """Check if a stream is blacklisted for this item.""" @@ -458,9 +459,7 @@ def _determine_state(self): def store_state(self) -> None: for season in self.seasons: season.store_state() - if self.last_state and self.last_state != self._determine_state(): - ws_manager.send_item_update(json.dumps(self.to_dict())) - self.last_state = self._determine_state() + super().store_state() def __repr__(self): return f"Show:{self.log_string}:{self.state.name}" @@ -531,9 +530,7 @@ class Season(MediaItem): def store_state(self) -> None: for episode in self.episodes: episode.store_state() - if self.last_state and self.last_state != self._determine_state(): - ws_manager.send_item_update(json.dumps(self.to_dict())) - self.last_state = self._determine_state() + super().store_state() def __init__(self, item): self.type = "season" diff --git a/src/program/media/stream.py b/src/program/media/stream.py index 2e803f13..4a8fcc81 100644 --- a/src/program/media/stream.py +++ b/src/program/media/stream.py @@ -24,7 +24,7 @@ class StreamRelation(db.Model): Index('ix_streamrelation_parent_id', 'parent_id'), Index('ix_streamrelation_child_id', 'child_id'), ) - + class StreamBlacklistRelation(db.Model): __tablename__ = "StreamBlacklistRelation" @@ -66,6 +66,6 @@ def __init__(self, torrent: Torrent): def __hash__(self): return self.infohash - + def __eq__(self, other): return isinstance(other, Stream) and self.infohash == other.infohash \ No newline at end of file diff --git a/src/program/program.py b/src/program/program.py index d03394d1..153d88cb 100644 --- a/src/program/program.py +++ b/src/program/program.py @@ -5,12 +5,10 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from queue import Empty -from typing import Iterator, List from apscheduler.schedulers.background import BackgroundScheduler from rich.live import Live -import utils.websockets.manager as ws_manager from program.content import Listrr, Mdblist, Overseerr, PlexWatchlist, TraktContent from program.downloaders import Downloader from program.indexers.trakt import TraktIndexer @@ -170,7 +168,6 @@ def start(self): super().start() self.scheduler.start() logger.success("Riven is running!") - ws_manager.send_health_update("running") self.initialized = True def _retry_library(self) -> None: diff --git a/src/routers/__init__.py b/src/routers/__init__.py index 1f370d05..34e0816b 100644 --- a/src/routers/__init__.py +++ b/src/routers/__init__.py @@ -9,7 +9,7 @@ from routers.secure.settings import router as settings_router # from routers.secure.tmdb import router as tmdb_router from routers.secure.webhooks import router as webooks_router -from routers.secure.ws import router as ws_router +from routers.secure.stream import router as stream_router API_VERSION = "v1" @@ -27,4 +27,4 @@ async def root(_: Request) -> RootResponse: app_router.include_router(settings_router, dependencies=[Depends(resolve_api_key)]) # app_router.include_router(tmdb_router, dependencies=[Depends(resolve_api_key)]) app_router.include_router(webooks_router, dependencies=[Depends(resolve_api_key)]) -app_router.include_router(ws_router, dependencies=[Depends(resolve_api_key)]) \ No newline at end of file +app_router.include_router(stream_router, dependencies=[Depends(resolve_api_key)]) \ No newline at end of file diff --git a/src/routers/secure/stream.py b/src/routers/secure/stream.py new file mode 100644 index 00000000..65452a83 --- /dev/null +++ b/src/routers/secure/stream.py @@ -0,0 +1,39 @@ +from datetime import datetime +import json +from fastapi import APIRouter, Request +from fastapi.responses import StreamingResponse +from loguru import logger + +import logging + +from pydantic import BaseModel +from utils.sse_manager import sse_manager + + +router = APIRouter( + responses={404: {"description": "Not found"}}, + prefix="/stream", + tags=["stream"], +) + +class EventResponse(BaseModel): + data: dict + +class SSELogHandler(logging.Handler): + def emit(self, record: logging.LogRecord): + log_entry = { + "time": datetime.fromtimestamp(record.created).isoformat(), + "level": record.levelname, + "message": record.msg + } + sse_manager.publish_event("logging", json.dumps(log_entry)) + +logger.add(SSELogHandler()) + +@router.get("/event_types") +async def get_event_types(): + return {"message": list(sse_manager.event_queues.keys())} + +@router.get("/{event_type}") +async def stream_events(_: Request, event_type: str) -> EventResponse: + return StreamingResponse(sse_manager.subscribe(event_type), media_type="text/event-stream") \ No newline at end of file diff --git a/src/utils/event_manager.py b/src/utils/event_manager.py index 97290a64..32c1c442 100644 --- a/src/utils/event_manager.py +++ b/src/utils/event_manager.py @@ -1,16 +1,17 @@ import os import traceback - + from datetime import datetime from queue import Empty from threading import Lock +from typing import Dict, List from loguru import logger from pydantic import BaseModel from sqlalchemy.orm.exc import StaleDataError from concurrent.futures import CancelledError, Future, ThreadPoolExecutor -import utils.websockets.manager as ws_manager +from utils.sse_manager import sse_manager from program.db.db import db from program.db.db_functions import ( ensure_item_exists_in_db, @@ -73,7 +74,7 @@ def _process_future(self, future, service): result = next(future.result(), None) if future in self._futures: self._futures.remove(future) - ws_manager.send_event_update([future.event for future in self._futures if hasattr(future, "event")]) + sse_manager.publish_event("event_update", self.get_event_updates()) if isinstance(result, tuple): item_id, timestamp = result else: @@ -170,7 +171,7 @@ def submit_job(self, service, program, event=None): if event: future.event = event self._futures.append(future) - ws_manager.send_event_update([future.event for future in self._futures if hasattr(future, "event")]) + sse_manager.publish_event("event_update", self.get_event_updates()) future.add_done_callback(lambda f:self._process_future(f, service)) def cancel_job(self, item_id: int, suppress_logs=False): @@ -310,24 +311,14 @@ def add_item(self, item, service="Manual"): logger.debug(f"Added item with ID {item_id} to the queue.") - def get_event_updates(self) -> dict[str, list[EventUpdate]]: - """ - Returns a formatted list of event updates. - - Returns: - list: The list of formatted event updates. - """ + def get_event_updates(self) -> Dict[str, List[int]]: events = [future.event for future in self._futures if hasattr(future, "event")] event_types = ["Scraping", "Downloader", "Symlinker", "Updater", "PostProcessing"] - return { - event_type.lower(): [ - EventUpdate.model_validate( - { - "item_id": event.item_id, - "emitted_by": event.emitted_by if isinstance(event.emitted_by, str) else event.emitted_by.__name__, - "run_at": event.run_at.isoformat() - }) - for event in events if event.emitted_by == event_type - ] - for event_type in event_types - } \ No newline at end of file + + updates = {event_type: [] for event_type in event_types} + for event in events: + table = updates.get(event.emitted_by.__name__, None) + if table is not None: + table.append(event.item_id) + + return updates \ No newline at end of file diff --git a/src/utils/logging.py b/src/utils/logging.py index 5306f26b..5c355c53 100644 --- a/src/utils/logging.py +++ b/src/utils/logging.py @@ -1,6 +1,5 @@ """Logging utils""" -import asyncio import os import sys from datetime import datetime @@ -17,7 +16,6 @@ from program.settings.manager import settings_manager from utils import data_dir_path -from utils.websockets.logging_handler import Handler as WebSocketHandler LOG_ENABLED: bool = settings_manager.settings.log @@ -67,7 +65,7 @@ def get_log_settings(name, default_color, default_icon): warning_color, warning_icon = get_log_settings("WARNING", "ffcc00", "⚠️ ") critical_color, critical_icon = get_log_settings("CRITICAL", "ff0000", "") success_color, success_icon = get_log_settings("SUCCESS", "00ff00", "✔️ ") - + logger.level("DEBUG", color=debug_color, icon=debug_icon) logger.level("INFO", color=info_color, icon=info_icon) logger.level("WARNING", color=warning_color, icon=warning_icon) @@ -91,30 +89,18 @@ def get_log_settings(name, default_color, default_icon): "enqueue": True, }, { - "sink": log_filename, - "level": level.upper(), - "format": log_format, - "rotation": "25 MB", - "retention": "24 hours", - "compression": None, - "backtrace": False, + "sink": log_filename, + "level": level.upper(), + "format": log_format, + "rotation": "25 MB", + "retention": "24 hours", + "compression": None, + "backtrace": False, "diagnose": True, "enqueue": True, - }, - # maybe later - # { - # "sink": manager.send_log_message, - # "level": level.upper() or "INFO", - # "format": log_format, - # "backtrace": False, - # "diagnose": False, - # "enqueue": True, - # } + } ]) - logger.add(WebSocketHandler(), format=log_format) - - def log_cleaner(): """Remove old log files based on retention settings.""" cleaned = False @@ -130,7 +116,6 @@ def log_cleaner(): except Exception as e: logger.error(f"Failed to clean old logs: {e}") - def create_progress_bar(total_items: int) -> tuple[Progress, Console]: console = Console() progress = Progress( diff --git a/src/utils/sse_manager.py b/src/utils/sse_manager.py new file mode 100644 index 00000000..1b284d06 --- /dev/null +++ b/src/utils/sse_manager.py @@ -0,0 +1,26 @@ +import asyncio +from typing import Dict, Any + +class ServerSentEventManager: + def __init__(self): + self.event_queues: Dict[str, asyncio.Queue] = {} + + def publish_event(self, event_type: str, data: Any): + if not data: + return + if event_type not in self.event_queues: + self.event_queues[event_type] = asyncio.Queue() + self.event_queues[event_type].put_nowait(data) + + async def subscribe(self, event_type: str): + if event_type not in self.event_queues: + self.event_queues[event_type] = asyncio.Queue() + + while True: + try: + data = await asyncio.wait_for(self.event_queues[event_type].get(), timeout=1.0) + yield f"data: {data}\n\n" + except asyncio.TimeoutError: + pass + +sse_manager = ServerSentEventManager() \ No newline at end of file diff --git a/src/utils/websockets/logging_handler.py b/src/utils/websockets/logging_handler.py deleted file mode 100644 index 163c624e..00000000 --- a/src/utils/websockets/logging_handler.py +++ /dev/null @@ -1,12 +0,0 @@ -import logging - -from utils.websockets import manager - - -class Handler(logging.Handler): - def emit(self, record: logging.LogRecord): - try: - message = self.format(record) - manager.send_log_message(message) - except Exception: - self.handleError(record) \ No newline at end of file diff --git a/src/utils/websockets/manager.py b/src/utils/websockets/manager.py deleted file mode 100644 index fe20e14d..00000000 --- a/src/utils/websockets/manager.py +++ /dev/null @@ -1,59 +0,0 @@ -import asyncio -import json - -from fastapi import WebSocket -from loguru import logger - -active_connections = [] - -async def connect(websocket: WebSocket): - await websocket.accept() - existing_connection = next((connection for connection in active_connections if connection.app == websocket.app), None) - if not existing_connection: - logger.debug("Frontend connected!") - active_connections.append(websocket) - if websocket.app.program.initialized: - status = "running" - else: - status = "paused" - await websocket.send_json({"type": "health", "status": status}) - -def disconnect(websocket: WebSocket): - logger.debug("Frontend disconnected!") - existing_connection = next((connection for connection in active_connections if connection.app == websocket.app), None) - active_connections.remove(existing_connection) - -async def _send_json(message: json, websocket: WebSocket): - try: - await websocket.send_json(message) - except Exception: - pass - -def send_event_update(events: list): - event_types = ["Scraping", "Downloader", "Symlinker", "Updater", "PostProcessing"] - message = {event_type.lower(): [event.item_id for event in events if event.emitted_by == event_type] for event_type in event_types} - broadcast({"type": "event_update", "message": message}) - -def send_health_update(status: str): - broadcast({"type": "health", "status": status}) - -def send_log_message(message: str): - broadcast({"type": "log", "message": message}) - -def send_item_update(item: json): - broadcast({"type": "item_update", "item": item}) - -def broadcast(message: json): - for connection in active_connections: - event_loop = None - try: - event_loop = asyncio.get_event_loop() - except RuntimeError: - pass - try: - if event_loop and event_loop.is_running(): - asyncio.create_task(_send_json(message, connection)) - else: - asyncio.run(_send_json(message, connection)) - except Exception: - pass