diff --git a/poetry.lock b/poetry.lock index 771d9a28..54c7b75b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "alembic" @@ -1398,6 +1398,63 @@ docs = ["IPython", "bump2version", "furo", "sphinx", "sphinx-argparse", "towncri lint = ["black", "check-manifest", "flake8", "isort", "mypy"] test = ["Cython", "greenlet", "ipython", "packaging", "pytest", "pytest-cov", "pytest-textual-snapshot", "setuptools", "textual (>=0.43,!=0.65.2,!=0.66)"] +[[package]] +name = "mypy" +version = "1.11.2" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.11.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d42a6dd818ffce7be66cce644f1dff482f1d97c53ca70908dff0b9ddc120b77a"}, + {file = "mypy-1.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:801780c56d1cdb896eacd5619a83e427ce436d86a3bdf9112527f24a66618fef"}, + {file = "mypy-1.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41ea707d036a5307ac674ea172875f40c9d55c5394f888b168033177fce47383"}, + {file = "mypy-1.11.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e658bd2d20565ea86da7d91331b0eed6d2eee22dc031579e6297f3e12c758c8"}, + {file = "mypy-1.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:478db5f5036817fe45adb7332d927daa62417159d49783041338921dcf646fc7"}, + {file = "mypy-1.11.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75746e06d5fa1e91bfd5432448d00d34593b52e7e91a187d981d08d1f33d4385"}, + {file = "mypy-1.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a976775ab2256aadc6add633d44f100a2517d2388906ec4f13231fafbb0eccca"}, + {file = "mypy-1.11.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cd953f221ac1379050a8a646585a29574488974f79d8082cedef62744f0a0104"}, + {file = "mypy-1.11.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:57555a7715c0a34421013144a33d280e73c08df70f3a18a552938587ce9274f4"}, + {file = "mypy-1.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:36383a4fcbad95f2657642a07ba22ff797de26277158f1cc7bd234821468b1b6"}, + {file = "mypy-1.11.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e8960dbbbf36906c5c0b7f4fbf2f0c7ffb20f4898e6a879fcf56a41a08b0d318"}, + {file = "mypy-1.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06d26c277962f3fb50e13044674aa10553981ae514288cb7d0a738f495550b36"}, + {file = "mypy-1.11.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e7184632d89d677973a14d00ae4d03214c8bc301ceefcdaf5c474866814c987"}, + {file = "mypy-1.11.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a66169b92452f72117e2da3a576087025449018afc2d8e9bfe5ffab865709ca"}, + {file = "mypy-1.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:969ea3ef09617aff826885a22ece0ddef69d95852cdad2f60c8bb06bf1f71f70"}, + {file = "mypy-1.11.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:37c7fa6121c1cdfcaac97ce3d3b5588e847aa79b580c1e922bb5d5d2902df19b"}, + {file = "mypy-1.11.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a8a53bc3ffbd161b5b2a4fff2f0f1e23a33b0168f1c0778ec70e1a3d66deb86"}, + {file = "mypy-1.11.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ff93107f01968ed834f4256bc1fc4475e2fecf6c661260066a985b52741ddce"}, + {file = "mypy-1.11.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:edb91dded4df17eae4537668b23f0ff6baf3707683734b6a818d5b9d0c0c31a1"}, + {file = "mypy-1.11.2-cp38-cp38-win_amd64.whl", hash = "sha256:ee23de8530d99b6db0573c4ef4bd8f39a2a6f9b60655bf7a1357e585a3486f2b"}, + {file = "mypy-1.11.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:801ca29f43d5acce85f8e999b1e431fb479cb02d0e11deb7d2abb56bdaf24fd6"}, + {file = "mypy-1.11.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:af8d155170fcf87a2afb55b35dc1a0ac21df4431e7d96717621962e4b9192e70"}, + {file = "mypy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7821776e5c4286b6a13138cc935e2e9b6fde05e081bdebf5cdb2bb97c9df81d"}, + {file = "mypy-1.11.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:539c570477a96a4e6fb718b8d5c3e0c0eba1f485df13f86d2970c91f0673148d"}, + {file = "mypy-1.11.2-cp39-cp39-win_amd64.whl", hash = "sha256:3f14cd3d386ac4d05c5a39a51b84387403dadbd936e17cb35882134d4f8f0d24"}, + {file = "mypy-1.11.2-py3-none-any.whl", hash = "sha256:b499bc07dbdcd3de92b0a8b29fdf592c111276f6a12fe29c30f6c417dd546d12"}, + {file = "mypy-1.11.2.tar.gz", hash = "sha256:7f9993ad3e0ffdc95c2a14b66dee63729f021968bff8ad911867579c65d13a79"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +typing-extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -3261,4 +3318,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "42f7ca2f6421e5c1b7e3d6727f6e595eefb0efc9f167072d793ec72d3dfb8c97" +content-hash = "917ecbb3c56b9b59b74603b76d4d9c61eefe69419966b21f985b0ba585a3c919" diff --git a/pyproject.toml b/pyproject.toml index cc590509..018c3070 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ codecov = "^2.1.13" httpx = "^0.27.0" memray = "^1.13.4" testcontainers = "^4.8.0" +mypy = "^1.11.2" [tool.poetry.group.test] optional = true @@ -91,7 +92,8 @@ ignore = [ "S101", # ruff: Ignore assert warnings on tests "RET505", # "RET503", # ruff: Ignore required explicit returns (is this desired?) - "SLF001" # private member accessing from pickle + "SLF001", # private member accessing from pickle + "B904" # ruff: ignore raising exceptions from except for the API ] extend-select = [ "I", # isort diff --git a/src/controllers/default.py b/src/controllers/default.py index 94d96d8c..52051055 100644 --- a/src/controllers/default.py +++ b/src/controllers/default.py @@ -1,57 +1,75 @@ +from typing import Literal + import requests +from controllers.models.shared import MessageResponse from fastapi import APIRouter, HTTPException, Request from loguru import logger -from sqlalchemy import func, select - from program.content.trakt import TraktContent from program.db.db import db from program.media.item import Episode, MediaItem, Movie, Season, Show from program.media.state import States from program.settings.manager import settings_manager +from pydantic import BaseModel, Field +from sqlalchemy import func, select +from utils.event_manager import EventUpdate router = APIRouter( responses={404: {"description": "Not found"}}, ) +class RootResponse(MessageResponse): + version: str + + @router.get("/", operation_id="root") -async def root(): +async def root() -> RootResponse: return { - "success": True, "message": "Riven is running!", "version": settings_manager.settings.version, } @router.get("/health", operation_id="health") -async def health(request: Request): +async def health(request: Request) -> MessageResponse: return { - "success": True, "message": request.app.program.initialized, } +class RDUser(BaseModel): + id: int + username: str + email: str + points: int = Field(description="User's RD points") + locale: str + avatar: str = Field(description="URL to the user's avatar") + type: Literal["free", "premium"] + premium: int = Field(description="Premium subscription left in seconds") + + @router.get("/rd", operation_id="rd") -async def get_rd_user(): +async def get_rd_user() -> RDUser: api_key = settings_manager.settings.downloaders.real_debrid.api_key headers = {"Authorization": f"Bearer {api_key}"} - proxy = settings_manager.settings.downloaders.real_debrid.proxy_url if settings_manager.settings.downloaders.real_debrid.proxy_enabled else None + proxy = ( + settings_manager.settings.downloaders.real_debrid.proxy_url + if settings_manager.settings.downloaders.real_debrid.proxy_enabled + else None + ) response = requests.get( "https://api.real-debrid.com/rest/1.0/user", headers=headers, proxies=proxy if proxy else None, - timeout=10 + timeout=10, ) if response.status_code != 200: return {"success": False, "message": response.json()} - return { - "success": True, - "data": response.json(), - } + return response.json() @router.get("/torbox", operation_id="torbox") @@ -65,7 +83,7 @@ async def get_torbox_user(): @router.get("/services", operation_id="services") -async def get_services(request: Request): +async def get_services(request: Request) -> dict[str, bool]: data = {} if hasattr(request.app.program, "services"): for service in request.app.program.all_services.values(): @@ -74,11 +92,15 @@ async def get_services(request: Request): continue for sub_service in service.services.values(): data[sub_service.key] = sub_service.initialized - return {"success": True, "data": data} + return data + + +class TraktOAuthInitiateResponse(BaseModel): + auth_url: str @router.get("/trakt/oauth/initiate", operation_id="trakt_oauth_initiate") -async def initiate_trakt_oauth(request: Request): +async def initiate_trakt_oauth(request: Request) -> TraktOAuthInitiateResponse: trakt = request.app.program.services.get(TraktContent) if trakt is None: raise HTTPException(status_code=404, detail="Trakt service not found") @@ -87,24 +109,41 @@ async def initiate_trakt_oauth(request: Request): @router.get("/trakt/oauth/callback", operation_id="trakt_oauth_callback") -async def trakt_oauth_callback(code: str, request: Request): +async def trakt_oauth_callback(code: str, request: Request) -> MessageResponse: trakt = request.app.program.services.get(TraktContent) if trakt is None: raise HTTPException(status_code=404, detail="Trakt service not found") success = trakt.handle_oauth_callback(code) if success: - return {"success": True, "message": "OAuth token obtained successfully"} + return {"message": "OAuth token obtained successfully"} else: raise HTTPException(status_code=400, detail="Failed to obtain OAuth token") +class StatsResponse(BaseModel): + total_items: int + total_movies: int + total_shows: int + total_seasons: int + total_episodes: int + total_symlinks: int + incomplete_items: int + incomplete_retries: dict[str, int] = Field( + description="Media item log string: number of retries" + ) + states: dict[States, int] + + @router.get("/stats", operation_id="stats") -async def get_stats(_: Request): +async def get_stats(_: Request) -> StatsResponse: payload = {} with db.Session() as session: - - movies_symlinks = session.execute(select(func.count(Movie._id)).where(Movie.symlinked == True)).scalar_one() - episodes_symlinks = session.execute(select(func.count(Episode._id)).where(Episode.symlinked == True)).scalar_one() + movies_symlinks = session.execute( + select(func.count(Movie._id)).where(Movie.symlinked == True) + ).scalar_one() + episodes_symlinks = session.execute( + select(func.count(Episode._id)).where(Episode.symlinked == True) + ).scalar_one() total_symlinks = movies_symlinks + episodes_symlinks total_movies = session.execute(select(func.count(Movie._id))).scalar_one() @@ -113,21 +152,30 @@ async def get_stats(_: Request): total_episodes = session.execute(select(func.count(Episode._id))).scalar_one() total_items = session.execute(select(func.count(MediaItem._id))).scalar_one() - # Select only the IDs of incomplete items - _incomplete_items = session.execute( - select(MediaItem._id) - .where(MediaItem.last_state != States.Completed) - ).scalars().all() - + # Select only the IDs of incomplete items + _incomplete_items = ( + session.execute( + select(MediaItem._id).where(MediaItem.last_state != States.Completed) + ) + .scalars() + .all() + ) + incomplete_retries = {} if _incomplete_items: - media_items = session.query(MediaItem).filter(MediaItem._id.in_(_incomplete_items)).all() + media_items = ( + session.query(MediaItem) + .filter(MediaItem._id.in_(_incomplete_items)) + .all() + ) for media_item in media_items: incomplete_retries[media_item.log_string] = media_item.scraped_times states = {} for state in States: - states[state] = session.execute(select(func.count(MediaItem._id)).where(MediaItem.last_state == state)).scalar_one() + states[state] = session.execute( + select(func.count(MediaItem._id)).where(MediaItem.last_state == state) + ).scalar_one() payload["total_items"] = total_items payload["total_movies"] = total_movies @@ -138,11 +186,15 @@ async def get_stats(_: Request): payload["incomplete_items"] = len(_incomplete_items) payload["incomplete_retries"] = incomplete_retries payload["states"] = states + return payload + + +class LogsResponse(BaseModel): + logs: str - return {"success": True, "data": payload} @router.get("/logs", operation_id="logs") -async def get_logs(): +async def get_logs() -> str: log_file_path = None for handler in logger._core.handlers.values(): if ".log" in handler._name: @@ -153,24 +205,29 @@ async def get_logs(): return {"success": False, "message": "Log file handler not found"} try: - with open(log_file_path, 'r') as log_file: + with open(log_file_path, "r") as log_file: log_contents = log_file.read() - return {"success": True, "logs": log_contents} + return {"logs": log_contents} except Exception as e: logger.error(f"Failed to read log file: {e}") - return {"success": False, "message": "Failed to read log file"} - + raise HTTPException(status_code=500, detail="Failed to read log file") + + @router.get("/events", operation_id="events") -async def get_events(request: Request): - return {"success": True, "data": request.app.program.em.get_event_updates()} +async def get_events( + request: Request, +) -> dict[str, list[EventUpdate]]: + return request.app.program.em.get_event_updates() @router.get("/mount", operation_id="mount") -async def get_rclone_files(): +async def get_rclone_files() -> dict[str, str]: """Get all files in the rclone mount.""" import os + rclone_dir = settings_manager.settings.symlink.rclone_path file_map = {} + def scan_dir(path): with os.scandir(path) as entries: for entry in entries: @@ -179,6 +236,5 @@ def scan_dir(path): elif entry.is_dir(): scan_dir(entry.path) - scan_dir(rclone_dir) # dict of `filename: filepath`` - return {"success": True, "data": file_map} - + scan_dir(rclone_dir) # dict of `filename: filepath`` + return file_map diff --git a/src/controllers/items.py b/src/controllers/items.py index e213f562..f4442fd9 100644 --- a/src/controllers/items.py +++ b/src/controllers/items.py @@ -1,13 +1,10 @@ import asyncio from datetime import datetime -from typing import Optional +from typing import Literal, Optional import Levenshtein -from RTN import RTN, Torrent +from controllers.models.shared import MessageResponse from fastapi import APIRouter, HTTPException, Request -from sqlalchemy import func, select -from sqlalchemy.exc import NoResultFound - from program.content import Overseerr from program.db.db import db from program.db.db_functions import ( @@ -17,16 +14,21 @@ get_parent_items_by_ids, reset_media_item, ) +from program.downloaders import Downloader, get_needed_media +from program.downloaders.realdebrid import ( + add_torrent_magnet, + torrent_info, +) from program.media.item import MediaItem from program.media.state import States -from program.symlink import Symlinker -from program.downloaders import Downloader, get_needed_media -from program.downloaders.realdebrid import RealDebridDownloader, add_torrent_magnet, torrent_info -from program.settings.versions import models -from program.settings.manager import settings_manager from program.media.stream import Stream from program.scrapers.shared import rtn +from program.symlink import Symlinker from program.types import Event +from pydantic import BaseModel +from RTN import Torrent +from sqlalchemy import and_, func, or_, select +from sqlalchemy.exc import NoResultFound from utils.logger import logger router = APIRouter( @@ -35,34 +37,55 @@ responses={404: {"description": "Not found"}}, ) + def handle_ids(ids: str) -> list[int]: ids = [int(id) for id in ids.split(",")] if "," in ids else [int(ids)] if not ids: raise HTTPException(status_code=400, detail="No item ID provided") return ids -@router.get("/states") -async def get_states(): + +class StateResponse(BaseModel): + success: bool + states: list[str] + + +@router.get("/states", operation_id="get_states") +async def get_states() -> StateResponse: return { "success": True, "states": [state for state in States], } + +class ItemsResponse(BaseModel): + success: bool + items: list[dict] + page: int + limit: int + total_items: int + total_pages: int + + @router.get( "", summary="Retrieve Media Items", description="Fetch media items with optional filters and pagination", + operation_id="get_items", ) async def get_items( _: Request, limit: Optional[int] = 50, page: Optional[int] = 1, type: Optional[str] = None, - state: Optional[str] = None, - sort: Optional[str] = "date_desc", + states: Optional[str] = None, + sort: Optional[ + Literal["date_desc", "date_asc", "title_asc", "title_desc"] + ] = "date_desc", search: Optional[str] = None, extended: Optional[bool] = False, -): + is_anime: Optional[bool] = False, +) -> ItemsResponse: if page < 1: raise HTTPException(status_code=400, detail="Page number must be 1 or greater.") @@ -77,37 +100,55 @@ async def get_items( query = query.where(MediaItem.imdb_id == search_lower) else: query = query.where( - (func.lower(MediaItem.title).like(f"%{search_lower}%")) | - (func.lower(MediaItem.imdb_id).like(f"%{search_lower}%")) + (func.lower(MediaItem.title).like(f"%{search_lower}%")) + | (func.lower(MediaItem.imdb_id).like(f"%{search_lower}%")) ) - if state: - filter_lower = state.lower() - filter_state = None - for state_enum in States: - if Levenshtein.ratio(filter_lower, state_enum.name.lower()) <= 0.82: - filter_state = state_enum - break - if filter_state: - query = query.where(MediaItem.last_state == filter_state) + if states: + states = states.split(",") + filter_states = [] + for state in states: + filter_lower = state.lower() + for state_enum in States: + if Levenshtein.ratio(filter_lower, state_enum.name.lower()) >= 0.82: + filter_states.append(state_enum) + break + if len(filter_states) == len(states): + query = query.where(MediaItem.last_state in filter_states) else: valid_states = [state_enum.name for state_enum in States] raise HTTPException( status_code=400, - detail=f"Invalid filter state: {state}. Valid states are: {valid_states}", + detail=f"Invalid filter states: {states}. Valid states are: {valid_states}", ) if type: if "," in type: types = type.split(",") for type in types: - if type not in ["movie", "show", "season", "episode"]: + if type not in ["movie", "show", "season", "episode", "anime"]: raise HTTPException( status_code=400, - detail=f"Invalid type: {type}. Valid types are: ['movie', 'show', 'season', 'episode']") + detail=f"Invalid type: {type}. Valid types are: ['movie', 'show', 'season', 'episode', 'anime']", + ) else: - types=[type] - query = query.where(MediaItem.type.in_(types)) + types = [type] + if "anime" in types: + types = [type for type in types if type != "anime"] + query = query.where( + or_( + and_( + MediaItem.type.in_(["movie", "show"]), + MediaItem.is_anime == True, + ), + MediaItem.type.in_(types), + ) + ) + else: + query = query.where(MediaItem.type.in_(types)) + + if is_anime: + query = query.where(MediaItem.is_anime is True) if sort and not search: sort_lower = sort.lower() @@ -126,14 +167,24 @@ async def get_items( ) with db.Session() as session: - total_items = session.execute(select(func.count()).select_from(query.subquery())).scalar_one() - items = session.execute(query.offset((page - 1) * limit).limit(limit)).unique().scalars().all() + total_items = session.execute( + select(func.count()).select_from(query.subquery()) + ).scalar_one() + items = ( + session.execute(query.offset((page - 1) * limit).limit(limit)) + .unique() + .scalars() + .all() + ) total_pages = (total_items + limit - 1) // limit return { "success": True, - "items": [item.to_extended_dict() if extended else item.to_dict() for item in items], + "items": [ + item.to_extended_dict() if extended else item.to_dict() + for item in items + ], "page": page, "limit": limit, "total_items": total_items, @@ -142,14 +193,12 @@ async def get_items( @router.post( - "/add", - summary="Add Media Items", - description="Add media items with bases on imdb IDs", + "/add", + summary="Add Media Items", + description="Add media items with bases on imdb IDs", + operation_id="add_items", ) -async def add_items( - request: Request, imdb_ids: str = None -): - +async def add_items(request: Request, imdb_ids: str = None) -> MessageResponse: if not imdb_ids: raise HTTPException(status_code=400, detail="No IMDb ID(s) provided") @@ -167,47 +216,67 @@ async def add_items( with db.Session() as _: for id in valid_ids: - item = MediaItem({"imdb_id": id, "requested_by": "riven", "requested_at": datetime.now()}) + item = MediaItem( + {"imdb_id": id, "requested_by": "riven", "requested_at": datetime.now()} + ) request.app.program.em.add_item(item) - return {"success": True, "message": f"Added {len(valid_ids)} item(s) to the queue"} + return {"message": f"Added {len(valid_ids)} item(s) to the queue"} + @router.get( "/{id}", summary="Retrieve Media Item", description="Fetch a single media item by ID", + operation_id="get_item", ) -async def get_item(request: Request, id: int): +async def get_item(_: Request, id: int, use_tmdb_id: Optional[bool] = False) -> dict: with db.Session() as session: try: - item = session.execute(select(MediaItem).where(MediaItem._id == id)).unique().scalar_one() + query = select(MediaItem) + if use_tmdb_id: + query = query.where(MediaItem.tmdb_id == str(id)) + else: + query = query.where(MediaItem._id == id) + item = session.execute(query).unique().scalar_one() except NoResultFound: raise HTTPException(status_code=404, detail="Item not found") - return {"success": True, "item": item.to_extended_dict()} + return item.to_extended_dict(with_streams=False) + @router.get( "/{imdb_ids}", summary="Retrieve Media Items By IMDb IDs", description="Fetch media items by IMDb IDs", + operation_id="get_items_by_imdb_ids", ) -async def get_items_by_imdb_ids(request: Request, imdb_ids: str): +async def get_items_by_imdb_ids(request: Request, imdb_ids: str) -> list[dict]: ids = imdb_ids.split(",") with db.Session() as session: items = [] for id in ids: - item = session.execute(select(MediaItem).where(MediaItem.imdb_id == id)).unique().scalar_one() + item = ( + session.execute(select(MediaItem).where(MediaItem.imdb_id == id)) + .unique() + .scalar_one() + ) if item: items.append(item) - return {"success": True, "items": [item.to_extended_dict() for item in items]} + return [item.to_extended_dict() for item in items] + + +class ResetResponse(BaseModel): + message: str + ids: list[str] + @router.post( - "/reset", - summary="Reset Media Items", - description="Reset media items with bases on item IDs", + "/reset", + summary="Reset Media Items", + description="Reset media items with bases on item IDs", + operation_id="reset_items", ) -async def reset_items( - request: Request, ids: str -): +async def reset_items(request: Request, ids: str) -> ResetResponse: ids = handle_ids(ids) try: media_items = get_media_items_by_ids(ids) @@ -222,15 +291,22 @@ async def reset_items( logger.error(f"Failed to reset item with id {media_item._id}: {str(e)}") continue except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - return {"success": True, "message": f"Reset items with id {ids}"} + raise HTTPException(status_code=400, detail=str(e)) from e + return {"message": f"Reset items with id {ids}", "ids": ids} + + +class RetryResponse(BaseModel): + message: str + ids: list[str] + @router.post( - "/retry", - summary="Retry Media Items", - description="Retry media items with bases on item IDs", + "/retry", + summary="Retry Media Items", + description="Retry media items with bases on item IDs", + operation_id="retry_items", ) -async def retry_items(request: Request, ids: str): +async def retry_items(request: Request, ids: str) -> RetryResponse: ids = handle_ids(ids) try: media_items = get_media_items_by_ids(ids) @@ -238,19 +314,26 @@ async def retry_items(request: Request, ids: str): raise ValueError("Invalid item ID(s) provided. Some items may not exist.") for media_item in media_items: request.app.program.em.cancel_job(media_item) - await asyncio.sleep(0.1) # Ensure cancellation is processed + await asyncio.sleep(0.1) # Ensure cancellation is processed request.app.program.em.add_item(media_item) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - return {"success": True, "message": f"Retried items with ids {ids}"} + return {"message": f"Retried items with ids {ids}", "ids": ids} + + +class RemoveResponse(BaseModel): + message: str + ids: list[int] + @router.delete( "/remove", summary="Remove Media Items", description="Remove media items based on item IDs", + operation_id="remove_item", ) -async def remove_item(request: Request, ids: str): +async def remove_item(request: Request, ids: str) -> RemoveResponse: ids = handle_ids(ids) try: media_items = get_parent_items_by_ids(ids) @@ -259,30 +342,45 @@ async def remove_item(request: Request, ids: str): for media_item in media_items: logger.debug(f"Removing item {media_item.title} with ID {media_item._id}") request.app.program.em.cancel_job(media_item) - await asyncio.sleep(0.1) # Ensure cancellation is processed + await asyncio.sleep(0.1) # Ensure cancellation is processed clear_streams(media_item) symlink_service = request.app.program.services.get(Symlinker) if symlink_service: symlink_service.delete_item_symlinks(media_item) if media_item.requested_by == "overseerr" and media_item.requested_id: - logger.debug(f"Item was originally requested by Overseerr, deleting request within Overseerr...") + logger.debug( + f"Item was originally requested by Overseerr, deleting request within Overseerr..." + ) Overseerr.delete_request(media_item.requested_id) delete_media_item(media_item) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - return {"success": True, "message": f"Removed items with ids {ids}"} + return {"message": f"Removed items with ids {ids}", "ids": ids} + + +class SetTorrentRDResponse(BaseModel): + message: str + item_id: int + torrent_id: str + -@router.post("/{id}/set_torrent_rd_magnet", description="Set a torrent for a media item using a magnet link.") -def add_torrent(request: Request, id: int, magnet: str): +@router.post( + "/{id}/set_torrent_rd_magnet", + name="Set torrent RD magnet", + description="Set a torrent for a media item using a magnet link.", + operation_id="set_torrent_rd_magnet", +) +def add_torrent(request: Request, id: int, magnet: str) -> SetTorrentRDResponse: torrent_id = "" try: torrent_id = add_torrent_magnet(magnet) except Exception: raise HTTPException(status_code=500, detail="Failed to add torrent.") from None - + return set_torrent_rd(request, id, torrent_id) + def reset_item_to_scraped(item: MediaItem): item.last_state = States.Scraped item.symlinked = False @@ -293,30 +391,48 @@ def reset_item_to_scraped(item: MediaItem): item.file = None item.folder = None + def create_stream(hash, torrent_info): try: torrent: Torrent = rtn.rank( - raw_title=torrent_info["filename"], - infohash=hash, - remove_trash=False + raw_title=torrent_info["filename"], infohash=hash, remove_trash=False ) return Stream(torrent) except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to rank torrent: {e}") from e + raise HTTPException( + status_code=500, detail=f"Failed to rank torrent: {e}" + ) from e -@router.post("/{id}/set_torrent_rd", description="Set a torrent for a media item using RD torrent ID.") -def set_torrent_rd(request: Request, id: int, torrent_id: str): +@router.post( + "/{id}/set_torrent_rd", + description="Set a torrent for a media item using RD torrent ID.", +) +def set_torrent_rd(request: Request, id: int, torrent_id: str) -> SetTorrentRDResponse: downloader: Downloader = request.app.program.services.get(Downloader) with db.Session() as session: - item: MediaItem = session.execute(select(MediaItem).where(MediaItem._id == id).outerjoin(MediaItem.streams)).unique().scalar_one_or_none() + item: MediaItem = ( + session.execute( + select(MediaItem) + .where(MediaItem._id == id) + .outerjoin(MediaItem.streams) + ) + .unique() + .scalar_one_or_none() + ) if item is None: raise HTTPException(status_code=404, detail="Item not found") fetched_torrent_info = torrent_info(torrent_id) - stream = session.execute(select(Stream).where(Stream.infohash == fetched_torrent_info["hash"])).scalars().first() + stream = ( + session.execute( + select(Stream).where(Stream.infohash == fetched_torrent_info["hash"]) + ) + .scalars() + .first() + ) hash = fetched_torrent_info["hash"] # Create stream if it doesn't exist @@ -325,10 +441,12 @@ def set_torrent_rd(request: Request, id: int, torrent_id: str): item.streams.append(stream) # check if the stream exists in the item - stream_exists_in_item = next((stream for stream in item.streams if stream.infohash == hash), None) + stream_exists_in_item = next( + (stream for stream in item.streams if stream.infohash == hash), None + ) if stream_exists_in_item is None: item.streams.append(stream) - + reset_item_to_scraped(item) # reset episodes if it's a season @@ -342,7 +460,10 @@ def set_torrent_rd(request: Request, id: int, torrent_id: str): if len(cached_streams) == 0: session.rollback() - raise HTTPException(status_code=400, detail=f"No cached torrents found for {item.log_string}") + raise HTTPException( + status_code=400, + detail=f"No cached torrents found for {item.log_string}", + ) item.active_stream = cached_streams[0] try: @@ -352,13 +473,21 @@ def set_torrent_rd(request: Request, id: int, torrent_id: str): if item.active_stream.get("infohash", None): downloader._delete_and_reset_active_stream(item) session.rollback() - raise HTTPException(status_code=500, detail=f"Failed to download {item.log_string}: {e}") from e + raise HTTPException( + status_code=500, detail=f"Failed to download {item.log_string}: {e}" + ) from e session.commit() request.app.program.em.add_event(Event("Symlinker", item)) - return {"success": True, "message": f"Set torrent for {item.title} to {torrent_id}"} + return { + "success": True, + "message": f"Set torrent for {item.title} to {torrent_id}", + "item_id": item._id, + "torrent_id": torrent_id, + } + # These require downloaders to be refactored @@ -397,4 +526,4 @@ def set_torrent_rd(request: Request, id: int, torrent_id: str): # item.reset(True) # downloader.download_cached(item, hash) # request.app.program.add_to_queue(item) -# return {"success": True, "message": f"Downloading {item.title} with hash {hash}"} \ No newline at end of file +# return {"success": True, "message": f"Downloading {item.title} with hash {hash}"} diff --git a/src/controllers/models/shared.py b/src/controllers/models/shared.py new file mode 100644 index 00000000..53b5fefc --- /dev/null +++ b/src/controllers/models/shared.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class MessageResponse(BaseModel): + message: str \ No newline at end of file diff --git a/src/controllers/scrape.py b/src/controllers/scrape.py index deb833da..e975705b 100644 --- a/src/controllers/scrape.py +++ b/src/controllers/scrape.py @@ -1,29 +1,38 @@ """Scrape controller.""" + from fastapi import APIRouter, HTTPException, Request -from sqlalchemy import select -from program.scrapers import Scraping +from program.db.db import db +from program.downloaders.realdebrid import RDTorrent, get_torrents from program.indexers.trakt import TraktIndexer from program.media.item import MediaItem -from program.db.db import db +from program.scrapers import Scraping +from pydantic import BaseModel +from sqlalchemy import select + +router = APIRouter(prefix="/scrape", tags=["scrape"]) -router = APIRouter( - prefix="/scrape", - tags=["scrape"] -) + +class ScrapedTorrent(BaseModel): + rank: int + raw_title: str + infohash: str @router.get( "", summary="Scrape Media Item", - description="Scrape media item based on IMDb ID." + description="Scrape media item based on IMDb ID.", + operation_id="scrape", ) -async def scrape(request: Request, imdb_id: str, season: int = None, episode: int = None): +async def scrape( + request: Request, imdb_id: str, season: int = None, episode: int = None +) -> list[ScrapedTorrent]: """ Scrape media item based on IMDb ID. - **imdb_id**: IMDb ID of the media item. """ - if (services := request.app.program.services): + if services := request.app.program.services: scraping = services[Scraping] indexer = services[TraktIndexer] else: @@ -31,12 +40,16 @@ async def scrape(request: Request, imdb_id: str, season: int = None, episode: in try: with db.Session() as session: - media_item = session.execute( - select(MediaItem).where( - MediaItem.imdb_id == imdb_id, - MediaItem.type.in_(["movie", "show"]) + media_item = ( + session.execute( + select(MediaItem).where( + MediaItem.imdb_id == imdb_id, + MediaItem.type.in_(["movie", "show"]), + ) ) - ).unique().scalar_one_or_none() + .unique() + .scalar_one_or_none() + ) if not media_item: indexed_item = MediaItem({"imdb_id": imdb_id}) media_item = next(indexer.run(indexed_item)) @@ -48,7 +61,14 @@ async def scrape(request: Request, imdb_id: str, season: int = None, episode: in if media_item.type == "show": if season and episode: - media_item = next((ep for ep in media_item.seasons[season - 1].episodes if ep.number == episode), None) + media_item = next( + ( + ep + for ep in media_item.seasons[season - 1].episodes + if ep.number == episode + ), + None, + ) if not media_item: raise HTTPException(status_code=204, detail="Episode not found") elif season: @@ -56,7 +76,10 @@ async def scrape(request: Request, imdb_id: str, season: int = None, episode: in if not media_item: raise HTTPException(status_code=204, detail="Season not found") elif media_item.type == "movie" and (season or episode): - raise HTTPException(status_code=204, detail="Item type returned movie, cannot scrape season or episode") + raise HTTPException( + status_code=204, + detail="Item type returned movie, cannot scrape season or episode", + ) results = scraping.scrape(media_item, log=False) if not results: @@ -66,8 +89,9 @@ async def scrape(request: Request, imdb_id: str, season: int = None, episode: in { "raw_title": stream.raw_title, "infohash": stream.infohash, - "rank": stream.rank - } for stream in results.values() + "rank": stream.rank, + } + for stream in results.values() ] except StopIteration as e: @@ -75,4 +99,18 @@ async def scrape(request: Request, imdb_id: str, season: int = None, episode: in except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - return {"success": True, "data": data} + return data + +@router.get( + "/rd", + summary="Get Real-Debrid Torrents", + description="Get torrents from Real-Debrid.", + operation_id="get_rd_torrents", +) +async def get_rd_torrents(limit: int = 1000) -> list[RDTorrent]: + """ + Get torrents from Real-Debrid. + + - **limit**: Limit the number of torrents to get. + """ + return get_torrents(limit) diff --git a/src/controllers/settings.py b/src/controllers/settings.py index 671c9426..d02375c4 100644 --- a/src/controllers/settings.py +++ b/src/controllers/settings.py @@ -1,10 +1,11 @@ from copy import copy from typing import Any, Dict, List +from controllers.models.shared import MessageResponse from fastapi import APIRouter, HTTPException -from pydantic import BaseModel, ValidationError - from program.settings.manager import settings_manager +from program.settings.models import AppModel +from pydantic import BaseModel, ValidationError class SetSettings(BaseModel): @@ -19,42 +20,36 @@ class SetSettings(BaseModel): ) -@router.get("/schema") -async def get_settings_schema(): +@router.get("/schema", operation_id="get_settings_schema") +async def get_settings_schema() -> dict[str, Any]: """ Get the JSON schema for the settings. """ return settings_manager.settings.model_json_schema() -@router.get("/load") -async def load_settings(): +@router.get("/load", operation_id="load_settings") +async def load_settings() -> MessageResponse: settings_manager.load() return { - "success": True, "message": "Settings loaded!", } - -@router.post("/save") -async def save_settings(): +@router.post("/save", operation_id="save_settings") +async def save_settings() -> MessageResponse: settings_manager.save() return { - "success": True, "message": "Settings saved!", } -@router.get("/get/all") -async def get_all_settings(): - return { - "success": True, - "data": copy(settings_manager.settings), - } +@router.get("/get/all", operation_id="get_all_settings") +async def get_all_settings() -> AppModel: + return copy(settings_manager.settings) -@router.get("/get/{paths}") -async def get_settings(paths: str): - current_settings = settings_manager.settings.dict() +@router.get("/get/{paths}", operation_id="get_settings") +async def get_settings(paths: str) -> dict[str, Any]: + current_settings = settings_manager.settings.model_dump() data = {} for path in paths.split(","): keys = path.split(".") @@ -66,15 +61,11 @@ async def get_settings(paths: str): current_obj = current_obj[k] data[path] = current_obj + return data - return { - "success": True, - "data": data, - } - -@router.post("/set/all") -async def set_all_settings(new_settings: Dict[str, Any]): +@router.post("/set/all", operation_id="set_all_settings") +async def set_all_settings(new_settings: Dict[str, Any]) -> MessageResponse: current_settings = settings_manager.settings.model_dump() def update_settings(current_obj, new_obj): @@ -95,12 +86,11 @@ def update_settings(current_obj, new_obj): raise HTTPException(status_code=400, detail=str(e)) return { - "success": True, "message": "All settings updated successfully!", } -@router.post("/set") -async def set_settings(settings: List[SetSettings]): +@router.post("/set", operation_id="set_settings") +async def set_settings(settings: List[SetSettings]) -> MessageResponse: current_settings = settings_manager.settings.model_dump() for setting in settings: @@ -136,5 +126,4 @@ async def set_settings(settings: List[SetSettings]): detail=f"Failed to update settings: {str(e)}", ) - return {"success": True, "message": "Settings updated successfully."} - + return {"message": "Settings updated successfully."} diff --git a/src/controllers/tmdb.py b/src/controllers/tmdb.py index e2db8ce7..a9b930f2 100644 --- a/src/controllers/tmdb.py +++ b/src/controllers/tmdb.py @@ -1,10 +1,21 @@ from enum import Enum -from typing import Annotated +from typing import Annotated, Generic, Optional, TypeVar from urllib.parse import urlencode from fastapi import APIRouter, Depends - -from program.indexers.tmdb import tmdb +from program.indexers.tmdb import ( + TmdbCollectionDetails, + TmdbEpisodeDetails, + TmdbFindResults, + TmdbItem, + TmdbMovieDetails, + TmdbPagedResults, + TmdbPagedResultsWithDates, + TmdbSeasonDetails, + TmdbTVDetails, + tmdb, +) +from pydantic import BaseModel router = APIRouter( prefix="/tmdb", @@ -140,12 +151,21 @@ def __init__( self.year = year -@router.get("/trending/{type}/{window}") +T = TypeVar("T") + + +class TmdbResponse(BaseModel, Generic[T]): + success: bool + data: Optional[T] = None + message: Optional[str] = None + + +@router.get("/trending/{type}/{window}", operation_id="get_trending") async def get_trending( params: Annotated[TrendingParams, Depends()], type: TrendingType, window: TrendingWindow, -): +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: trending = tmdb.getTrending( params=dict_to_query_string(params.__dict__), type=type.value, @@ -163,8 +183,10 @@ async def get_trending( } -@router.get("/movie/now_playing") -async def get_movies_now_playing(params: Annotated[CommonListParams, Depends()]): +@router.get("/movie/now_playing", operation_id="get_movies_now_playing") +async def get_movies_now_playing( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResultsWithDates[TmdbItem]]: # noqa: F821 movies = tmdb.getMoviesNowPlaying(params=dict_to_query_string(params.__dict__)) if movies: return { @@ -178,8 +200,10 @@ async def get_movies_now_playing(params: Annotated[CommonListParams, Depends()]) } -@router.get("/movie/popular") -async def get_movies_popular(params: Annotated[CommonListParams, Depends()]): +@router.get("/movie/popular", operation_id="get_movies_popular") +async def get_movies_popular( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: movies = tmdb.getMoviesPopular(params=dict_to_query_string(params.__dict__)) if movies: return { @@ -193,8 +217,10 @@ async def get_movies_popular(params: Annotated[CommonListParams, Depends()]): } -@router.get("/movie/top_rated") -async def get_movies_top_rated(params: Annotated[CommonListParams, Depends()]): +@router.get("/movie/top_rated", operation_id="get_movies_top_rated") +async def get_movies_top_rated( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: movies = tmdb.getMoviesTopRated(params=dict_to_query_string(params.__dict__)) if movies: return { @@ -208,8 +234,10 @@ async def get_movies_top_rated(params: Annotated[CommonListParams, Depends()]): } -@router.get("/movie/upcoming") -async def get_movies_upcoming(params: Annotated[CommonListParams, Depends()]): +@router.get("/movie/upcoming", operation_id="get_movies_upcoming") +async def get_movies_upcoming( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResultsWithDates[TmdbItem]]: movies = tmdb.getMoviesUpcoming(params=dict_to_query_string(params.__dict__)) if movies: return { @@ -226,11 +254,11 @@ async def get_movies_upcoming(params: Annotated[CommonListParams, Depends()]): # FastAPI has router preference, so /movie/now_playing, /movie/popular, /movie/top_rated and /movie/upcoming will be matched first before /movie/{movie_id}, same for /tv/{tv_id} -@router.get("/movie/{movie_id}") +@router.get("/movie/{movie_id}", operation_id="get_movie_details") async def get_movie_details( movie_id: str, params: Annotated[DetailsParams, Depends()], -): +) -> TmdbResponse[TmdbMovieDetails]: data = tmdb.getMovieDetails( params=dict_to_query_string(params.__dict__), movie_id=movie_id, @@ -247,8 +275,10 @@ async def get_movie_details( } -@router.get("/tv/airing_today") -async def get_tv_airing_today(params: Annotated[CommonListParams, Depends()]): +@router.get("/tv/airing_today", operation_id="get_tv_airing_today") +async def get_tv_airing_today( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: tv = tmdb.getTVAiringToday(params=dict_to_query_string(params.__dict__)) if tv: return { @@ -262,8 +292,10 @@ async def get_tv_airing_today(params: Annotated[CommonListParams, Depends()]): } -@router.get("/tv/on_the_air") -async def get_tv_on_the_air(params: Annotated[CommonListParams, Depends()]): +@router.get("/tv/on_the_air", operation_id="get_tv_on_the_air") +async def get_tv_on_the_air( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: tv = tmdb.getTVOnTheAir(params=dict_to_query_string(params.__dict__)) if tv: return { @@ -277,8 +309,10 @@ async def get_tv_on_the_air(params: Annotated[CommonListParams, Depends()]): } -@router.get("/tv/popular") -async def get_tv_popular(params: Annotated[CommonListParams, Depends()]): +@router.get("/tv/popular", operation_id="get_tv_popular") +async def get_tv_popular( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: tv = tmdb.getTVPopular(params=dict_to_query_string(params.__dict__)) if tv: return { @@ -292,8 +326,10 @@ async def get_tv_popular(params: Annotated[CommonListParams, Depends()]): } -@router.get("/tv/top_rated") -async def get_tv_top_rated(params: Annotated[CommonListParams, Depends()]): +@router.get("/tv/top_rated", operation_id="get_tv_top_rated") +async def get_tv_top_rated( + params: Annotated[CommonListParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: tv = tmdb.getTVTopRated(params=dict_to_query_string(params.__dict__)) if tv: return { @@ -307,11 +343,11 @@ async def get_tv_top_rated(params: Annotated[CommonListParams, Depends()]): } -@router.get("/tv/{series_id}") +@router.get("/tv/{series_id}", operation_id="get_tv_details") async def get_tv_details( series_id: str, params: Annotated[DetailsParams, Depends()], -): +) -> TmdbResponse[TmdbTVDetails]: data = tmdb.getTVDetails( params=dict_to_query_string(params.__dict__), series_id=series_id, @@ -328,12 +364,14 @@ async def get_tv_details( } -@router.get("/tv/{series_id}/season/{season_number}") +@router.get( + "/tv/{series_id}/season/{season_number}", operation_id="get_tv_season_details" +) async def get_tv_season_details( series_id: int, season_number: int, params: Annotated[DetailsParams, Depends()], -): +) -> TmdbResponse[TmdbSeasonDetails]: data = tmdb.getTVSeasonDetails( params=dict_to_query_string(params.__dict__), series_id=series_id, @@ -351,13 +389,16 @@ async def get_tv_season_details( } -@router.get("/tv/{series_id}/season/{season_number}/episode/{episode_number}") +@router.get( + "/tv/{series_id}/season/{season_number}/episode/{episode_number}", + operation_id="get_tv_episode_details", +) async def get_tv_episode_details( series_id: int, season_number: int, episode_number: int, params: Annotated[DetailsParams, Depends()], -): +) -> TmdbResponse[TmdbEpisodeDetails]: data = tmdb.getTVSeasonEpisodeDetails( params=dict_to_query_string(params.__dict__), series_id=series_id, @@ -376,8 +417,10 @@ async def get_tv_episode_details( } -@router.get("/search/collection") -async def search_collection(params: Annotated[CollectionSearchParams, Depends()]): +@router.get("/search/collection", operation_id="search_collection") +async def search_collection( + params: Annotated[CollectionSearchParams, Depends()], +) -> TmdbResponse[TmdbPagedResults[TmdbCollectionDetails]]: data = tmdb.getCollectionSearch(params=dict_to_query_string(params.__dict__)) if data: return { @@ -391,8 +434,8 @@ async def search_collection(params: Annotated[CollectionSearchParams, Depends()] } -@router.get("/search/movie") -async def search_movie(params: Annotated[MovieSearchParams, Depends()]): +@router.get("/search/movie", operation_id="search_movie") +async def search_movie(params: Annotated[MovieSearchParams, Depends()]) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: data = tmdb.getMovieSearch(params=dict_to_query_string(params.__dict__)) if data: return { @@ -406,8 +449,8 @@ async def search_movie(params: Annotated[MovieSearchParams, Depends()]): } -@router.get("/search/multi") -async def search_multi(params: Annotated[MultiSearchParams, Depends()]): +@router.get("/search/multi", operation_id="search_multi") +async def search_multi(params: Annotated[MultiSearchParams, Depends()]) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: data = tmdb.getMultiSearch(params=dict_to_query_string(params.__dict__)) if data: return { @@ -421,8 +464,8 @@ async def search_multi(params: Annotated[MultiSearchParams, Depends()]): } -@router.get("/search/tv") -async def search_tv(params: Annotated[TVSearchParams, Depends()]): +@router.get("/search/tv", operation_id="search_tv") +async def search_tv(params: Annotated[TVSearchParams, Depends()]) -> TmdbResponse[TmdbPagedResults[TmdbItem]]: data = tmdb.getTVSearch(params=dict_to_query_string(params.__dict__)) if data: return { @@ -436,11 +479,11 @@ async def search_tv(params: Annotated[TVSearchParams, Depends()]): } -@router.get("/external_id/{external_id}") +@router.get("/external_id/{external_id}", operation_id="get_from_external_id") async def get_from_external_id( external_id: str, params: Annotated[ExternalIDParams, Depends()], -): +) -> TmdbResponse[TmdbFindResults]: data = tmdb.getFromExternalID( params=dict_to_query_string(params.__dict__), external_id=external_id, diff --git a/src/controllers/webhooks.py b/src/controllers/webhooks.py index 891bc688..d25a1c3b 100644 --- a/src/controllers/webhooks.py +++ b/src/controllers/webhooks.py @@ -2,12 +2,11 @@ import pydantic from fastapi import APIRouter, Request -from requests import RequestException - from program.content.overseerr import Overseerr from program.db.db_functions import _ensure_item_exists_in_db from program.indexers.trakt import get_imdbid_from_tmdb, get_imdbid_from_tvdb from program.media.item import MediaItem +from requests import RequestException from utils.logger import logger from .models.overseerr import OverseerrWebhook diff --git a/src/controllers/ws.py b/src/controllers/ws.py index 6693b7b7..5d622aeb 100644 --- a/src/controllers/ws.py +++ b/src/controllers/ws.py @@ -1,5 +1,4 @@ from fastapi import WebSocket - from utils.websockets import manager from .default import router diff --git a/src/program/downloaders/realdebrid.py b/src/program/downloaders/realdebrid.py index c7780fdd..58df3ad6 100644 --- a/src/program/downloaders/realdebrid.py +++ b/src/program/downloaders/realdebrid.py @@ -1,7 +1,10 @@ from datetime import datetime +from enum import Enum +from typing import Optional, Union from loguru import logger from program.settings.manager import settings_manager as settings +from pydantic import BaseModel, TypeAdapter from requests import ConnectTimeout from utils import request from utils.ratelimiter import RateLimiter @@ -13,13 +16,39 @@ torrent_limiter = RateLimiter(1, 1) overall_limiter = RateLimiter(60, 60) +class RDTorrentStatus(str, Enum): + magnet_error = "magnet_error" + magnet_conversion = "magnet_conversion" + waiting_files_selection = "waiting_files_selection" + downloading = "downloading" + downloaded = "downloaded" + error = "error" + seeding = "seeding" + dead = "dead" + uploading = "uploading" + compressing = "compressing" + +class RDTorrent(BaseModel): + id: str + hash: str + filename: str + bytes: int + status: RDTorrentStatus + added: datetime + links: list[str] + ended: Optional[datetime] = None + speed: Optional[int] = None + seeders: Optional[int] = None + +rd_torrent_list = TypeAdapter(list[RDTorrent]) + class RealDebridDownloader: def __init__(self): self.key = "realdebrid" self.settings = settings.settings.downloaders.real_debrid self.initialized = self.validate() if self.initialized: - self.existing_hashes = [torrent["hash"] for torrent in get_torrents(1000)] + self.existing_hashes = [torrent.hash for torrent in get_torrents(1000)] self.file_finder = FileFinder("filename", "filesize") def validate(self) -> bool: @@ -103,12 +132,12 @@ def all_files_valid(file_dict: dict) -> bool: break return cached_containers - def get_torrent_names(self, id: str) -> dict: + def get_torrent_names(self, id: str) -> tuple[str,str]: info = torrent_info(id) return (info["filename"], info["original_filename"]) def delete_torrent_with_infohash(self, infohash: str): - id = next(torrent["id"] for torrent in get_torrents(1000) if torrent["hash"] == infohash) + id = next(torrent.id for torrent in get_torrents(1000) if torrent.hash == infohash) if id: delete_torrent(id) @@ -174,9 +203,9 @@ def torrent_info(id: str) -> dict: info = {} return info -def get_torrents(limit: int) -> list[dict]: +def get_torrents(limit: int) -> list[RDTorrent]: try: - torrents = get(f"torrents?limit={str(limit)}") + torrents = rd_torrent_list.validate_python(get(f"torrents?limit={str(limit)}")) except: logger.warning("Failed to get torrents.") torrents = [] diff --git a/src/program/indexers/tmdb.py b/src/program/indexers/tmdb.py index 3868f606..2ff9d816 100644 --- a/src/program/indexers/tmdb.py +++ b/src/program/indexers/tmdb.py @@ -1,3 +1,8 @@ +from datetime import date +from enum import Enum +from typing import Generic, Literal, Optional, TypeVar + +from pydantic import BaseModel from utils.logger import logger from utils.request import get @@ -5,6 +10,191 @@ # TODO: Maybe remove the else condition ? It's not necessary since exception is raised 400-450, 500-511, 408, 460, 504, 520, 524, 522, 598 and 599 +ItemT = TypeVar("ItemT") + +class TmdbMediaType(str, Enum): + movie = "movie" + tv = "tv" + episode = "tv_episode" + season = "tv_season" + + +class TmdbItem(BaseModel): + adult: bool + backdrop_path: Optional[str] + id: int + title: str + original_title: str + original_language: str + overview: str + poster_path: Optional[str] + media_type: Optional[TmdbMediaType] = None + genre_ids: list[int] + popularity: float + release_date: str + video: bool + vote_average: float + vote_count: int + +class TmdbEpisodeItem(BaseModel): + id: int + name: str + overview: str + media_type: Literal["tv_episode"] + vote_average: float + vote_count: int + air_date: date + episode_number: int + episode_type: str + production_code: str + runtime: int + season_number: int + show_id: int + still_path: str + +class TmdbSeasonItem(BaseModel): + id: int + name: str + overview: str + poster_path: str + media_type: Literal["tv_season"] + vote_average: float + air_date: date + season_number: int + show_id: int + episode_count: int + + +class TmdbPagedResults(BaseModel, Generic[ItemT]): + page: int + results: list[ItemT] + total_pages: int + total_results: int + +class TmdbPagedResultsWithDates(TmdbPagedResults[ItemT], Generic[ItemT]): + class Dates(BaseModel): + maximum: date + minimum: date + dates: Dates + +class TmdbFindResults(BaseModel): + movie_results: list[TmdbItem] + tv_results: list[TmdbItem] + tv_episode_results: list[TmdbEpisodeItem] + tv_season_results: list[TmdbSeasonItem] + +class Genre(BaseModel): + id: int + name: str + +class BelongsToCollection(BaseModel): + id: int + name: str + poster_path: Optional[str] + backdrop_path: Optional[str] + + +class ProductionCompany(BaseModel): + id: int + logo_path: Optional[str] + name: str + origin_country: str + + +class ProductionCountry(BaseModel): + iso_3166_1: str + name: str + + +class SpokenLanguage(BaseModel): + english_name: str + iso_639_1: str + name: str + +class Network(BaseModel): + id: int + logo_path: Optional[str] + name: str + origin_country: str + +class TmdbMovieDetails(BaseModel): + adult: bool + backdrop_path: Optional[str] + belongs_to_collection: Optional[BelongsToCollection] + budget: int + genres: list[Genre] + homepage: Optional[str] + id: int + imdb_id: Optional[str] + original_language: str + original_title: str + overview: Optional[str] + popularity: float + poster_path: Optional[str] + production_companies: list[ProductionCompany] + production_countries: list[ProductionCountry] + release_date: Optional[str] + revenue: int + runtime: Optional[int] + spoken_languages: list[SpokenLanguage] + status: Optional[str] + tagline: Optional[str] + title: str + video: bool + vote_average: float + vote_count: int + +class TmdbTVDetails(BaseModel): + adult: bool + backdrop_path: Optional[str] + episode_run_time: list[int] + first_air_date: str + genres: list[Genre] + homepage: Optional[str] + id: int + in_production: bool + languages: list[str] + last_air_date: Optional[str] + last_episode_to_air: Optional[TmdbEpisodeItem] + name: str + next_episode_to_air: Optional[str] + networks: list[Network] + number_of_episodes: int + number_of_seasons: int + origin_country: list[str] + original_language: str + original_name: str + overview: Optional[str] + popularity: float + poster_path: Optional[str] + production_companies: list[ProductionCompany] + production_countries: list[ProductionCountry] + seasons: list[TmdbSeasonItem] + spoken_languages: list[str] + status: Optional[str] + tagline: Optional[str] + type: Optional[str] + vote_average: float + vote_count: int + +class TmdbCollectionDetails(BaseModel): + adult: bool + backdrop_path: Optional[str] + id: int + name: str + overview: str + original_language: str + original_name: str + poster_path: Optional[str] + +class TmdbEpisodeDetails(TmdbEpisodeItem): + crew: list[dict] + guest_stars: list[dict] + +class TmdbSeasonDetails(BaseModel): + _id: str + air_date: str + episodes: list[TmdbEpisodeDetails] class TMDB: def __init__(self): @@ -13,7 +203,7 @@ def __init__(self): "Authorization": f"Bearer {TMDB_READ_ACCESS_TOKEN}", } - def getMoviesNowPlaying(self, params: str): + def getMoviesNowPlaying(self, params: str) -> Optional[TmdbPagedResultsWithDates[TmdbItem]]: url = f"{self.API_URL}/movie/now_playing?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -28,7 +218,7 @@ def getMoviesNowPlaying(self, params: str): ) return None - def getMoviesPopular(self, params: str): + def getMoviesPopular(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/movie/popular?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -41,7 +231,7 @@ def getMoviesPopular(self, params: str): logger.error(f"An error occurred while getting popular movies: {str(e)}") return None - def getMoviesTopRated(self, params: str): + def getMoviesTopRated(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/movie/top_rated?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -54,7 +244,7 @@ def getMoviesTopRated(self, params: str): logger.error(f"An error occurred while getting top rated movies: {str(e)}") return None - def getMoviesUpcoming(self, params: str): + def getMoviesUpcoming(self, params: str) -> Optional[TmdbPagedResultsWithDates[TmdbItem]]: url = f"{self.API_URL}/movie/upcoming?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -67,7 +257,7 @@ def getMoviesUpcoming(self, params: str): logger.error(f"An error occurred while getting upcoming movies: {str(e)}") return None - def getTrending(self, params: str, type: str, window: str): + def getTrending(self, params: str, type: str, window: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/trending/{type}/{window}?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -80,7 +270,7 @@ def getTrending(self, params: str, type: str, window: str): logger.error(f"An error occurred while getting trending {type}: {str(e)}") return None - def getTVAiringToday(self, params: str): + def getTVAiringToday(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/tv/airing_today?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -93,7 +283,7 @@ def getTVAiringToday(self, params: str): logger.error(f"An error occurred while getting TV airing today: {str(e)}") return None - def getTVOnTheAir(self, params: str): + def getTVOnTheAir(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/tv/on_the_air?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -106,7 +296,7 @@ def getTVOnTheAir(self, params: str): logger.error(f"An error occurred while getting TV on the air: {str(e)}") return None - def getTVPopular(self, params: str): + def getTVPopular(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/tv/popular?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -119,7 +309,7 @@ def getTVPopular(self, params: str): logger.error(f"An error occurred while getting popular TV shows: {str(e)}") return None - def getTVTopRated(self, params: str): + def getTVTopRated(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/tv/top_rated?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -134,7 +324,7 @@ def getTVTopRated(self, params: str): ) return None - def getFromExternalID(self, params: str, external_id: str): + def getFromExternalID(self, params: str, external_id: str) -> Optional[TmdbFindResults]: url = f"{self.API_URL}/find/{external_id}?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -147,7 +337,7 @@ def getFromExternalID(self, params: str, external_id: str): logger.error(f"An error occurred while getting from external ID: {str(e)}") return None - def getMovieDetails(self, params: str, movie_id: str): + def getMovieDetails(self, params: str, movie_id: str) -> Optional[TmdbMovieDetails]: url = f"{self.API_URL}/movie/{movie_id}?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -160,7 +350,7 @@ def getMovieDetails(self, params: str, movie_id: str): logger.error(f"An error occurred while getting movie details: {str(e)}") return None - def getTVDetails(self, params: str, series_id: str): + def getTVDetails(self, params: str, series_id: str) -> Optional[TmdbTVDetails]: url = f"{self.API_URL}/tv/{series_id}?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -173,7 +363,7 @@ def getTVDetails(self, params: str, series_id: str): logger.error(f"An error occurred while getting TV details: {str(e)}") return None - def getCollectionSearch(self, params: str): + def getCollectionSearch(self, params: str) -> Optional[TmdbPagedResults[TmdbCollectionDetails]]: url = f"{self.API_URL}/search/collection?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -186,7 +376,7 @@ def getCollectionSearch(self, params: str): logger.error(f"An error occurred while searching collections: {str(e)}") return None - def getMovieSearch(self, params: str): + def getMovieSearch(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/search/movie?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -199,7 +389,7 @@ def getMovieSearch(self, params: str): logger.error(f"An error occurred while searching movies: {str(e)}") return None - def getMultiSearch(self, params: str): + def getMultiSearch(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/search/multi?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -212,7 +402,7 @@ def getMultiSearch(self, params: str): logger.error(f"An error occurred while searching multi: {str(e)}") return None - def getTVSearch(self, params: str): + def getTVSearch(self, params: str) -> Optional[TmdbPagedResults[TmdbItem]]: url = f"{self.API_URL}/search/tv?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -225,7 +415,7 @@ def getTVSearch(self, params: str): logger.error(f"An error occurred while searching TV shows: {str(e)}") return None - def getTVSeasonDetails(self, params: str, series_id: int, season_number: int): + def getTVSeasonDetails(self, params: str, series_id: int, season_number: int) -> Optional[TmdbSeasonDetails]: url = f"{self.API_URL}/tv/{series_id}/season/{season_number}?{params}" try: response = get(url, additional_headers=self.HEADERS) @@ -240,7 +430,7 @@ def getTVSeasonDetails(self, params: str, series_id: int, season_number: int): def getTVSeasonEpisodeDetails( self, params: str, series_id: int, season_number: int, episode_number: int - ): + ) -> Optional[TmdbEpisodeDetails]: url = f"{self.API_URL}/tv/{series_id}/season/{season_number}/episode/{episode_number}?{params}" try: response = get(url, additional_headers=self.HEADERS) diff --git a/src/program/media/item.py b/src/program/media/item.py index 82dd42ea..96766a2e 100644 --- a/src/program/media/item.py +++ b/src/program/media/item.py @@ -222,30 +222,31 @@ def to_dict(self): "scraped_times": self.scraped_times, } - def to_extended_dict(self, abbreviated_children=False): + def to_extended_dict(self, abbreviated_children=False, with_streams=True): """Convert item to extended dictionary (API response)""" dict = self.to_dict() match self: case Show(): dict["seasons"] = ( - [season.to_extended_dict() for season in self.seasons] + [season.to_extended_dict(with_streams=with_streams) for season in self.seasons] if not abbreviated_children else self.represent_children ) case Season(): dict["episodes"] = ( - [episode.to_extended_dict() for episode in self.episodes] + [episode.to_extended_dict(with_streams=with_streams) for episode in self.episodes] if not abbreviated_children else self.represent_children ) dict["language"] = self.language if hasattr(self, "language") else None dict["country"] = self.country if hasattr(self, "country") else None dict["network"] = self.network if hasattr(self, "network") else None - dict["active_stream"] = ( - self.active_stream if hasattr(self, "active_stream") else None - ) - dict["streams"] = getattr(self, "streams", []) - dict["blacklisted_streams"] = getattr(self, "blacklisted_streams", []) + if with_streams: + dict["streams"] = getattr(self, "streams", []) + dict["blacklisted_streams"] = getattr(self, "blacklisted_streams", []) + dict["active_stream"] = ( + self.active_stream if hasattr(self, "active_stream") else None + ) dict["number"] = self.number if hasattr(self, "number") else None dict["symlinked"] = self.symlinked if hasattr(self, "symlinked") else None dict["symlinked_at"] = ( diff --git a/src/utils/event_manager.py b/src/utils/event_manager.py index 17e175b2..aaf4ed04 100644 --- a/src/utils/event_manager.py +++ b/src/utils/event_manager.py @@ -6,6 +6,7 @@ from threading import Lock from loguru import logger +from pydantic import BaseModel from sqlalchemy.orm.exc import StaleDataError from subliminal import Episode, Movie @@ -15,6 +16,14 @@ from program.media.item import Season, Show from program.types import Event +class EventUpdate(BaseModel): + item_id: int + imdb_id: str + title: str + type: str + emitted_by: str + run_at: str + last_state: str class EventManager: """ @@ -324,7 +333,7 @@ def add_item(self, item, service="Manual"): """ self.add_event(Event(service, item)) - def get_event_updates(self): + def get_event_updates(self) -> dict[str, list[EventUpdate]]: """ Returns a formatted list of event updates. @@ -335,6 +344,7 @@ def get_event_updates(self): event_types = ["Scraping", "Downloader", "Symlinker", "Updater", "PostProcessing"] return { event_type.lower(): [ + EventUpdate.model_validate( { "item_id": event.item._id, "imdb_id": event.item.imdb_id, @@ -343,7 +353,7 @@ def get_event_updates(self): "emitted_by": event.emitted_by if isinstance(event.emitted_by, str) else event.emitted_by.__name__, "run_at": event.run_at.isoformat(), "last_state": event.item.last_state.name if event.item.last_state else "N/A" - } + }) for event in events if event.emitted_by == event_type ] for event_type in event_types