From 09d69626b98550eadee4545816b9a458a223a59a Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Mon, 30 Sep 2024 17:47:26 +0200 Subject: [PATCH 01/20] feat: add response models to items.py --- src/controllers/items.py | 258 +++++++++++++++++++++++++++++---------- 1 file changed, 192 insertions(+), 66 deletions(-) diff --git a/src/controllers/items.py b/src/controllers/items.py index e213f562..500abb50 100644 --- a/src/controllers/items.py +++ b/src/controllers/items.py @@ -3,11 +3,7 @@ from typing import Optional import Levenshtein -from RTN import RTN, Torrent from fastapi import APIRouter, HTTPException, Request -from sqlalchemy import func, select -from sqlalchemy.exc import NoResultFound - from program.content import Overseerr from program.db.db import db from program.db.db_functions import ( @@ -17,16 +13,21 @@ get_parent_items_by_ids, reset_media_item, ) +from program.downloaders import Downloader, get_needed_media +from program.downloaders.realdebrid import ( + add_torrent_magnet, + torrent_info, +) from program.media.item import MediaItem from program.media.state import States -from program.symlink import Symlinker -from program.downloaders import Downloader, get_needed_media -from program.downloaders.realdebrid import RealDebridDownloader, add_torrent_magnet, torrent_info -from program.settings.versions import models -from program.settings.manager import settings_manager from program.media.stream import Stream from program.scrapers.shared import rtn +from program.symlink import Symlinker from program.types import Event +from pydantic import BaseModel +from RTN import Torrent +from sqlalchemy import func, select +from sqlalchemy.exc import NoResultFound from utils.logger import logger router = APIRouter( @@ -35,23 +36,41 @@ responses={404: {"description": "Not found"}}, ) + def handle_ids(ids: str) -> list[int]: ids = [int(id) for id in ids.split(",")] if "," in ids else [int(ids)] if not ids: raise HTTPException(status_code=400, detail="No item ID provided") return ids -@router.get("/states") -async def get_states(): + +class StateResponse(BaseModel): + success: bool + states: list[str] + + +@router.get("/states", operation_id="get_states") +async def get_states() -> StateResponse: return { "success": True, "states": [state for state in States], } + +class ItemsResponse(BaseModel): + success: bool + items: list[dict] + page: int + limit: int + total_items: int + total_pages: int + + @router.get( "", summary="Retrieve Media Items", description="Fetch media items with optional filters and pagination", + operation_id="get_items", ) async def get_items( _: Request, @@ -62,7 +81,7 @@ async def get_items( sort: Optional[str] = "date_desc", search: Optional[str] = None, extended: Optional[bool] = False, -): +) -> ItemsResponse: if page < 1: raise HTTPException(status_code=400, detail="Page number must be 1 or greater.") @@ -77,8 +96,8 @@ async def get_items( query = query.where(MediaItem.imdb_id == search_lower) else: query = query.where( - (func.lower(MediaItem.title).like(f"%{search_lower}%")) | - (func.lower(MediaItem.imdb_id).like(f"%{search_lower}%")) + (func.lower(MediaItem.title).like(f"%{search_lower}%")) + | (func.lower(MediaItem.imdb_id).like(f"%{search_lower}%")) ) if state: @@ -104,9 +123,10 @@ async def get_items( if type not in ["movie", "show", "season", "episode"]: raise HTTPException( status_code=400, - detail=f"Invalid type: {type}. Valid types are: ['movie', 'show', 'season', 'episode']") + detail=f"Invalid type: {type}. Valid types are: ['movie', 'show', 'season', 'episode']", + ) else: - types=[type] + types = [type] query = query.where(MediaItem.type.in_(types)) if sort and not search: @@ -126,14 +146,24 @@ async def get_items( ) with db.Session() as session: - total_items = session.execute(select(func.count()).select_from(query.subquery())).scalar_one() - items = session.execute(query.offset((page - 1) * limit).limit(limit)).unique().scalars().all() + total_items = session.execute( + select(func.count()).select_from(query.subquery()) + ).scalar_one() + items = ( + session.execute(query.offset((page - 1) * limit).limit(limit)) + .unique() + .scalars() + .all() + ) total_pages = (total_items + limit - 1) // limit return { "success": True, - "items": [item.to_extended_dict() if extended else item.to_dict() for item in items], + "items": [ + item.to_extended_dict() if extended else item.to_dict() + for item in items + ], "page": page, "limit": limit, "total_items": total_items, @@ -141,15 +171,18 @@ async def get_items( } +class AddItemsResponse(BaseModel): + success: bool + message: str + + @router.post( - "/add", - summary="Add Media Items", - description="Add media items with bases on imdb IDs", + "/add", + summary="Add Media Items", + description="Add media items with bases on imdb IDs", + operation_id="add_items", ) -async def add_items( - request: Request, imdb_ids: str = None -): - +async def add_items(request: Request, imdb_ids: str = None) -> AddItemsResponse: if not imdb_ids: raise HTTPException(status_code=400, detail="No IMDb ID(s) provided") @@ -167,47 +200,77 @@ async def add_items( with db.Session() as _: for id in valid_ids: - item = MediaItem({"imdb_id": id, "requested_by": "riven", "requested_at": datetime.now()}) + item = MediaItem( + {"imdb_id": id, "requested_by": "riven", "requested_at": datetime.now()} + ) request.app.program.em.add_item(item) return {"success": True, "message": f"Added {len(valid_ids)} item(s) to the queue"} + +class ItemResponse(BaseModel): + success: bool + item: dict + + @router.get( "/{id}", summary="Retrieve Media Item", description="Fetch a single media item by ID", + operation_id="get_item", ) -async def get_item(request: Request, id: int): +async def get_item(request: Request, id: int) -> ItemResponse: with db.Session() as session: try: - item = session.execute(select(MediaItem).where(MediaItem._id == id)).unique().scalar_one() + item = ( + session.execute(select(MediaItem).where(MediaItem._id == id)) + .unique() + .scalar_one() + ) except NoResultFound: raise HTTPException(status_code=404, detail="Item not found") return {"success": True, "item": item.to_extended_dict()} + +class ItemsByImdbResponse(BaseModel): + success: bool + items: list[dict] + + @router.get( "/{imdb_ids}", summary="Retrieve Media Items By IMDb IDs", description="Fetch media items by IMDb IDs", + operation_id="get_items_by_imdb_ids", ) -async def get_items_by_imdb_ids(request: Request, imdb_ids: str): +async def get_items_by_imdb_ids(request: Request, imdb_ids: str) -> ItemsByImdbResponse: ids = imdb_ids.split(",") with db.Session() as session: items = [] for id in ids: - item = session.execute(select(MediaItem).where(MediaItem.imdb_id == id)).unique().scalar_one() + item = ( + session.execute(select(MediaItem).where(MediaItem.imdb_id == id)) + .unique() + .scalar_one() + ) if item: items.append(item) return {"success": True, "items": [item.to_extended_dict() for item in items]} + +class ResetResponse(BaseModel): + success: bool + message: str + ids: list[str] + + @router.post( - "/reset", - summary="Reset Media Items", - description="Reset media items with bases on item IDs", + "/reset", + summary="Reset Media Items", + description="Reset media items with bases on item IDs", + operation_id="reset_items", ) -async def reset_items( - request: Request, ids: str -): +async def reset_items(request: Request, ids: str) -> ResetResponse: ids = handle_ids(ids) try: media_items = get_media_items_by_ids(ids) @@ -222,15 +285,23 @@ async def reset_items( logger.error(f"Failed to reset item with id {media_item._id}: {str(e)}") continue except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - return {"success": True, "message": f"Reset items with id {ids}"} + raise HTTPException(status_code=400, detail=str(e)) from e + return {"success": True, "message": f"Reset items with id {ids}", "ids": ids} + + +class RetryResponse(BaseModel): + success: bool + message: str + ids: list[str] + @router.post( - "/retry", - summary="Retry Media Items", - description="Retry media items with bases on item IDs", + "/retry", + summary="Retry Media Items", + description="Retry media items with bases on item IDs", + operation_id="retry_items", ) -async def retry_items(request: Request, ids: str): +async def retry_items(request: Request, ids: str) -> RetryResponse: ids = handle_ids(ids) try: media_items = get_media_items_by_ids(ids) @@ -238,19 +309,27 @@ async def retry_items(request: Request, ids: str): raise ValueError("Invalid item ID(s) provided. Some items may not exist.") for media_item in media_items: request.app.program.em.cancel_job(media_item) - await asyncio.sleep(0.1) # Ensure cancellation is processed + await asyncio.sleep(0.1) # Ensure cancellation is processed request.app.program.em.add_item(media_item) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - return {"success": True, "message": f"Retried items with ids {ids}"} + return {"success": True, "message": f"Retried items with ids {ids}", "ids": ids} + + +class RemoveResponse(BaseModel): + success: bool + message: str + ids: list[str] + @router.delete( "/remove", summary="Remove Media Items", description="Remove media items based on item IDs", + operation_id="remove_item", ) -async def remove_item(request: Request, ids: str): +async def remove_item(request: Request, ids: str) -> RemoveResponse: ids = handle_ids(ids) try: media_items = get_parent_items_by_ids(ids) @@ -259,30 +338,46 @@ async def remove_item(request: Request, ids: str): for media_item in media_items: logger.debug(f"Removing item {media_item.title} with ID {media_item._id}") request.app.program.em.cancel_job(media_item) - await asyncio.sleep(0.1) # Ensure cancellation is processed + await asyncio.sleep(0.1) # Ensure cancellation is processed clear_streams(media_item) symlink_service = request.app.program.services.get(Symlinker) if symlink_service: symlink_service.delete_item_symlinks(media_item) if media_item.requested_by == "overseerr" and media_item.requested_id: - logger.debug(f"Item was originally requested by Overseerr, deleting request within Overseerr...") + logger.debug( + f"Item was originally requested by Overseerr, deleting request within Overseerr..." + ) Overseerr.delete_request(media_item.requested_id) delete_media_item(media_item) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - return {"success": True, "message": f"Removed items with ids {ids}"} + return {"success": True, "message": f"Removed items with ids {ids}", "ids": ids} + + +class SetTorrentRDResponse(BaseModel): + success: bool + message: str + item_id: int + torrent_id: str -@router.post("/{id}/set_torrent_rd_magnet", description="Set a torrent for a media item using a magnet link.") -def add_torrent(request: Request, id: int, magnet: str): + +@router.post( + "/{id}/set_torrent_rd_magnet", + name="Set torrent RD magnet", + description="Set a torrent for a media item using a magnet link.", + operation_id="set_torrent_rd_magnet", +) +def add_torrent(request: Request, id: int, magnet: str) -> SetTorrentRDResponse: torrent_id = "" try: torrent_id = add_torrent_magnet(magnet) except Exception: raise HTTPException(status_code=500, detail="Failed to add torrent.") from None - + return set_torrent_rd(request, id, torrent_id) + def reset_item_to_scraped(item: MediaItem): item.last_state = States.Scraped item.symlinked = False @@ -293,30 +388,48 @@ def reset_item_to_scraped(item: MediaItem): item.file = None item.folder = None + def create_stream(hash, torrent_info): try: torrent: Torrent = rtn.rank( - raw_title=torrent_info["filename"], - infohash=hash, - remove_trash=False + raw_title=torrent_info["filename"], infohash=hash, remove_trash=False ) return Stream(torrent) except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to rank torrent: {e}") from e + raise HTTPException( + status_code=500, detail=f"Failed to rank torrent: {e}" + ) from e -@router.post("/{id}/set_torrent_rd", description="Set a torrent for a media item using RD torrent ID.") -def set_torrent_rd(request: Request, id: int, torrent_id: str): +@router.post( + "/{id}/set_torrent_rd", + description="Set a torrent for a media item using RD torrent ID.", +) +def set_torrent_rd(request: Request, id: int, torrent_id: str): downloader: Downloader = request.app.program.services.get(Downloader) with db.Session() as session: - item: MediaItem = session.execute(select(MediaItem).where(MediaItem._id == id).outerjoin(MediaItem.streams)).unique().scalar_one_or_none() + item: MediaItem = ( + session.execute( + select(MediaItem) + .where(MediaItem._id == id) + .outerjoin(MediaItem.streams) + ) + .unique() + .scalar_one_or_none() + ) if item is None: raise HTTPException(status_code=404, detail="Item not found") fetched_torrent_info = torrent_info(torrent_id) - stream = session.execute(select(Stream).where(Stream.infohash == fetched_torrent_info["hash"])).scalars().first() + stream = ( + session.execute( + select(Stream).where(Stream.infohash == fetched_torrent_info["hash"]) + ) + .scalars() + .first() + ) hash = fetched_torrent_info["hash"] # Create stream if it doesn't exist @@ -325,10 +438,12 @@ def set_torrent_rd(request: Request, id: int, torrent_id: str): item.streams.append(stream) # check if the stream exists in the item - stream_exists_in_item = next((stream for stream in item.streams if stream.infohash == hash), None) + stream_exists_in_item = next( + (stream for stream in item.streams if stream.infohash == hash), None + ) if stream_exists_in_item is None: item.streams.append(stream) - + reset_item_to_scraped(item) # reset episodes if it's a season @@ -342,7 +457,10 @@ def set_torrent_rd(request: Request, id: int, torrent_id: str): if len(cached_streams) == 0: session.rollback() - raise HTTPException(status_code=400, detail=f"No cached torrents found for {item.log_string}") + raise HTTPException( + status_code=400, + detail=f"No cached torrents found for {item.log_string}", + ) item.active_stream = cached_streams[0] try: @@ -352,13 +470,21 @@ def set_torrent_rd(request: Request, id: int, torrent_id: str): if item.active_stream.get("infohash", None): downloader._delete_and_reset_active_stream(item) session.rollback() - raise HTTPException(status_code=500, detail=f"Failed to download {item.log_string}: {e}") from e + raise HTTPException( + status_code=500, detail=f"Failed to download {item.log_string}: {e}" + ) from e session.commit() request.app.program.em.add_event(Event("Symlinker", item)) - return {"success": True, "message": f"Set torrent for {item.title} to {torrent_id}"} + return { + "success": True, + "message": f"Set torrent for {item.title} to {torrent_id}", + "item_id": item._id, + "torrent_id": torrent_id, + } + # These require downloaders to be refactored @@ -397,4 +523,4 @@ def set_torrent_rd(request: Request, id: int, torrent_id: str): # item.reset(True) # downloader.download_cached(item, hash) # request.app.program.add_to_queue(item) -# return {"success": True, "message": f"Downloading {item.title} with hash {hash}"} \ No newline at end of file +# return {"success": True, "message": f"Downloading {item.title} with hash {hash}"} From d17be82c4ec519d72eca49e89be78837da0360c3 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Mon, 30 Sep 2024 17:49:30 +0200 Subject: [PATCH 02/20] feat: ignore ruff linter raising exceptions in except block --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cc590509..74cfc826 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,8 @@ ignore = [ "S101", # ruff: Ignore assert warnings on tests "RET505", # "RET503", # ruff: Ignore required explicit returns (is this desired?) - "SLF001" # private member accessing from pickle + "SLF001", # private member accessing from pickle + "B904" # ruff: ignore raising exceptions from except for the API ] extend-select = [ "I", # isort From cbdc66259c42d43dc8f15d24a365514a54191da8 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Mon, 30 Sep 2024 17:51:21 +0200 Subject: [PATCH 03/20] refactor: ruff format all the api files --- src/controllers/default.py | 56 ++++++++++++++++++++----------- src/controllers/scrape.py | 67 ++++++++++++++++++++++++++----------- src/controllers/settings.py | 3 +- src/controllers/tmdb.py | 1 - src/controllers/webhooks.py | 3 +- src/controllers/ws.py | 1 - 6 files changed, 87 insertions(+), 44 deletions(-) diff --git a/src/controllers/default.py b/src/controllers/default.py index 94d96d8c..e185be43 100644 --- a/src/controllers/default.py +++ b/src/controllers/default.py @@ -1,13 +1,12 @@ import requests from fastapi import APIRouter, HTTPException, Request from loguru import logger -from sqlalchemy import func, select - from program.content.trakt import TraktContent from program.db.db import db from program.media.item import Episode, MediaItem, Movie, Season, Show from program.media.state import States from program.settings.manager import settings_manager +from sqlalchemy import func, select router = APIRouter( responses={404: {"description": "Not found"}}, @@ -36,13 +35,17 @@ async def get_rd_user(): api_key = settings_manager.settings.downloaders.real_debrid.api_key headers = {"Authorization": f"Bearer {api_key}"} - proxy = settings_manager.settings.downloaders.real_debrid.proxy_url if settings_manager.settings.downloaders.real_debrid.proxy_enabled else None + proxy = ( + settings_manager.settings.downloaders.real_debrid.proxy_url + if settings_manager.settings.downloaders.real_debrid.proxy_enabled + else None + ) response = requests.get( "https://api.real-debrid.com/rest/1.0/user", headers=headers, proxies=proxy if proxy else None, - timeout=10 + timeout=10, ) if response.status_code != 200: @@ -102,9 +105,12 @@ async def trakt_oauth_callback(code: str, request: Request): async def get_stats(_: Request): payload = {} with db.Session() as session: - - movies_symlinks = session.execute(select(func.count(Movie._id)).where(Movie.symlinked == True)).scalar_one() - episodes_symlinks = session.execute(select(func.count(Episode._id)).where(Episode.symlinked == True)).scalar_one() + movies_symlinks = session.execute( + select(func.count(Movie._id)).where(Movie.symlinked == True) + ).scalar_one() + episodes_symlinks = session.execute( + select(func.count(Episode._id)).where(Episode.symlinked == True) + ).scalar_one() total_symlinks = movies_symlinks + episodes_symlinks total_movies = session.execute(select(func.count(Movie._id))).scalar_one() @@ -113,21 +119,30 @@ async def get_stats(_: Request): total_episodes = session.execute(select(func.count(Episode._id))).scalar_one() total_items = session.execute(select(func.count(MediaItem._id))).scalar_one() - # Select only the IDs of incomplete items - _incomplete_items = session.execute( - select(MediaItem._id) - .where(MediaItem.last_state != States.Completed) - ).scalars().all() - + # Select only the IDs of incomplete items + _incomplete_items = ( + session.execute( + select(MediaItem._id).where(MediaItem.last_state != States.Completed) + ) + .scalars() + .all() + ) + incomplete_retries = {} if _incomplete_items: - media_items = session.query(MediaItem).filter(MediaItem._id.in_(_incomplete_items)).all() + media_items = ( + session.query(MediaItem) + .filter(MediaItem._id.in_(_incomplete_items)) + .all() + ) for media_item in media_items: incomplete_retries[media_item.log_string] = media_item.scraped_times states = {} for state in States: - states[state] = session.execute(select(func.count(MediaItem._id)).where(MediaItem.last_state == state)).scalar_one() + states[state] = session.execute( + select(func.count(MediaItem._id)).where(MediaItem.last_state == state) + ).scalar_one() payload["total_items"] = total_items payload["total_movies"] = total_movies @@ -141,6 +156,7 @@ async def get_stats(_: Request): return {"success": True, "data": payload} + @router.get("/logs", operation_id="logs") async def get_logs(): log_file_path = None @@ -153,13 +169,14 @@ async def get_logs(): return {"success": False, "message": "Log file handler not found"} try: - with open(log_file_path, 'r') as log_file: + with open(log_file_path, "r") as log_file: log_contents = log_file.read() return {"success": True, "logs": log_contents} except Exception as e: logger.error(f"Failed to read log file: {e}") return {"success": False, "message": "Failed to read log file"} - + + @router.get("/events", operation_id="events") async def get_events(request: Request): return {"success": True, "data": request.app.program.em.get_event_updates()} @@ -169,8 +186,10 @@ async def get_events(request: Request): async def get_rclone_files(): """Get all files in the rclone mount.""" import os + rclone_dir = settings_manager.settings.symlink.rclone_path file_map = {} + def scan_dir(path): with os.scandir(path) as entries: for entry in entries: @@ -179,6 +198,5 @@ def scan_dir(path): elif entry.is_dir(): scan_dir(entry.path) - scan_dir(rclone_dir) # dict of `filename: filepath`` + scan_dir(rclone_dir) # dict of `filename: filepath`` return {"success": True, "data": file_map} - diff --git a/src/controllers/scrape.py b/src/controllers/scrape.py index deb833da..d0cdada8 100644 --- a/src/controllers/scrape.py +++ b/src/controllers/scrape.py @@ -1,29 +1,29 @@ """Scrape controller.""" + from fastapi import APIRouter, HTTPException, Request from sqlalchemy import select from program.scrapers import Scraping -from program.indexers.trakt import TraktIndexer from program.media.item import MediaItem from program.db.db import db +from program.downloaders.realdebrid import get_torrents +from program.indexers.trakt import TraktIndexer +from sqlalchemy import select -router = APIRouter( - prefix="/scrape", - tags=["scrape"] -) +router = APIRouter(prefix="/scrape", tags=["scrape"]) @router.get( - "", - summary="Scrape Media Item", - description="Scrape media item based on IMDb ID." + "", summary="Scrape Media Item", description="Scrape media item based on IMDb ID." ) -async def scrape(request: Request, imdb_id: str, season: int = None, episode: int = None): +async def scrape( + request: Request, imdb_id: str, season: int = None, episode: int = None +): """ Scrape media item based on IMDb ID. - **imdb_id**: IMDb ID of the media item. """ - if (services := request.app.program.services): + if services := request.app.program.services: scraping = services[Scraping] indexer = services[TraktIndexer] else: @@ -31,12 +31,16 @@ async def scrape(request: Request, imdb_id: str, season: int = None, episode: in try: with db.Session() as session: - media_item = session.execute( - select(MediaItem).where( - MediaItem.imdb_id == imdb_id, - MediaItem.type.in_(["movie", "show"]) + media_item = ( + session.execute( + select(MediaItem).where( + MediaItem.imdb_id == imdb_id, + MediaItem.type.in_(["movie", "show"]), + ) ) - ).unique().scalar_one_or_none() + .unique() + .scalar_one_or_none() + ) if not media_item: indexed_item = MediaItem({"imdb_id": imdb_id}) media_item = next(indexer.run(indexed_item)) @@ -48,7 +52,14 @@ async def scrape(request: Request, imdb_id: str, season: int = None, episode: in if media_item.type == "show": if season and episode: - media_item = next((ep for ep in media_item.seasons[season - 1].episodes if ep.number == episode), None) + media_item = next( + ( + ep + for ep in media_item.seasons[season - 1].episodes + if ep.number == episode + ), + None, + ) if not media_item: raise HTTPException(status_code=204, detail="Episode not found") elif season: @@ -56,7 +67,10 @@ async def scrape(request: Request, imdb_id: str, season: int = None, episode: in if not media_item: raise HTTPException(status_code=204, detail="Season not found") elif media_item.type == "movie" and (season or episode): - raise HTTPException(status_code=204, detail="Item type returned movie, cannot scrape season or episode") + raise HTTPException( + status_code=204, + detail="Item type returned movie, cannot scrape season or episode", + ) results = scraping.scrape(media_item, log=False) if not results: @@ -66,8 +80,9 @@ async def scrape(request: Request, imdb_id: str, season: int = None, episode: in { "raw_title": stream.raw_title, "infohash": stream.infohash, - "rank": stream.rank - } for stream in results.values() + "rank": stream.rank, + } + for stream in results.values() ] except StopIteration as e: @@ -76,3 +91,17 @@ async def scrape(request: Request, imdb_id: str, season: int = None, episode: in raise HTTPException(status_code=500, detail=str(e)) return {"success": True, "data": data} + + +@router.get( + "/rd", + summary="Get Real-Debrid Torrents", + description="Get torrents from Real-Debrid.", +) +async def get_rd_torrents(limit: int = 1000): + """ + Get torrents from Real-Debrid. + + - **limit**: Limit the number of torrents to get. + """ + return get_torrents(limit) diff --git a/src/controllers/settings.py b/src/controllers/settings.py index 671c9426..7cb13146 100644 --- a/src/controllers/settings.py +++ b/src/controllers/settings.py @@ -2,9 +2,8 @@ from typing import Any, Dict, List from fastapi import APIRouter, HTTPException -from pydantic import BaseModel, ValidationError - from program.settings.manager import settings_manager +from pydantic import BaseModel, ValidationError class SetSettings(BaseModel): diff --git a/src/controllers/tmdb.py b/src/controllers/tmdb.py index e2db8ce7..46910f75 100644 --- a/src/controllers/tmdb.py +++ b/src/controllers/tmdb.py @@ -3,7 +3,6 @@ from urllib.parse import urlencode from fastapi import APIRouter, Depends - from program.indexers.tmdb import tmdb router = APIRouter( diff --git a/src/controllers/webhooks.py b/src/controllers/webhooks.py index 891bc688..d25a1c3b 100644 --- a/src/controllers/webhooks.py +++ b/src/controllers/webhooks.py @@ -2,12 +2,11 @@ import pydantic from fastapi import APIRouter, Request -from requests import RequestException - from program.content.overseerr import Overseerr from program.db.db_functions import _ensure_item_exists_in_db from program.indexers.trakt import get_imdbid_from_tmdb, get_imdbid_from_tvdb from program.media.item import MediaItem +from requests import RequestException from utils.logger import logger from .models.overseerr import OverseerrWebhook diff --git a/src/controllers/ws.py b/src/controllers/ws.py index 6693b7b7..5d622aeb 100644 --- a/src/controllers/ws.py +++ b/src/controllers/ws.py @@ -1,5 +1,4 @@ from fastapi import WebSocket - from utils.websockets import manager from .default import router From 499163642a14229696ce3d20039aa9d058b24222 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Mon, 30 Sep 2024 18:03:22 +0200 Subject: [PATCH 04/20] feat: add response models to scrape.py --- src/controllers/scrape.py | 30 ++++++++++++++++++++------- src/program/downloaders/realdebrid.py | 29 +++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/src/controllers/scrape.py b/src/controllers/scrape.py index d0cdada8..442c7905 100644 --- a/src/controllers/scrape.py +++ b/src/controllers/scrape.py @@ -1,23 +1,37 @@ """Scrape controller.""" from fastapi import APIRouter, HTTPException, Request -from sqlalchemy import select -from program.scrapers import Scraping -from program.media.item import MediaItem from program.db.db import db -from program.downloaders.realdebrid import get_torrents +from program.downloaders.realdebrid import RDTorrent, get_torrents from program.indexers.trakt import TraktIndexer +from program.media.item import MediaItem +from program.scrapers import Scraping +from pydantic import BaseModel from sqlalchemy import select router = APIRouter(prefix="/scrape", tags=["scrape"]) +class ScrapedTorrent(BaseModel): + rank: int + raw_title: str + infohash: str + + +class ScrapeResponse(BaseModel): + success: bool + data: list[ScrapedTorrent] + + @router.get( - "", summary="Scrape Media Item", description="Scrape media item based on IMDb ID." + "", + summary="Scrape Media Item", + description="Scrape media item based on IMDb ID.", + operation_id="scrape", ) async def scrape( request: Request, imdb_id: str, season: int = None, episode: int = None -): +) -> ScrapeResponse: """ Scrape media item based on IMDb ID. @@ -92,13 +106,13 @@ async def scrape( return {"success": True, "data": data} - @router.get( "/rd", summary="Get Real-Debrid Torrents", description="Get torrents from Real-Debrid.", + operation_id="get_rd_torrents", ) -async def get_rd_torrents(limit: int = 1000): +async def get_rd_torrents(limit: int = 1000) -> list[RDTorrent]: """ Get torrents from Real-Debrid. diff --git a/src/program/downloaders/realdebrid.py b/src/program/downloaders/realdebrid.py index c7780fdd..4ed9a0b9 100644 --- a/src/program/downloaders/realdebrid.py +++ b/src/program/downloaders/realdebrid.py @@ -1,7 +1,10 @@ from datetime import datetime +from enum import Enum +from typing import Union from loguru import logger from program.settings.manager import settings_manager as settings +from pydantic import BaseModel from requests import ConnectTimeout from utils import request from utils.ratelimiter import RateLimiter @@ -13,6 +16,30 @@ torrent_limiter = RateLimiter(1, 1) overall_limiter = RateLimiter(60, 60) +class RDTorrentStatus(str, Enum): + magnet_error = "magnet_error" + magnet_conversion = "magnet_conversion" + waiting_files_selection = "waiting_files_selection" + downloading = "downloading" + downloaded = "downloaded" + error = "error" + seeding = "seeding" + dead = "dead" + uploading = "uploading" + compressing = "compressing" + +class RDTorrent(BaseModel): + id: str + hash: str + filename: str + bytes: int + status: RDTorrentStatus + added: datetime + links: list[str] + ended: Union[datetime, None] + speed: Union[int, None] + seeders: Union[int, None] + class RealDebridDownloader: def __init__(self): self.key = "realdebrid" @@ -174,7 +201,7 @@ def torrent_info(id: str) -> dict: info = {} return info -def get_torrents(limit: int) -> list[dict]: +def get_torrents(limit: int) -> list[RDTorrent]: try: torrents = get(f"torrents?limit={str(limit)}") except: From 382a2bf28cd6a441f5ee3b13dbb6f1239f25ce88 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Mon, 30 Sep 2024 18:07:16 +0200 Subject: [PATCH 05/20] feat: add response models to settings.py --- src/controllers/settings.py | 47 +++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/src/controllers/settings.py b/src/controllers/settings.py index 7cb13146..ea07e44c 100644 --- a/src/controllers/settings.py +++ b/src/controllers/settings.py @@ -3,6 +3,7 @@ from fastapi import APIRouter, HTTPException from program.settings.manager import settings_manager +from program.settings.models import AppModel from pydantic import BaseModel, ValidationError @@ -18,41 +19,54 @@ class SetSettings(BaseModel): ) -@router.get("/schema") -async def get_settings_schema(): +@router.get("/schema", operation_id="get_settings_schema") +async def get_settings_schema() -> dict[str, Any]: """ Get the JSON schema for the settings. """ return settings_manager.settings.model_json_schema() -@router.get("/load") -async def load_settings(): +class LoadSettingsResponse(BaseModel): + success: bool + message: str + +@router.get("/load", operation_id="load_settings") +async def load_settings() -> LoadSettingsResponse: settings_manager.load() return { "success": True, "message": "Settings loaded!", } +class SaveSettingsResponse(BaseModel): + success: bool + message: str -@router.post("/save") -async def save_settings(): +@router.post("/save", operation_id="save_settings") +async def save_settings() -> SaveSettingsResponse: settings_manager.save() return { "success": True, "message": "Settings saved!", } +class GetAllSettingsResponse(BaseModel): + success: bool + data: AppModel -@router.get("/get/all") -async def get_all_settings(): +@router.get("/get/all", operation_id="get_all_settings") +async def get_all_settings() -> GetAllSettingsResponse: return { "success": True, "data": copy(settings_manager.settings), } +class GetSettingsResponse(BaseModel): + success: bool + data: dict[str, Any] -@router.get("/get/{paths}") -async def get_settings(paths: str): +@router.get("/get/{paths}", operation_id="get_settings") +async def get_settings(paths: str) -> GetSettingsResponse: current_settings = settings_manager.settings.dict() data = {} for path in paths.split(","): @@ -71,9 +85,12 @@ async def get_settings(paths: str): "data": data, } +class SetAllSettingsResponse(BaseModel): + success: bool + message: str -@router.post("/set/all") -async def set_all_settings(new_settings: Dict[str, Any]): +@router.post("/set/all", operation_id="set_all_settings") +async def set_all_settings(new_settings: Dict[str, Any]) -> SetAllSettingsResponse: current_settings = settings_manager.settings.model_dump() def update_settings(current_obj, new_obj): @@ -98,8 +115,12 @@ def update_settings(current_obj, new_obj): "message": "All settings updated successfully!", } +class SetSettingsResponse(BaseModel): + success: bool + message: str + @router.post("/set") -async def set_settings(settings: List[SetSettings]): +async def set_settings(settings: List[SetSettings]) -> SetSettingsResponse: current_settings = settings_manager.settings.model_dump() for setting in settings: From 81fdd23bdca625e2945d51bdfad88a2196d9938d Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Mon, 30 Sep 2024 19:03:29 +0200 Subject: [PATCH 06/20] feat: add response models to tmdb.py --- src/controllers/tmdb.py | 15 ++- src/program/indexers/tmdb.py | 226 ++++++++++++++++++++++++++++++++--- 2 files changed, 219 insertions(+), 22 deletions(-) diff --git a/src/controllers/tmdb.py b/src/controllers/tmdb.py index 46910f75..501a27a6 100644 --- a/src/controllers/tmdb.py +++ b/src/controllers/tmdb.py @@ -1,9 +1,10 @@ from enum import Enum -from typing import Annotated +from typing import Annotated, Optional from urllib.parse import urlencode from fastapi import APIRouter, Depends -from program.indexers.tmdb import tmdb +from program.indexers.tmdb import TmdbItem, TmdbPagedResults, tmdb +from pydantic import BaseModel router = APIRouter( prefix="/tmdb", @@ -139,12 +140,18 @@ def __init__( self.year = year -@router.get("/trending/{type}/{window}") +class GetTrendingResponse(BaseModel): + success: bool + data: Optional[TmdbPagedResults[TmdbItem]] + message: Optional[str] + + +@router.get("/trending/{type}/{window}", operation_id="get_trending") async def get_trending( params: Annotated[TrendingParams, Depends()], type: TrendingType, window: TrendingWindow, -): +) -> GetTrendingResponse: trending = tmdb.getTrending( params=dict_to_query_string(params.__dict__), type=type.value, diff --git a/src/program/indexers/tmdb.py b/src/program/indexers/tmdb.py index 3868f606..f657b3d9 100644 --- a/src/program/indexers/tmdb.py +++ b/src/program/indexers/tmdb.py @@ -1,3 +1,8 @@ +from datetime import date +from enum import Enum +from typing import Generic, Literal, Optional, TypeVar + +from pydantic import BaseModel from utils.logger import logger from utils.request import get @@ -5,6 +10,191 @@ # TODO: Maybe remove the else condition ? It's not necessary since exception is raised 400-450, 500-511, 408, 460, 504, 520, 524, 522, 598 and 599 +ItemT = TypeVar("ItemT") + +class TmdbMediaType(str, Enum): + movie = "movie" + tv = "tv" + episode = "tv_episode" + season = "tv_season" + + +class TmdbItem(BaseModel): + adult: bool + backdrop_path: str + id: int + title: str + original_title: str + original_language: str + overview: str + poster_path: str + media_type: TmdbMediaType + genre_ids: list[int] + popularity: float + release_date: str + video: bool + vote_average: float + vote_count: int + +class TmdbEpisodeItem(BaseModel): + id: int + name: str + overview: str + media_type: Literal["tv_episode"] + vote_average: float + vote_count: int + air_date: date + episode_number: int + episode_type: str + production_code: str + runtime: int + season_number: int + show_id: int + still_path: str + +class TmdbSeasonItem(BaseModel): + id: int + name: str + overview: str + poster_path: str + media_type: Literal["tv_season"] + vote_average: float + air_date: date + season_number: int + show_id: int + episode_count: int + + +class TmdbPagedResults(BaseModel, Generic[ItemT]): + page: int + results: list[ItemT] + total_pages: int + total_results: int + +class TmdbPagedResultsWithDates(TmdbPagedResults[ItemT], Generic[ItemT]): + class Dates(BaseModel): + maximum: date + minimum: date + dates: Dates + +class TmdbFindResults(BaseModel): + movie_results: list[TmdbItem] + tv_results: list[TmdbItem] + tv_episode_results: list[TmdbEpisodeItem] + tv_season_results: list[TmdbSeasonItem] + +class Genre(BaseModel): + id: int + name: str + +class BelongsToCollection(BaseModel): + id: int + name: str + poster_path: Optional[str] + backdrop_path: Optional[str] + + +class ProductionCompany(BaseModel): + id: int + logo_path: Optional[str] + name: str + origin_country: str + + +class ProductionCountry(BaseModel): + iso_3166_1: str + name: str + + +class SpokenLanguage(BaseModel): + english_name: str + iso_639_1: str + name: str + +class Network(BaseModel): + id: int + logo_path: Optional[str] + name: str + origin_country: str + +class TmdbMovieDetails(BaseModel): + adult: bool + backdrop_path: Optional[str] + belongs_to_collection: Optional[BelongsToCollection] + budget: int + genres: list[Genre] + homepage: Optional[str] + id: int + imdb_id: Optional[str] + original_language: str + original_title: str + overview: Optional[str] + popularity: float + poster_path: Optional[str] + production_companies: list[ProductionCompany] + production_countries: list[ProductionCountry] + release_date: Optional[str] + revenue: int + runtime: Optional[int] + spoken_languages: list[SpokenLanguage] + status: Optional[str] + tagline: Optional[str] + title: str + video: bool + vote_average: float + vote_count: int + +class TmdbTVDetails(BaseModel): + adult: bool + backdrop_path: Optional[str] + episode_run_time: list[int] + first_air_date: str + genres: list[Genre] + homepage: Optional[str] + id: int + in_production: bool + languages: list[str] + last_air_date: Optional[str] + last_episode_to_air: Optional[TmdbEpisodeItem] + name: str + next_episode_to_air: Optional[str] + networks: list[Network] + number_of_episodes: int + number_of_seasons: int + origin_country: list[str] + original_language: str + original_name: str + overview: Optional[str] + popularity: float + poster_path: Optional[str] + production_companies: list[ProductionCompany] + production_countries: list[ProductionCountry] + seasons: list[TmdbSeasonItem] + spoken_languages: list[str] + status: Optional[str] + tagline: Optional[str] + type: Optional[str] + vote_average: float + vote_count: int + +class TmdbCollectionDetails(BaseModel): + adult: bool + backdrop_path: Optional[str] + id: int + name: str + overview: str + original_language: str + original_name: str + poster_path: Optional[str] + +class TmdbEpisodeDetails(TmdbEpisodeItem): + crew: list[dict] + guest_stars: list[dict] + +class TmdbSeasonDetails(BaseModel): + _id: str + air_date: str + episodes: list[TmdbEpisodeDetails] class TMDB: def __init__(self): @@ -13,7 +203,7 @@ def __init__(self): "Authorization": f"Bearer {TMDB_READ_ACCESS_TOKEN}", } - def getMoviesNowPlaying(self, params: str): + def getMoviesNowPlaying(self, params: str) -> Optional[TmdbPagedResultsWithDates]: url = f"{self.API_URL}/movie/now_playing?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -28,7 +218,7 @@ def getMoviesNowPlaying(self, params: str): ) return None - def getMoviesPopular(self, params: str): + def getMoviesPopular(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/movie/popular?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -41,7 +231,7 @@ def getMoviesPopular(self, params: str): logger.error(f"An error occurred while getting popular movies: {str(e)}") return None - def getMoviesTopRated(self, params: str): + def getMoviesTopRated(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/movie/top_rated?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -54,7 +244,7 @@ def getMoviesTopRated(self, params: str): logger.error(f"An error occurred while getting top rated movies: {str(e)}") return None - def getMoviesUpcoming(self, params: str): + def getMoviesUpcoming(self, params: str) -> Optional[TmdbPagedResultsWithDates[TmdbItem]]: url = f"{self.API_URL}/movie/upcoming?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -67,7 +257,7 @@ def getMoviesUpcoming(self, params: str): logger.error(f"An error occurred while getting upcoming movies: {str(e)}") return None - def getTrending(self, params: str, type: str, window: str): + def getTrending(self, params: str, type: str, window: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/trending/{type}/{window}?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -80,7 +270,7 @@ def getTrending(self, params: str, type: str, window: str): logger.error(f"An error occurred while getting trending {type}: {str(e)}") return None - def getTVAiringToday(self, params: str): + def getTVAiringToday(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/tv/airing_today?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -93,7 +283,7 @@ def getTVAiringToday(self, params: str): logger.error(f"An error occurred while getting TV airing today: {str(e)}") return None - def getTVOnTheAir(self, params: str): + def getTVOnTheAir(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/tv/on_the_air?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -106,7 +296,7 @@ def getTVOnTheAir(self, params: str): logger.error(f"An error occurred while getting TV on the air: {str(e)}") return None - def getTVPopular(self, params: str): + def getTVPopular(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/tv/popular?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -119,7 +309,7 @@ def getTVPopular(self, params: str): logger.error(f"An error occurred while getting popular TV shows: {str(e)}") return None - def getTVTopRated(self, params: str): + def getTVTopRated(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/tv/top_rated?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -134,7 +324,7 @@ def getTVTopRated(self, params: str): ) return None - def getFromExternalID(self, params: str, external_id: str): + def getFromExternalID(self, params: str, external_id: str) -> Optional[TmdbFindResults]: url = f"{self.API_URL}/find/{external_id}?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -147,7 +337,7 @@ def getFromExternalID(self, params: str, external_id: str): logger.error(f"An error occurred while getting from external ID: {str(e)}") return None - def getMovieDetails(self, params: str, movie_id: str): + def getMovieDetails(self, params: str, movie_id: str) -> Optional[TmdbMovieDetails]: url = f"{self.API_URL}/movie/{movie_id}?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -160,7 +350,7 @@ def getMovieDetails(self, params: str, movie_id: str): logger.error(f"An error occurred while getting movie details: {str(e)}") return None - def getTVDetails(self, params: str, series_id: str): + def getTVDetails(self, params: str, series_id: str) -> Optional[TmdbTVDetails]: url = f"{self.API_URL}/tv/{series_id}?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -173,7 +363,7 @@ def getTVDetails(self, params: str, series_id: str): logger.error(f"An error occurred while getting TV details: {str(e)}") return None - def getCollectionSearch(self, params: str): + def getCollectionSearch(self, params: str) -> Optional[TmdbPagedResults[TmdbCollectionDetails]]: url = f"{self.API_URL}/search/collection?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -186,7 +376,7 @@ def getCollectionSearch(self, params: str): logger.error(f"An error occurred while searching collections: {str(e)}") return None - def getMovieSearch(self, params: str): + def getMovieSearch(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/search/movie?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -199,7 +389,7 @@ def getMovieSearch(self, params: str): logger.error(f"An error occurred while searching movies: {str(e)}") return None - def getMultiSearch(self, params: str): + def getMultiSearch(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/search/multi?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -212,7 +402,7 @@ def getMultiSearch(self, params: str): logger.error(f"An error occurred while searching multi: {str(e)}") return None - def getTVSearch(self, params: str): + def getTVSearch(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/search/tv?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -225,7 +415,7 @@ def getTVSearch(self, params: str): logger.error(f"An error occurred while searching TV shows: {str(e)}") return None - def getTVSeasonDetails(self, params: str, series_id: int, season_number: int): + def getTVSeasonDetails(self, params: str, series_id: int, season_number: int) -> Optional[TmdbSeasonDetails]: url = f"{self.API_URL}/tv/{series_id}/season/{season_number}?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -240,7 +430,7 @@ def getTVSeasonDetails(self, params: str, series_id: int, season_number: int): def getTVSeasonEpisodeDetails( self, params: str, series_id: int, season_number: int, episode_number: int - ): + ) -> Optional[TmdbEpisodeDetails]: url = f"{self.API_URL}/tv/{series_id}/season/{season_number}/episode/{episode_number}?{params}" try: response = get(url, additional_headers=self.HEADERS) From 04bd9a96bf9c4b5a58c8a54bd58a2d750c8dee67 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Mon, 30 Sep 2024 19:16:34 +0200 Subject: [PATCH 07/20] feat: add missing type annotations to tmdb.py --- src/controllers/tmdb.py | 115 +++++++++++++++++++++++------------ src/program/indexers/tmdb.py | 2 +- 2 files changed, 77 insertions(+), 40 deletions(-) diff --git a/src/controllers/tmdb.py b/src/controllers/tmdb.py index 501a27a6..1dc039f6 100644 --- a/src/controllers/tmdb.py +++ b/src/controllers/tmdb.py @@ -1,9 +1,20 @@ from enum import Enum -from typing import Annotated, Optional +from typing import Annotated, Generic, Optional, TypeVar from urllib.parse import urlencode from fastapi import APIRouter, Depends -from program.indexers.tmdb import TmdbItem, TmdbPagedResults, tmdb +from program.indexers.tmdb import ( + TmdbCollectionDetails, + TmdbEpisodeDetails, + TmdbFindResults, + TmdbItem, + TmdbMovieDetails, + TmdbPagedResults, + TmdbPagedResultsWithDates, + TmdbSeasonDetails, + TmdbTVDetails, + tmdb, +) from pydantic import BaseModel router = APIRouter( @@ -140,9 +151,12 @@ def __init__( self.year = year -class GetTrendingResponse(BaseModel): +T = TypeVar("T") + + +class TmdbResponse(BaseModel, Generic[T]): success: bool - data: Optional[TmdbPagedResults[TmdbItem]] + data: Optional[T] message: Optional[str] @@ -151,7 +165,7 @@ async def get_trending( params: Annotated[TrendingParams, Depends()], type: TrendingType, window: TrendingWindow, -) -> GetTrendingResponse: +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: trending = tmdb.getTrending( params=dict_to_query_string(params.__dict__), type=type.value, @@ -169,8 +183,10 @@ async def get_trending( } -@router.get("/movie/now_playing") -async def get_movies_now_playing(params: Annotated[CommonListParams, Depends()]): +@router.get("/movie/now_playing", operation_id="get_movies_now_playing") +async def get_movies_now_playing( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResultsWithDates[TmdbItem]]: # noqa: F821 movies = tmdb.getMoviesNowPlaying(params=dict_to_query_string(params.__dict__)) if movies: return { @@ -184,8 +200,10 @@ async def get_movies_now_playing(params: Annotated[CommonListParams, Depends()]) } -@router.get("/movie/popular") -async def get_movies_popular(params: Annotated[CommonListParams, Depends()]): +@router.get("/movie/popular", operation_id="get_movies_popular") +async def get_movies_popular( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: movies = tmdb.getMoviesPopular(params=dict_to_query_string(params.__dict__)) if movies: return { @@ -199,8 +217,10 @@ async def get_movies_popular(params: Annotated[CommonListParams, Depends()]): } -@router.get("/movie/top_rated") -async def get_movies_top_rated(params: Annotated[CommonListParams, Depends()]): +@router.get("/movie/top_rated", operation_id="get_movies_top_rated") +async def get_movies_top_rated( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: movies = tmdb.getMoviesTopRated(params=dict_to_query_string(params.__dict__)) if movies: return { @@ -214,8 +234,10 @@ async def get_movies_top_rated(params: Annotated[CommonListParams, Depends()]): } -@router.get("/movie/upcoming") -async def get_movies_upcoming(params: Annotated[CommonListParams, Depends()]): +@router.get("/movie/upcoming", operation_id="get_movies_upcoming") +async def get_movies_upcoming( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResultsWithDates[TmdbItem]]: movies = tmdb.getMoviesUpcoming(params=dict_to_query_string(params.__dict__)) if movies: return { @@ -232,11 +254,11 @@ async def get_movies_upcoming(params: Annotated[CommonListParams, Depends()]): # FastAPI has router preference, so /movie/now_playing, /movie/popular, /movie/top_rated and /movie/upcoming will be matched first before /movie/{movie_id}, same for /tv/{tv_id} -@router.get("/movie/{movie_id}") +@router.get("/movie/{movie_id}", operation_id="get_movie_details") async def get_movie_details( movie_id: str, params: Annotated[DetailsParams, Depends()], -): +) -> TmdbResponse[TmdbMovieDetails]: data = tmdb.getMovieDetails( params=dict_to_query_string(params.__dict__), movie_id=movie_id, @@ -253,8 +275,10 @@ async def get_movie_details( } -@router.get("/tv/airing_today") -async def get_tv_airing_today(params: Annotated[CommonListParams, Depends()]): +@router.get("/tv/airing_today", operation_id="get_tv_airing_today") +async def get_tv_airing_today( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: tv = tmdb.getTVAiringToday(params=dict_to_query_string(params.__dict__)) if tv: return { @@ -268,8 +292,10 @@ async def get_tv_airing_today(params: Annotated[CommonListParams, Depends()]): } -@router.get("/tv/on_the_air") -async def get_tv_on_the_air(params: Annotated[CommonListParams, Depends()]): +@router.get("/tv/on_the_air", operation_id="get_tv_on_the_air") +async def get_tv_on_the_air( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: tv = tmdb.getTVOnTheAir(params=dict_to_query_string(params.__dict__)) if tv: return { @@ -283,8 +309,10 @@ async def get_tv_on_the_air(params: Annotated[CommonListParams, Depends()]): } -@router.get("/tv/popular") -async def get_tv_popular(params: Annotated[CommonListParams, Depends()]): +@router.get("/tv/popular", operation_id="get_tv_popular") +async def get_tv_popular( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: tv = tmdb.getTVPopular(params=dict_to_query_string(params.__dict__)) if tv: return { @@ -298,8 +326,10 @@ async def get_tv_popular(params: Annotated[CommonListParams, Depends()]): } -@router.get("/tv/top_rated") -async def get_tv_top_rated(params: Annotated[CommonListParams, Depends()]): +@router.get("/tv/top_rated", operation_id="get_tv_top_rated") +async def get_tv_top_rated( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: tv = tmdb.getTVTopRated(params=dict_to_query_string(params.__dict__)) if tv: return { @@ -313,11 +343,11 @@ async def get_tv_top_rated(params: Annotated[CommonListParams, Depends()]): } -@router.get("/tv/{series_id}") +@router.get("/tv/{series_id}", operation_id="get_tv_details") async def get_tv_details( series_id: str, params: Annotated[DetailsParams, Depends()], -): +) -> TmdbResponse[TmdbTVDetails]: data = tmdb.getTVDetails( params=dict_to_query_string(params.__dict__), series_id=series_id, @@ -334,12 +364,14 @@ async def get_tv_details( } -@router.get("/tv/{series_id}/season/{season_number}") +@router.get( + "/tv/{series_id}/season/{season_number}", operation_id="get_tv_season_details" +) async def get_tv_season_details( series_id: int, season_number: int, params: Annotated[DetailsParams, Depends()], -): +) -> TmdbResponse[TmdbSeasonDetails]: data = tmdb.getTVSeasonDetails( params=dict_to_query_string(params.__dict__), series_id=series_id, @@ -357,13 +389,16 @@ async def get_tv_season_details( } -@router.get("/tv/{series_id}/season/{season_number}/episode/{episode_number}") +@router.get( + "/tv/{series_id}/season/{season_number}/episode/{episode_number}", + operation_id="get_tv_episode_details", +) async def get_tv_episode_details( series_id: int, season_number: int, episode_number: int, params: Annotated[DetailsParams, Depends()], -): +) -> TmdbResponse[TmdbEpisodeDetails]: data = tmdb.getTVSeasonEpisodeDetails( params=dict_to_query_string(params.__dict__), series_id=series_id, @@ -382,8 +417,10 @@ async def get_tv_episode_details( } -@router.get("/search/collection") -async def search_collection(params: Annotated[CollectionSearchParams, Depends()]): +@router.get("/search/collection", operation_id="search_collection") +async def search_collection( + params: Annotated[CollectionSearchParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbCollectionDetails]]: data = tmdb.getCollectionSearch(params=dict_to_query_string(params.__dict__)) if data: return { @@ -397,8 +434,8 @@ async def search_collection(params: Annotated[CollectionSearchParams, Depends()] } -@router.get("/search/movie") -async def search_movie(params: Annotated[MovieSearchParams, Depends()]): +@router.get("/search/movie", operation_id="search_movie") +async def search_movie(params: Annotated[MovieSearchParams, Depends()]) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: data = tmdb.getMovieSearch(params=dict_to_query_string(params.__dict__)) if data: return { @@ -412,8 +449,8 @@ async def search_movie(params: Annotated[MovieSearchParams, Depends()]): } -@router.get("/search/multi") -async def search_multi(params: Annotated[MultiSearchParams, Depends()]): +@router.get("/search/multi", operation_id="search_multi") +async def search_multi(params: Annotated[MultiSearchParams, Depends()]) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: data = tmdb.getMultiSearch(params=dict_to_query_string(params.__dict__)) if data: return { @@ -427,8 +464,8 @@ async def search_multi(params: Annotated[MultiSearchParams, Depends()]): } -@router.get("/search/tv") -async def search_tv(params: Annotated[TVSearchParams, Depends()]): +@router.get("/search/tv", operation_id="search_tv") +async def search_tv(params: Annotated[TVSearchParams, Depends()]) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: data = tmdb.getTVSearch(params=dict_to_query_string(params.__dict__)) if data: return { @@ -442,11 +479,11 @@ async def search_tv(params: Annotated[TVSearchParams, Depends()]): } -@router.get("/external_id/{external_id}") +@router.get("/external_id/{external_id}", operation_id="get_from_external_id") async def get_from_external_id( external_id: str, params: Annotated[ExternalIDParams, Depends()], -): +) -> TmdbResponse[TmdbFindResults]: data = tmdb.getFromExternalID( params=dict_to_query_string(params.__dict__), external_id=external_id, diff --git a/src/program/indexers/tmdb.py b/src/program/indexers/tmdb.py index f657b3d9..8c572ff6 100644 --- a/src/program/indexers/tmdb.py +++ b/src/program/indexers/tmdb.py @@ -203,7 +203,7 @@ def __init__(self): "Authorization": f"Bearer {TMDB_READ_ACCESS_TOKEN}", } - def getMoviesNowPlaying(self, params: str) -> Optional[TmdbPagedResultsWithDates]: + def getMoviesNowPlaying(self, params: str) -> Optional[TmdbPagedResultsWithDates[TmdbItem]]: url = f"{self.API_URL}/movie/now_playing?{params}" try: response = get(url, additional_headers=self.HEADERS) From e027f643eb1e8a107e4e9d6bf4118f18b6ff5a34 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Mon, 30 Sep 2024 19:36:51 +0200 Subject: [PATCH 08/20] fix: add default values for some pydantic models --- src/controllers/tmdb.py | 4 ++-- src/program/indexers/tmdb.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/controllers/tmdb.py b/src/controllers/tmdb.py index 1dc039f6..a9b930f2 100644 --- a/src/controllers/tmdb.py +++ b/src/controllers/tmdb.py @@ -156,8 +156,8 @@ def __init__( class TmdbResponse(BaseModel, Generic[T]): success: bool - data: Optional[T] - message: Optional[str] + data: Optional[T] = None + message: Optional[str] = None @router.get("/trending/{type}/{window}", operation_id="get_trending") diff --git a/src/program/indexers/tmdb.py b/src/program/indexers/tmdb.py index 8c572ff6..2ff9d816 100644 --- a/src/program/indexers/tmdb.py +++ b/src/program/indexers/tmdb.py @@ -21,14 +21,14 @@ class TmdbMediaType(str, Enum): class TmdbItem(BaseModel): adult: bool - backdrop_path: str + backdrop_path: Optional[str] id: int title: str original_title: str original_language: str overview: str - poster_path: str - media_type: TmdbMediaType + poster_path: Optional[str] + media_type: Optional[TmdbMediaType] = None genre_ids: list[int] popularity: float release_date: str From cbf6f296add628fdc538e3f3c91ba742719d79b7 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Wed, 2 Oct 2024 16:45:01 +0200 Subject: [PATCH 09/20] feat: add types to default.py --- src/controllers/default.py | 53 +++++++++++++++++++++++++------- src/controllers/models/shared.py | 13 ++++++++ src/utils/event_manager.py | 11 ++++++- 3 files changed, 65 insertions(+), 12 deletions(-) create mode 100644 src/controllers/models/shared.py diff --git a/src/controllers/default.py b/src/controllers/default.py index e185be43..5654703b 100644 --- a/src/controllers/default.py +++ b/src/controllers/default.py @@ -1,4 +1,7 @@ +from typing import Literal + import requests +from controllers.models.shared import DataAndSuccessResponse, MessageAndSuccessResponse from fastapi import APIRouter, HTTPException, Request from loguru import logger from program.content.trakt import TraktContent @@ -6,32 +9,44 @@ from program.media.item import Episode, MediaItem, Movie, Season, Show from program.media.state import States from program.settings.manager import settings_manager +from pydantic import BaseModel, Field from sqlalchemy import func, select +from utils.event_manager import EventUpdate router = APIRouter( responses={404: {"description": "Not found"}}, ) +class RootResponse(MessageAndSuccessResponse): + version: str @router.get("/", operation_id="root") -async def root(): +async def root() -> RootResponse: return { "success": True, "message": "Riven is running!", "version": settings_manager.settings.version, } - @router.get("/health", operation_id="health") -async def health(request: Request): +async def health(request: Request) -> MessageAndSuccessResponse: return { "success": True, "message": request.app.program.initialized, } +class RDUser(BaseModel): + id: int + username: str + email: str + points: int = Field(description="User's RD points") + locale: str + avatar: str = Field(description="URL to the user's avatar") + type: Literal["free", "premium"] + premium: int = Field(description="Premium subscription left in seconds") @router.get("/rd", operation_id="rd") -async def get_rd_user(): +async def get_rd_user() -> DataAndSuccessResponse[RDUser]: api_key = settings_manager.settings.downloaders.real_debrid.api_key headers = {"Authorization": f"Bearer {api_key}"} @@ -68,7 +83,7 @@ async def get_torbox_user(): @router.get("/services", operation_id="services") -async def get_services(request: Request): +async def get_services(request: Request) -> DataAndSuccessResponse[dict]: data = {} if hasattr(request.app.program, "services"): for service in request.app.program.all_services.values(): @@ -79,9 +94,11 @@ async def get_services(request: Request): data[sub_service.key] = sub_service.initialized return {"success": True, "data": data} +class TraktOAuthInitiateResponse(BaseModel): + auth_url: str @router.get("/trakt/oauth/initiate", operation_id="trakt_oauth_initiate") -async def initiate_trakt_oauth(request: Request): +async def initiate_trakt_oauth(request: Request) -> TraktOAuthInitiateResponse: trakt = request.app.program.services.get(TraktContent) if trakt is None: raise HTTPException(status_code=404, detail="Trakt service not found") @@ -90,7 +107,7 @@ async def initiate_trakt_oauth(request: Request): @router.get("/trakt/oauth/callback", operation_id="trakt_oauth_callback") -async def trakt_oauth_callback(code: str, request: Request): +async def trakt_oauth_callback(code: str, request: Request) -> MessageAndSuccessResponse: trakt = request.app.program.services.get(TraktContent) if trakt is None: raise HTTPException(status_code=404, detail="Trakt service not found") @@ -100,9 +117,19 @@ async def trakt_oauth_callback(code: str, request: Request): else: raise HTTPException(status_code=400, detail="Failed to obtain OAuth token") +class StatsResponse(BaseModel): + total_items: int + total_movies: int + total_shows: int + total_seasons: int + total_episodes: int + total_symlinks: int + incomplete_items: int + incomplete_retries: dict[str, int] = Field(description="Media item log string: number of retries") + states: dict[States, int] @router.get("/stats", operation_id="stats") -async def get_stats(_: Request): +async def get_stats(_: Request) -> DataAndSuccessResponse[StatsResponse]: payload = {} with db.Session() as session: movies_symlinks = session.execute( @@ -157,8 +184,12 @@ async def get_stats(_: Request): return {"success": True, "data": payload} +class LogsResponse(BaseModel): + success: bool + logs: str + @router.get("/logs", operation_id="logs") -async def get_logs(): +async def get_logs() -> LogsResponse: log_file_path = None for handler in logger._core.handlers.values(): if ".log" in handler._name: @@ -178,12 +209,12 @@ async def get_logs(): @router.get("/events", operation_id="events") -async def get_events(request: Request): +async def get_events(request: Request) -> DataAndSuccessResponse[dict[str, list[EventUpdate]]]: return {"success": True, "data": request.app.program.em.get_event_updates()} @router.get("/mount", operation_id="mount") -async def get_rclone_files(): +async def get_rclone_files() -> DataAndSuccessResponse[dict[str, str]]: """Get all files in the rclone mount.""" import os diff --git a/src/controllers/models/shared.py b/src/controllers/models/shared.py new file mode 100644 index 00000000..1550167b --- /dev/null +++ b/src/controllers/models/shared.py @@ -0,0 +1,13 @@ +from typing import Generic, TypeVar +from pydantic import BaseModel + + +class MessageAndSuccessResponse(BaseModel): + message: str + success: bool + +T = TypeVar('T', bound=BaseModel) + +class DataAndSuccessResponse(BaseModel, Generic[T]): + data: T + success: bool \ No newline at end of file diff --git a/src/utils/event_manager.py b/src/utils/event_manager.py index 17e175b2..c9f2cb29 100644 --- a/src/utils/event_manager.py +++ b/src/utils/event_manager.py @@ -6,6 +6,7 @@ from threading import Lock from loguru import logger +from pydantic import BaseModel from sqlalchemy.orm.exc import StaleDataError from subliminal import Episode, Movie @@ -15,6 +16,14 @@ from program.media.item import Season, Show from program.types import Event +class EventUpdate(BaseModel): + item_id: str + imdb_id: str + title: str + type: str + emitted_by: str + run_at: str + last_state: str class EventManager: """ @@ -324,7 +333,7 @@ def add_item(self, item, service="Manual"): """ self.add_event(Event(service, item)) - def get_event_updates(self): + def get_event_updates(self) -> dict[str, list[EventUpdate]]: """ Returns a formatted list of event updates. From 84e920626a5ea8455ff3432ce0db95b83042fafc Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Thu, 3 Oct 2024 13:42:15 +0200 Subject: [PATCH 10/20] fix: bad pydantic types causing serialization error --- src/controllers/items.py | 2 +- src/utils/event_manager.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/controllers/items.py b/src/controllers/items.py index 500abb50..4ef748fe 100644 --- a/src/controllers/items.py +++ b/src/controllers/items.py @@ -320,7 +320,7 @@ async def retry_items(request: Request, ids: str) -> RetryResponse: class RemoveResponse(BaseModel): success: bool message: str - ids: list[str] + ids: list[int] @router.delete( diff --git a/src/utils/event_manager.py b/src/utils/event_manager.py index c9f2cb29..bc21172c 100644 --- a/src/utils/event_manager.py +++ b/src/utils/event_manager.py @@ -17,7 +17,7 @@ from program.types import Event class EventUpdate(BaseModel): - item_id: str + item_id: int imdb_id: str title: str type: str From fb9b21bc673b7e86db14e75235582897d5274b92 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Thu, 3 Oct 2024 13:59:07 +0200 Subject: [PATCH 11/20] fix: add some model validation where needed --- src/program/downloaders/realdebrid.py | 18 ++++++++++-------- src/utils/event_manager.py | 3 ++- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/program/downloaders/realdebrid.py b/src/program/downloaders/realdebrid.py index 4ed9a0b9..d7a63e5a 100644 --- a/src/program/downloaders/realdebrid.py +++ b/src/program/downloaders/realdebrid.py @@ -1,10 +1,10 @@ from datetime import datetime from enum import Enum -from typing import Union +from typing import Optional, Union from loguru import logger from program.settings.manager import settings_manager as settings -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from requests import ConnectTimeout from utils import request from utils.ratelimiter import RateLimiter @@ -36,9 +36,11 @@ class RDTorrent(BaseModel): status: RDTorrentStatus added: datetime links: list[str] - ended: Union[datetime, None] - speed: Union[int, None] - seeders: Union[int, None] + ended: Optional[datetime] = None + speed: Optional[int] = None + seeders: Optional[int] = None + +rd_torrent_list = TypeAdapter(list[RDTorrent]) class RealDebridDownloader: def __init__(self): @@ -46,7 +48,7 @@ def __init__(self): self.settings = settings.settings.downloaders.real_debrid self.initialized = self.validate() if self.initialized: - self.existing_hashes = [torrent["hash"] for torrent in get_torrents(1000)] + self.existing_hashes = [torrent.hash for torrent in get_torrents(1000)] self.file_finder = FileFinder("filename", "filesize") def validate(self) -> bool: @@ -135,7 +137,7 @@ def get_torrent_names(self, id: str) -> dict: return (info["filename"], info["original_filename"]) def delete_torrent_with_infohash(self, infohash: str): - id = next(torrent["id"] for torrent in get_torrents(1000) if torrent["hash"] == infohash) + id = next(torrent.id for torrent in get_torrents(1000) if torrent.hash == infohash) if id: delete_torrent(id) @@ -203,7 +205,7 @@ def torrent_info(id: str) -> dict: def get_torrents(limit: int) -> list[RDTorrent]: try: - torrents = get(f"torrents?limit={str(limit)}") + torrents = rd_torrent_list.validate_python(get(f"torrents?limit={str(limit)}")) except: logger.warning("Failed to get torrents.") torrents = [] diff --git a/src/utils/event_manager.py b/src/utils/event_manager.py index bc21172c..aaf4ed04 100644 --- a/src/utils/event_manager.py +++ b/src/utils/event_manager.py @@ -344,6 +344,7 @@ def get_event_updates(self) -> dict[str, list[EventUpdate]]: event_types = ["Scraping", "Downloader", "Symlinker", "Updater", "PostProcessing"] return { event_type.lower(): [ + EventUpdate.model_validate( { "item_id": event.item._id, "imdb_id": event.item.imdb_id, @@ -352,7 +353,7 @@ def get_event_updates(self) -> dict[str, list[EventUpdate]]: "emitted_by": event.emitted_by if isinstance(event.emitted_by, str) else event.emitted_by.__name__, "run_at": event.run_at.isoformat(), "last_state": event.item.last_state.name if event.item.last_state else "N/A" - } + }) for event in events if event.emitted_by == event_type ] for event_type in event_types From 9384c429cc4a756ecf83797867803c093a4aa99b Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Thu, 3 Oct 2024 19:22:37 +0200 Subject: [PATCH 12/20] feat: add mypy to dev dependencies for static type checking --- poetry.lock | 61 ++++++++++++++++++++++++++++++++++++++++++++++++-- pyproject.toml | 1 + 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index 771d9a28..54c7b75b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "alembic" @@ -1398,6 +1398,63 @@ docs = ["IPython", "bump2version", "furo", "sphinx", "sphinx-argparse", "towncri lint = ["black", "check-manifest", "flake8", "isort", "mypy"] test = ["Cython", "greenlet", "ipython", "packaging", "pytest", "pytest-cov", "pytest-textual-snapshot", "setuptools", "textual (>=0.43,!=0.65.2,!=0.66)"] +[[package]] +name = "mypy" +version = "1.11.2" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.11.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d42a6dd818ffce7be66cce644f1dff482f1d97c53ca70908dff0b9ddc120b77a"}, + {file = "mypy-1.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:801780c56d1cdb896eacd5619a83e427ce436d86a3bdf9112527f24a66618fef"}, + {file = "mypy-1.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41ea707d036a5307ac674ea172875f40c9d55c5394f888b168033177fce47383"}, + {file = "mypy-1.11.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e658bd2d20565ea86da7d91331b0eed6d2eee22dc031579e6297f3e12c758c8"}, + {file = "mypy-1.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:478db5f5036817fe45adb7332d927daa62417159d49783041338921dcf646fc7"}, + {file = "mypy-1.11.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75746e06d5fa1e91bfd5432448d00d34593b52e7e91a187d981d08d1f33d4385"}, + {file = "mypy-1.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a976775ab2256aadc6add633d44f100a2517d2388906ec4f13231fafbb0eccca"}, + {file = "mypy-1.11.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cd953f221ac1379050a8a646585a29574488974f79d8082cedef62744f0a0104"}, + {file = "mypy-1.11.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:57555a7715c0a34421013144a33d280e73c08df70f3a18a552938587ce9274f4"}, + {file = "mypy-1.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:36383a4fcbad95f2657642a07ba22ff797de26277158f1cc7bd234821468b1b6"}, + {file = "mypy-1.11.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e8960dbbbf36906c5c0b7f4fbf2f0c7ffb20f4898e6a879fcf56a41a08b0d318"}, + {file = "mypy-1.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06d26c277962f3fb50e13044674aa10553981ae514288cb7d0a738f495550b36"}, + {file = "mypy-1.11.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e7184632d89d677973a14d00ae4d03214c8bc301ceefcdaf5c474866814c987"}, + {file = "mypy-1.11.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a66169b92452f72117e2da3a576087025449018afc2d8e9bfe5ffab865709ca"}, + {file = "mypy-1.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:969ea3ef09617aff826885a22ece0ddef69d95852cdad2f60c8bb06bf1f71f70"}, + {file = "mypy-1.11.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:37c7fa6121c1cdfcaac97ce3d3b5588e847aa79b580c1e922bb5d5d2902df19b"}, + {file = "mypy-1.11.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a8a53bc3ffbd161b5b2a4fff2f0f1e23a33b0168f1c0778ec70e1a3d66deb86"}, + {file = "mypy-1.11.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ff93107f01968ed834f4256bc1fc4475e2fecf6c661260066a985b52741ddce"}, + {file = "mypy-1.11.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:edb91dded4df17eae4537668b23f0ff6baf3707683734b6a818d5b9d0c0c31a1"}, + {file = "mypy-1.11.2-cp38-cp38-win_amd64.whl", hash = "sha256:ee23de8530d99b6db0573c4ef4bd8f39a2a6f9b60655bf7a1357e585a3486f2b"}, + {file = "mypy-1.11.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:801ca29f43d5acce85f8e999b1e431fb479cb02d0e11deb7d2abb56bdaf24fd6"}, + {file = "mypy-1.11.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:af8d155170fcf87a2afb55b35dc1a0ac21df4431e7d96717621962e4b9192e70"}, + {file = "mypy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7821776e5c4286b6a13138cc935e2e9b6fde05e081bdebf5cdb2bb97c9df81d"}, + {file = "mypy-1.11.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:539c570477a96a4e6fb718b8d5c3e0c0eba1f485df13f86d2970c91f0673148d"}, + {file = "mypy-1.11.2-cp39-cp39-win_amd64.whl", hash = "sha256:3f14cd3d386ac4d05c5a39a51b84387403dadbd936e17cb35882134d4f8f0d24"}, + {file = "mypy-1.11.2-py3-none-any.whl", hash = "sha256:b499bc07dbdcd3de92b0a8b29fdf592c111276f6a12fe29c30f6c417dd546d12"}, + {file = "mypy-1.11.2.tar.gz", hash = "sha256:7f9993ad3e0ffdc95c2a14b66dee63729f021968bff8ad911867579c65d13a79"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +typing-extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -3261,4 +3318,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "42f7ca2f6421e5c1b7e3d6727f6e595eefb0efc9f167072d793ec72d3dfb8c97" +content-hash = "917ecbb3c56b9b59b74603b76d4d9c61eefe69419966b21f985b0ba585a3c919" diff --git a/pyproject.toml b/pyproject.toml index 74cfc826..018c3070 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ codecov = "^2.1.13" httpx = "^0.27.0" memray = "^1.13.4" testcontainers = "^4.8.0" +mypy = "^1.11.2" [tool.poetry.group.test] optional = true From 837ed19fe0ecd2907a332a2b37e81e50743c9382 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Thu, 3 Oct 2024 20:07:43 +0200 Subject: [PATCH 13/20] fix: wrong type in realdebrid --- src/program/downloaders/realdebrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/program/downloaders/realdebrid.py b/src/program/downloaders/realdebrid.py index d7a63e5a..58df3ad6 100644 --- a/src/program/downloaders/realdebrid.py +++ b/src/program/downloaders/realdebrid.py @@ -132,7 +132,7 @@ def all_files_valid(file_dict: dict) -> bool: break return cached_containers - def get_torrent_names(self, id: str) -> dict: + def get_torrent_names(self, id: str) -> tuple[str,str]: info = torrent_info(id) return (info["filename"], info["original_filename"]) From 8a51e6f2d86cfd0a9e63c04537a2106f3f530b55 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Thu, 3 Oct 2024 21:07:15 +0200 Subject: [PATCH 14/20] feat: add some options for easier querying of items --- src/controllers/items.py | 71 ++++++++++++++++++++++++++------------- src/program/media/item.py | 7 ++-- 2 files changed, 51 insertions(+), 27 deletions(-) diff --git a/src/controllers/items.py b/src/controllers/items.py index 4ef748fe..6db6f440 100644 --- a/src/controllers/items.py +++ b/src/controllers/items.py @@ -1,6 +1,6 @@ import asyncio from datetime import datetime -from typing import Optional +from typing import Literal, Optional import Levenshtein from fastapi import APIRouter, HTTPException, Request @@ -26,7 +26,7 @@ from program.types import Event from pydantic import BaseModel from RTN import Torrent -from sqlalchemy import func, select +from sqlalchemy import and_, func, or_, select from sqlalchemy.exc import NoResultFound from utils.logger import logger @@ -77,10 +77,13 @@ async def get_items( limit: Optional[int] = 50, page: Optional[int] = 1, type: Optional[str] = None, - state: Optional[str] = None, - sort: Optional[str] = "date_desc", + states: Optional[str] = None, + sort: Optional[ + Literal["date_desc", "date_asc", "title_asc", "title_desc"] + ] = "date_desc", search: Optional[str] = None, extended: Optional[bool] = False, + is_anime: Optional[bool] = False, ) -> ItemsResponse: if page < 1: raise HTTPException(status_code=400, detail="Page number must be 1 or greater.") @@ -100,34 +103,51 @@ async def get_items( | (func.lower(MediaItem.imdb_id).like(f"%{search_lower}%")) ) - if state: - filter_lower = state.lower() - filter_state = None - for state_enum in States: - if Levenshtein.ratio(filter_lower, state_enum.name.lower()) <= 0.82: - filter_state = state_enum - break - if filter_state: - query = query.where(MediaItem.last_state == filter_state) + if states: + states = states.split(",") + filter_states = [] + for state in states: + filter_lower = state.lower() + for state_enum in States: + if Levenshtein.ratio(filter_lower, state_enum.name.lower()) >= 0.82: + filter_states.append(state_enum) + break + if len(filter_states) == len(states): + query = query.where(MediaItem.last_state in filter_states) else: valid_states = [state_enum.name for state_enum in States] raise HTTPException( status_code=400, - detail=f"Invalid filter state: {state}. Valid states are: {valid_states}", + detail=f"Invalid filter states: {states}. Valid states are: {valid_states}", ) if type: if "," in type: types = type.split(",") for type in types: - if type not in ["movie", "show", "season", "episode"]: + if type not in ["movie", "show", "season", "episode", "anime"]: raise HTTPException( status_code=400, - detail=f"Invalid type: {type}. Valid types are: ['movie', 'show', 'season', 'episode']", + detail=f"Invalid type: {type}. Valid types are: ['movie', 'show', 'season', 'episode', 'anime']", ) else: types = [type] - query = query.where(MediaItem.type.in_(types)) + if "anime" in types: + types = [type for type in types if type != "anime"] + query = query.where( + or_( + and_( + MediaItem.type.in_(["movie", "show"]), + MediaItem.is_anime == True, + ), + MediaItem.type.in_(types), + ) + ) + else: + query = query.where(MediaItem.type.in_(types)) + + if is_anime: + query = query.where(MediaItem.is_anime is True) if sort and not search: sort_lower = sort.lower() @@ -219,17 +239,20 @@ class ItemResponse(BaseModel): description="Fetch a single media item by ID", operation_id="get_item", ) -async def get_item(request: Request, id: int) -> ItemResponse: +async def get_item( + _: Request, id: int, use_tmdb_id: Optional[bool] = False +) -> dict: with db.Session() as session: try: - item = ( - session.execute(select(MediaItem).where(MediaItem._id == id)) - .unique() - .scalar_one() - ) + query = select(MediaItem) + if use_tmdb_id: + query = query.where(MediaItem.tmdb_id == str(id)) + else: + query = query.where(MediaItem._id == id) + item = session.execute(query).unique().scalar_one() except NoResultFound: raise HTTPException(status_code=404, detail="Item not found") - return {"success": True, "item": item.to_extended_dict()} + return item.to_extended_dict(with_streams=False) class ItemsByImdbResponse(BaseModel): diff --git a/src/program/media/item.py b/src/program/media/item.py index 82dd42ea..a83d19e0 100644 --- a/src/program/media/item.py +++ b/src/program/media/item.py @@ -222,7 +222,7 @@ def to_dict(self): "scraped_times": self.scraped_times, } - def to_extended_dict(self, abbreviated_children=False): + def to_extended_dict(self, abbreviated_children=False, with_streams=True): """Convert item to extended dictionary (API response)""" dict = self.to_dict() match self: @@ -244,8 +244,9 @@ def to_extended_dict(self, abbreviated_children=False): dict["active_stream"] = ( self.active_stream if hasattr(self, "active_stream") else None ) - dict["streams"] = getattr(self, "streams", []) - dict["blacklisted_streams"] = getattr(self, "blacklisted_streams", []) + if with_streams: + dict["streams"] = getattr(self, "streams", []) + dict["blacklisted_streams"] = getattr(self, "blacklisted_streams", []) dict["number"] = self.number if hasattr(self, "number") else None dict["symlinked"] = self.symlinked if hasattr(self, "symlinked") else None dict["symlinked_at"] = ( From aba502dd57fda24f46c6cbd2211b4465d84f66ec Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Fri, 4 Oct 2024 11:03:46 +0200 Subject: [PATCH 15/20] fix: pass with_streams argument in to_extended_dict to chidren --- src/program/media/item.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/program/media/item.py b/src/program/media/item.py index a83d19e0..96766a2e 100644 --- a/src/program/media/item.py +++ b/src/program/media/item.py @@ -228,25 +228,25 @@ def to_extended_dict(self, abbreviated_children=False, with_streams=True): match self: case Show(): dict["seasons"] = ( - [season.to_extended_dict() for season in self.seasons] + [season.to_extended_dict(with_streams=with_streams) for season in self.seasons] if not abbreviated_children else self.represent_children ) case Season(): dict["episodes"] = ( - [episode.to_extended_dict() for episode in self.episodes] + [episode.to_extended_dict(with_streams=with_streams) for episode in self.episodes] if not abbreviated_children else self.represent_children ) dict["language"] = self.language if hasattr(self, "language") else None dict["country"] = self.country if hasattr(self, "country") else None dict["network"] = self.network if hasattr(self, "network") else None - dict["active_stream"] = ( - self.active_stream if hasattr(self, "active_stream") else None - ) if with_streams: dict["streams"] = getattr(self, "streams", []) dict["blacklisted_streams"] = getattr(self, "blacklisted_streams", []) + dict["active_stream"] = ( + self.active_stream if hasattr(self, "active_stream") else None + ) dict["number"] = self.number if hasattr(self, "number") else None dict["symlinked"] = self.symlinked if hasattr(self, "symlinked") else None dict["symlinked_at"] = ( From 87109f0fc8f51404fa437904d3620650925e3b68 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Fri, 4 Oct 2024 19:55:31 +0200 Subject: [PATCH 16/20] feat: remove the old json response format from services and stats endpoints --- src/controllers/default.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/controllers/default.py b/src/controllers/default.py index 5654703b..ff8d6ee4 100644 --- a/src/controllers/default.py +++ b/src/controllers/default.py @@ -17,9 +17,11 @@ responses={404: {"description": "Not found"}}, ) + class RootResponse(MessageAndSuccessResponse): version: str + @router.get("/", operation_id="root") async def root() -> RootResponse: return { @@ -28,6 +30,7 @@ async def root() -> RootResponse: "version": settings_manager.settings.version, } + @router.get("/health", operation_id="health") async def health(request: Request) -> MessageAndSuccessResponse: return { @@ -35,6 +38,7 @@ async def health(request: Request) -> MessageAndSuccessResponse: "message": request.app.program.initialized, } + class RDUser(BaseModel): id: int username: str @@ -45,6 +49,7 @@ class RDUser(BaseModel): type: Literal["free", "premium"] premium: int = Field(description="Premium subscription left in seconds") + @router.get("/rd", operation_id="rd") async def get_rd_user() -> DataAndSuccessResponse[RDUser]: api_key = settings_manager.settings.downloaders.real_debrid.api_key @@ -83,7 +88,7 @@ async def get_torbox_user(): @router.get("/services", operation_id="services") -async def get_services(request: Request) -> DataAndSuccessResponse[dict]: +async def get_services(request: Request) -> dict[str, bool]: data = {} if hasattr(request.app.program, "services"): for service in request.app.program.all_services.values(): @@ -92,11 +97,13 @@ async def get_services(request: Request) -> DataAndSuccessResponse[dict]: continue for sub_service in service.services.values(): data[sub_service.key] = sub_service.initialized - return {"success": True, "data": data} + return data + class TraktOAuthInitiateResponse(BaseModel): auth_url: str + @router.get("/trakt/oauth/initiate", operation_id="trakt_oauth_initiate") async def initiate_trakt_oauth(request: Request) -> TraktOAuthInitiateResponse: trakt = request.app.program.services.get(TraktContent) @@ -107,7 +114,9 @@ async def initiate_trakt_oauth(request: Request) -> TraktOAuthInitiateResponse: @router.get("/trakt/oauth/callback", operation_id="trakt_oauth_callback") -async def trakt_oauth_callback(code: str, request: Request) -> MessageAndSuccessResponse: +async def trakt_oauth_callback( + code: str, request: Request +) -> MessageAndSuccessResponse: trakt = request.app.program.services.get(TraktContent) if trakt is None: raise HTTPException(status_code=404, detail="Trakt service not found") @@ -117,6 +126,7 @@ async def trakt_oauth_callback(code: str, request: Request) -> MessageAndSuccess else: raise HTTPException(status_code=400, detail="Failed to obtain OAuth token") + class StatsResponse(BaseModel): total_items: int total_movies: int @@ -125,11 +135,14 @@ class StatsResponse(BaseModel): total_episodes: int total_symlinks: int incomplete_items: int - incomplete_retries: dict[str, int] = Field(description="Media item log string: number of retries") + incomplete_retries: dict[str, int] = Field( + description="Media item log string: number of retries" + ) states: dict[States, int] + @router.get("/stats", operation_id="stats") -async def get_stats(_: Request) -> DataAndSuccessResponse[StatsResponse]: +async def get_stats(_: Request) -> StatsResponse: payload = {} with db.Session() as session: movies_symlinks = session.execute( @@ -180,14 +193,14 @@ async def get_stats(_: Request) -> DataAndSuccessResponse[StatsResponse]: payload["incomplete_items"] = len(_incomplete_items) payload["incomplete_retries"] = incomplete_retries payload["states"] = states - - return {"success": True, "data": payload} + return payload class LogsResponse(BaseModel): success: bool logs: str + @router.get("/logs", operation_id="logs") async def get_logs() -> LogsResponse: log_file_path = None @@ -209,7 +222,9 @@ async def get_logs() -> LogsResponse: @router.get("/events", operation_id="events") -async def get_events(request: Request) -> DataAndSuccessResponse[dict[str, list[EventUpdate]]]: +async def get_events( + request: Request, +) -> DataAndSuccessResponse[dict[str, list[EventUpdate]]]: return {"success": True, "data": request.app.program.em.get_event_updates()} From bbaa73d49d7b43c8a63dfe2e5e8a970bf4fa9bb4 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Fri, 4 Oct 2024 20:23:50 +0200 Subject: [PATCH 17/20] feat: migrate the settings api to the new response types --- src/controllers/models/shared.py | 3 ++ src/controllers/settings.py | 55 +++++++------------------------- 2 files changed, 15 insertions(+), 43 deletions(-) diff --git a/src/controllers/models/shared.py b/src/controllers/models/shared.py index 1550167b..fb4ab778 100644 --- a/src/controllers/models/shared.py +++ b/src/controllers/models/shared.py @@ -6,6 +6,9 @@ class MessageAndSuccessResponse(BaseModel): message: str success: bool +class MessageResponse(BaseModel): + message: str + T = TypeVar('T', bound=BaseModel) class DataAndSuccessResponse(BaseModel, Generic[T]): diff --git a/src/controllers/settings.py b/src/controllers/settings.py index ea07e44c..a69e5f55 100644 --- a/src/controllers/settings.py +++ b/src/controllers/settings.py @@ -1,6 +1,7 @@ from copy import copy from typing import Any, Dict, List +from controllers.models.shared import MessageResponse from fastapi import APIRouter, HTTPException from program.settings.manager import settings_manager from program.settings.models import AppModel @@ -26,48 +27,29 @@ async def get_settings_schema() -> dict[str, Any]: """ return settings_manager.settings.model_json_schema() -class LoadSettingsResponse(BaseModel): - success: bool - message: str - @router.get("/load", operation_id="load_settings") -async def load_settings() -> LoadSettingsResponse: +async def load_settings() -> MessageResponse: settings_manager.load() return { - "success": True, "message": "Settings loaded!", } -class SaveSettingsResponse(BaseModel): - success: bool - message: str - @router.post("/save", operation_id="save_settings") -async def save_settings() -> SaveSettingsResponse: +async def save_settings() -> MessageResponse: settings_manager.save() return { - "success": True, "message": "Settings saved!", } -class GetAllSettingsResponse(BaseModel): - success: bool - data: AppModel @router.get("/get/all", operation_id="get_all_settings") -async def get_all_settings() -> GetAllSettingsResponse: - return { - "success": True, - "data": copy(settings_manager.settings), - } +async def get_all_settings() -> dict[str, Any]: + return copy(settings_manager.settings) -class GetSettingsResponse(BaseModel): - success: bool - data: dict[str, Any] @router.get("/get/{paths}", operation_id="get_settings") -async def get_settings(paths: str) -> GetSettingsResponse: - current_settings = settings_manager.settings.dict() +async def get_settings(paths: str) -> dict[str, Any]: + current_settings = settings_manager.settings.model_dump() data = {} for path in paths.split(","): keys = path.split(".") @@ -79,18 +61,11 @@ async def get_settings(paths: str) -> GetSettingsResponse: current_obj = current_obj[k] data[path] = current_obj + return data - return { - "success": True, - "data": data, - } - -class SetAllSettingsResponse(BaseModel): - success: bool - message: str @router.post("/set/all", operation_id="set_all_settings") -async def set_all_settings(new_settings: Dict[str, Any]) -> SetAllSettingsResponse: +async def set_all_settings(new_settings: Dict[str, Any]) -> MessageResponse: current_settings = settings_manager.settings.model_dump() def update_settings(current_obj, new_obj): @@ -111,16 +86,11 @@ def update_settings(current_obj, new_obj): raise HTTPException(status_code=400, detail=str(e)) return { - "success": True, "message": "All settings updated successfully!", } -class SetSettingsResponse(BaseModel): - success: bool - message: str - -@router.post("/set") -async def set_settings(settings: List[SetSettings]) -> SetSettingsResponse: +@router.post("/set", operation_id="set_settings") +async def set_settings(settings: List[SetSettings]) -> MessageResponse: current_settings = settings_manager.settings.model_dump() for setting in settings: @@ -156,5 +126,4 @@ async def set_settings(settings: List[SetSettings]) -> SetSettingsResponse: detail=f"Failed to update settings: {str(e)}", ) - return {"success": True, "message": "Settings updated successfully."} - + return {"message": "Settings updated successfully."} From 4765dd7b7a3faea88159330a355b7cd8778b1912 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Fri, 4 Oct 2024 20:42:01 +0200 Subject: [PATCH 18/20] feat: add type annotation to get_all_settings --- src/controllers/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/controllers/settings.py b/src/controllers/settings.py index a69e5f55..d02375c4 100644 --- a/src/controllers/settings.py +++ b/src/controllers/settings.py @@ -43,7 +43,7 @@ async def save_settings() -> MessageResponse: @router.get("/get/all", operation_id="get_all_settings") -async def get_all_settings() -> dict[str, Any]: +async def get_all_settings() -> AppModel: return copy(settings_manager.settings) From 96f5c674a541974ddff051aaf36bee2497443a43 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Fri, 4 Oct 2024 23:15:54 +0200 Subject: [PATCH 19/20] feat: migrate the rest of the APIs to the new response schema --- src/controllers/default.py | 31 +++++++++------------- src/controllers/items.py | 44 ++++++++------------------------ src/controllers/models/shared.py | 15 ++--------- src/controllers/scrape.py | 9 ++----- 4 files changed, 26 insertions(+), 73 deletions(-) diff --git a/src/controllers/default.py b/src/controllers/default.py index ff8d6ee4..7a49abbf 100644 --- a/src/controllers/default.py +++ b/src/controllers/default.py @@ -1,7 +1,7 @@ from typing import Literal import requests -from controllers.models.shared import DataAndSuccessResponse, MessageAndSuccessResponse +from controllers.models.shared import DataAndSuccessResponse, MessageAndSuccessResponse, MessageResponse from fastapi import APIRouter, HTTPException, Request from loguru import logger from program.content.trakt import TraktContent @@ -51,7 +51,7 @@ class RDUser(BaseModel): @router.get("/rd", operation_id="rd") -async def get_rd_user() -> DataAndSuccessResponse[RDUser]: +async def get_rd_user() -> RDUser: api_key = settings_manager.settings.downloaders.real_debrid.api_key headers = {"Authorization": f"Bearer {api_key}"} @@ -71,11 +71,7 @@ async def get_rd_user() -> DataAndSuccessResponse[RDUser]: if response.status_code != 200: return {"success": False, "message": response.json()} - return { - "success": True, - "data": response.json(), - } - + return response.json() @router.get("/torbox", operation_id="torbox") async def get_torbox_user(): @@ -114,15 +110,13 @@ async def initiate_trakt_oauth(request: Request) -> TraktOAuthInitiateResponse: @router.get("/trakt/oauth/callback", operation_id="trakt_oauth_callback") -async def trakt_oauth_callback( - code: str, request: Request -) -> MessageAndSuccessResponse: +async def trakt_oauth_callback(code: str, request: Request) -> MessageResponse: trakt = request.app.program.services.get(TraktContent) if trakt is None: raise HTTPException(status_code=404, detail="Trakt service not found") success = trakt.handle_oauth_callback(code) if success: - return {"success": True, "message": "OAuth token obtained successfully"} + return {"message": "OAuth token obtained successfully"} else: raise HTTPException(status_code=400, detail="Failed to obtain OAuth token") @@ -197,12 +191,11 @@ async def get_stats(_: Request) -> StatsResponse: class LogsResponse(BaseModel): - success: bool logs: str @router.get("/logs", operation_id="logs") -async def get_logs() -> LogsResponse: +async def get_logs() -> str: log_file_path = None for handler in logger._core.handlers.values(): if ".log" in handler._name: @@ -215,21 +208,21 @@ async def get_logs() -> LogsResponse: try: with open(log_file_path, "r") as log_file: log_contents = log_file.read() - return {"success": True, "logs": log_contents} + return {"logs": log_contents} except Exception as e: logger.error(f"Failed to read log file: {e}") - return {"success": False, "message": "Failed to read log file"} + raise HTTPException(status_code=500, detail="Failed to read log file") @router.get("/events", operation_id="events") async def get_events( request: Request, -) -> DataAndSuccessResponse[dict[str, list[EventUpdate]]]: - return {"success": True, "data": request.app.program.em.get_event_updates()} +) -> dict[str, list[EventUpdate]]: + return request.app.program.em.get_event_updates() @router.get("/mount", operation_id="mount") -async def get_rclone_files() -> DataAndSuccessResponse[dict[str, str]]: +async def get_rclone_files() -> dict[str, str]: """Get all files in the rclone mount.""" import os @@ -245,4 +238,4 @@ def scan_dir(path): scan_dir(entry.path) scan_dir(rclone_dir) # dict of `filename: filepath`` - return {"success": True, "data": file_map} + return file_map diff --git a/src/controllers/items.py b/src/controllers/items.py index 6db6f440..b86cca6c 100644 --- a/src/controllers/items.py +++ b/src/controllers/items.py @@ -28,6 +28,7 @@ from RTN import Torrent from sqlalchemy import and_, func, or_, select from sqlalchemy.exc import NoResultFound +from src.controllers.models.shared import MessageResponse from utils.logger import logger router = APIRouter( @@ -191,18 +192,13 @@ async def get_items( } -class AddItemsResponse(BaseModel): - success: bool - message: str - - @router.post( "/add", summary="Add Media Items", description="Add media items with bases on imdb IDs", operation_id="add_items", ) -async def add_items(request: Request, imdb_ids: str = None) -> AddItemsResponse: +async def add_items(request: Request, imdb_ids: str = None) -> MessageResponse: if not imdb_ids: raise HTTPException(status_code=400, detail="No IMDb ID(s) provided") @@ -225,12 +221,7 @@ async def add_items(request: Request, imdb_ids: str = None) -> AddItemsResponse: ) request.app.program.em.add_item(item) - return {"success": True, "message": f"Added {len(valid_ids)} item(s) to the queue"} - - -class ItemResponse(BaseModel): - success: bool - item: dict + return {"message": f"Added {len(valid_ids)} item(s) to the queue"} @router.get( @@ -239,9 +230,7 @@ class ItemResponse(BaseModel): description="Fetch a single media item by ID", operation_id="get_item", ) -async def get_item( - _: Request, id: int, use_tmdb_id: Optional[bool] = False -) -> dict: +async def get_item(_: Request, id: int, use_tmdb_id: Optional[bool] = False) -> dict: with db.Session() as session: try: query = select(MediaItem) @@ -255,18 +244,13 @@ async def get_item( return item.to_extended_dict(with_streams=False) -class ItemsByImdbResponse(BaseModel): - success: bool - items: list[dict] - - @router.get( "/{imdb_ids}", summary="Retrieve Media Items By IMDb IDs", description="Fetch media items by IMDb IDs", operation_id="get_items_by_imdb_ids", ) -async def get_items_by_imdb_ids(request: Request, imdb_ids: str) -> ItemsByImdbResponse: +async def get_items_by_imdb_ids(request: Request, imdb_ids: str) -> list[dict]: ids = imdb_ids.split(",") with db.Session() as session: items = [] @@ -278,15 +262,13 @@ async def get_items_by_imdb_ids(request: Request, imdb_ids: str) -> ItemsByImdbR ) if item: items.append(item) - return {"success": True, "items": [item.to_extended_dict() for item in items]} + return [item.to_extended_dict() for item in items] class ResetResponse(BaseModel): - success: bool message: str ids: list[str] - @router.post( "/reset", summary="Reset Media Items", @@ -309,15 +291,13 @@ async def reset_items(request: Request, ids: str) -> ResetResponse: continue except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) from e - return {"success": True, "message": f"Reset items with id {ids}", "ids": ids} + return {"message": f"Reset items with id {ids}", "ids": ids} class RetryResponse(BaseModel): - success: bool message: str ids: list[str] - @router.post( "/retry", summary="Retry Media Items", @@ -337,15 +317,13 @@ async def retry_items(request: Request, ids: str) -> RetryResponse: except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - return {"success": True, "message": f"Retried items with ids {ids}", "ids": ids} + return {"message": f"Retried items with ids {ids}", "ids": ids} class RemoveResponse(BaseModel): - success: bool message: str ids: list[int] - @router.delete( "/remove", summary="Remove Media Items", @@ -375,16 +353,14 @@ async def remove_item(request: Request, ids: str) -> RemoveResponse: except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - return {"success": True, "message": f"Removed items with ids {ids}", "ids": ids} + return {"message": f"Removed items with ids {ids}", "ids": ids} class SetTorrentRDResponse(BaseModel): - success: bool message: str item_id: int torrent_id: str - @router.post( "/{id}/set_torrent_rd_magnet", name="Set torrent RD magnet", @@ -428,7 +404,7 @@ def create_stream(hash, torrent_info): "/{id}/set_torrent_rd", description="Set a torrent for a media item using RD torrent ID.", ) -def set_torrent_rd(request: Request, id: int, torrent_id: str): +def set_torrent_rd(request: Request, id: int, torrent_id: str) -> SetTorrentRDResponse: downloader: Downloader = request.app.program.services.get(Downloader) with db.Session() as session: item: MediaItem = ( diff --git a/src/controllers/models/shared.py b/src/controllers/models/shared.py index fb4ab778..53b5fefc 100644 --- a/src/controllers/models/shared.py +++ b/src/controllers/models/shared.py @@ -1,16 +1,5 @@ -from typing import Generic, TypeVar from pydantic import BaseModel -class MessageAndSuccessResponse(BaseModel): - message: str - success: bool - -class MessageResponse(BaseModel): - message: str - -T = TypeVar('T', bound=BaseModel) - -class DataAndSuccessResponse(BaseModel, Generic[T]): - data: T - success: bool \ No newline at end of file +class MessageResponse(BaseModel): + message: str \ No newline at end of file diff --git a/src/controllers/scrape.py b/src/controllers/scrape.py index 442c7905..e975705b 100644 --- a/src/controllers/scrape.py +++ b/src/controllers/scrape.py @@ -18,11 +18,6 @@ class ScrapedTorrent(BaseModel): infohash: str -class ScrapeResponse(BaseModel): - success: bool - data: list[ScrapedTorrent] - - @router.get( "", summary="Scrape Media Item", @@ -31,7 +26,7 @@ class ScrapeResponse(BaseModel): ) async def scrape( request: Request, imdb_id: str, season: int = None, episode: int = None -) -> ScrapeResponse: +) -> list[ScrapedTorrent]: """ Scrape media item based on IMDb ID. @@ -104,7 +99,7 @@ async def scrape( except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - return {"success": True, "data": data} + return data @router.get( "/rd", From f84af7f3f4f77373826a800871a9d8827d34f254 Mon Sep 17 00:00:00 2001 From: Filip Trplan Date: Fri, 4 Oct 2024 23:23:42 +0200 Subject: [PATCH 20/20] fix: remove old imports --- src/controllers/default.py | 9 ++++----- src/controllers/items.py | 6 +++++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/controllers/default.py b/src/controllers/default.py index 7a49abbf..52051055 100644 --- a/src/controllers/default.py +++ b/src/controllers/default.py @@ -1,7 +1,7 @@ from typing import Literal import requests -from controllers.models.shared import DataAndSuccessResponse, MessageAndSuccessResponse, MessageResponse +from controllers.models.shared import MessageResponse from fastapi import APIRouter, HTTPException, Request from loguru import logger from program.content.trakt import TraktContent @@ -18,23 +18,21 @@ ) -class RootResponse(MessageAndSuccessResponse): +class RootResponse(MessageResponse): version: str @router.get("/", operation_id="root") async def root() -> RootResponse: return { - "success": True, "message": "Riven is running!", "version": settings_manager.settings.version, } @router.get("/health", operation_id="health") -async def health(request: Request) -> MessageAndSuccessResponse: +async def health(request: Request) -> MessageResponse: return { - "success": True, "message": request.app.program.initialized, } @@ -73,6 +71,7 @@ async def get_rd_user() -> RDUser: return response.json() + @router.get("/torbox", operation_id="torbox") async def get_torbox_user(): api_key = settings_manager.settings.downloaders.torbox.api_key diff --git a/src/controllers/items.py b/src/controllers/items.py index b86cca6c..f4442fd9 100644 --- a/src/controllers/items.py +++ b/src/controllers/items.py @@ -3,6 +3,7 @@ from typing import Literal, Optional import Levenshtein +from controllers.models.shared import MessageResponse from fastapi import APIRouter, HTTPException, Request from program.content import Overseerr from program.db.db import db @@ -28,7 +29,6 @@ from RTN import Torrent from sqlalchemy import and_, func, or_, select from sqlalchemy.exc import NoResultFound -from src.controllers.models.shared import MessageResponse from utils.logger import logger router = APIRouter( @@ -269,6 +269,7 @@ class ResetResponse(BaseModel): message: str ids: list[str] + @router.post( "/reset", summary="Reset Media Items", @@ -298,6 +299,7 @@ class RetryResponse(BaseModel): message: str ids: list[str] + @router.post( "/retry", summary="Retry Media Items", @@ -324,6 +326,7 @@ class RemoveResponse(BaseModel): message: str ids: list[int] + @router.delete( "/remove", summary="Remove Media Items", @@ -361,6 +364,7 @@ class SetTorrentRDResponse(BaseModel): item_id: int torrent_id: str + @router.post( "/{id}/set_torrent_rd_magnet", name="Set torrent RD magnet",