Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Types for the FastAPI API and API refactor #748

Merged
merged 20 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
09d6962
feat: add response models to items.py
filiptrplan Sep 30, 2024
d17be82
feat: ignore ruff linter raising exceptions in except block
filiptrplan Sep 30, 2024
cbdc662
refactor: ruff format all the api files
filiptrplan Sep 30, 2024
4991636
feat: add response models to scrape.py
filiptrplan Sep 30, 2024
382a2bf
feat: add response models to settings.py
filiptrplan Sep 30, 2024
81fdd23
feat: add response models to tmdb.py
filiptrplan Sep 30, 2024
04bd9a9
feat: add missing type annotations to tmdb.py
filiptrplan Sep 30, 2024
e027f64
fix: add default values for some pydantic models
filiptrplan Sep 30, 2024
cbf6f29
feat: add types to default.py
filiptrplan Oct 2, 2024
84e9206
fix: bad pydantic types causing serialization error
filiptrplan Oct 3, 2024
fb9b21b
fix: add some model validation where needed
filiptrplan Oct 3, 2024
9384c42
feat: add mypy to dev dependencies for static type checking
filiptrplan Oct 3, 2024
837ed19
fix: wrong type in realdebrid
filiptrplan Oct 3, 2024
8a51e6f
feat: add some options for easier querying of items
filiptrplan Oct 3, 2024
aba502d
fix: pass with_streams argument in to_extended_dict to chidren
filiptrplan Oct 4, 2024
87109f0
feat: remove the old json response format from services and stats end…
filiptrplan Oct 4, 2024
bbaa73d
feat: migrate the settings api to the new response types
filiptrplan Oct 4, 2024
4765dd7
feat: add type annotation to get_all_settings
filiptrplan Oct 4, 2024
96f5c67
feat: migrate the rest of the APIs to the new response schema
filiptrplan Oct 4, 2024
f84af7f
fix: remove old imports
filiptrplan Oct 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ codecov = "^2.1.13"
httpx = "^0.27.0"
memray = "^1.13.4"
testcontainers = "^4.8.0"
mypy = "^1.11.2"

[tool.poetry.group.test]
optional = true
Expand Down Expand Up @@ -91,7 +92,8 @@ ignore = [
"S101", # ruff: Ignore assert warnings on tests
"RET505", #
"RET503", # ruff: Ignore required explicit returns (is this desired?)
"SLF001" # private member accessing from pickle
"SLF001", # private member accessing from pickle
"B904" # ruff: ignore raising exceptions from except for the API
]
extend-select = [
"I", # isort
Expand Down
140 changes: 98 additions & 42 deletions src/controllers/default.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,75 @@
from typing import Literal

import requests
from controllers.models.shared import MessageResponse
from fastapi import APIRouter, HTTPException, Request
from loguru import logger
from sqlalchemy import func, select

from program.content.trakt import TraktContent
from program.db.db import db
from program.media.item import Episode, MediaItem, Movie, Season, Show
from program.media.state import States
from program.settings.manager import settings_manager
from pydantic import BaseModel, Field
from sqlalchemy import func, select
from utils.event_manager import EventUpdate

router = APIRouter(
responses={404: {"description": "Not found"}},
)


class RootResponse(MessageResponse):
version: str


@router.get("/", operation_id="root")
async def root():
async def root() -> RootResponse:
return {
"success": True,
"message": "Riven is running!",
"version": settings_manager.settings.version,
}


@router.get("/health", operation_id="health")
async def health(request: Request):
async def health(request: Request) -> MessageResponse:
return {
"success": True,
"message": request.app.program.initialized,
}


class RDUser(BaseModel):
id: int
username: str
email: str
points: int = Field(description="User's RD points")
locale: str
avatar: str = Field(description="URL to the user's avatar")
type: Literal["free", "premium"]
premium: int = Field(description="Premium subscription left in seconds")


@router.get("/rd", operation_id="rd")
async def get_rd_user():
async def get_rd_user() -> RDUser:
api_key = settings_manager.settings.downloaders.real_debrid.api_key
headers = {"Authorization": f"Bearer {api_key}"}

proxy = settings_manager.settings.downloaders.real_debrid.proxy_url if settings_manager.settings.downloaders.real_debrid.proxy_enabled else None
proxy = (
settings_manager.settings.downloaders.real_debrid.proxy_url
if settings_manager.settings.downloaders.real_debrid.proxy_enabled
else None
)

response = requests.get(
"https://api.real-debrid.com/rest/1.0/user",
headers=headers,
proxies=proxy if proxy else None,
timeout=10
timeout=10,
)

if response.status_code != 200:
return {"success": False, "message": response.json()}

return {
"success": True,
"data": response.json(),
}
return response.json()


@router.get("/torbox", operation_id="torbox")
Expand All @@ -65,7 +83,7 @@ async def get_torbox_user():


@router.get("/services", operation_id="services")
async def get_services(request: Request):
async def get_services(request: Request) -> dict[str, bool]:
data = {}
if hasattr(request.app.program, "services"):
for service in request.app.program.all_services.values():
Expand All @@ -74,11 +92,15 @@ async def get_services(request: Request):
continue
for sub_service in service.services.values():
data[sub_service.key] = sub_service.initialized
return {"success": True, "data": data}
return data


class TraktOAuthInitiateResponse(BaseModel):
auth_url: str


@router.get("/trakt/oauth/initiate", operation_id="trakt_oauth_initiate")
async def initiate_trakt_oauth(request: Request):
async def initiate_trakt_oauth(request: Request) -> TraktOAuthInitiateResponse:
trakt = request.app.program.services.get(TraktContent)
if trakt is None:
raise HTTPException(status_code=404, detail="Trakt service not found")
Expand All @@ -87,24 +109,41 @@ async def initiate_trakt_oauth(request: Request):


@router.get("/trakt/oauth/callback", operation_id="trakt_oauth_callback")
async def trakt_oauth_callback(code: str, request: Request):
async def trakt_oauth_callback(code: str, request: Request) -> MessageResponse:
trakt = request.app.program.services.get(TraktContent)
if trakt is None:
raise HTTPException(status_code=404, detail="Trakt service not found")
success = trakt.handle_oauth_callback(code)
if success:
return {"success": True, "message": "OAuth token obtained successfully"}
return {"message": "OAuth token obtained successfully"}
else:
raise HTTPException(status_code=400, detail="Failed to obtain OAuth token")


class StatsResponse(BaseModel):
total_items: int
total_movies: int
total_shows: int
total_seasons: int
total_episodes: int
total_symlinks: int
incomplete_items: int
incomplete_retries: dict[str, int] = Field(
description="Media item log string: number of retries"
)
states: dict[States, int]


@router.get("/stats", operation_id="stats")
async def get_stats(_: Request):
async def get_stats(_: Request) -> StatsResponse:
payload = {}
with db.Session() as session:

movies_symlinks = session.execute(select(func.count(Movie._id)).where(Movie.symlinked == True)).scalar_one()
episodes_symlinks = session.execute(select(func.count(Episode._id)).where(Episode.symlinked == True)).scalar_one()
movies_symlinks = session.execute(
select(func.count(Movie._id)).where(Movie.symlinked == True)
).scalar_one()
episodes_symlinks = session.execute(
select(func.count(Episode._id)).where(Episode.symlinked == True)
).scalar_one()
total_symlinks = movies_symlinks + episodes_symlinks

total_movies = session.execute(select(func.count(Movie._id))).scalar_one()
Expand All @@ -113,21 +152,30 @@ async def get_stats(_: Request):
total_episodes = session.execute(select(func.count(Episode._id))).scalar_one()
total_items = session.execute(select(func.count(MediaItem._id))).scalar_one()

# Select only the IDs of incomplete items
_incomplete_items = session.execute(
select(MediaItem._id)
.where(MediaItem.last_state != States.Completed)
).scalars().all()

# Select only the IDs of incomplete items
_incomplete_items = (
session.execute(
select(MediaItem._id).where(MediaItem.last_state != States.Completed)
)
.scalars()
.all()
)

incomplete_retries = {}
if _incomplete_items:
media_items = session.query(MediaItem).filter(MediaItem._id.in_(_incomplete_items)).all()
media_items = (
session.query(MediaItem)
.filter(MediaItem._id.in_(_incomplete_items))
.all()
)
for media_item in media_items:
incomplete_retries[media_item.log_string] = media_item.scraped_times

states = {}
for state in States:
states[state] = session.execute(select(func.count(MediaItem._id)).where(MediaItem.last_state == state)).scalar_one()
states[state] = session.execute(
select(func.count(MediaItem._id)).where(MediaItem.last_state == state)
).scalar_one()

payload["total_items"] = total_items
payload["total_movies"] = total_movies
Expand All @@ -138,11 +186,15 @@ async def get_stats(_: Request):
payload["incomplete_items"] = len(_incomplete_items)
payload["incomplete_retries"] = incomplete_retries
payload["states"] = states
return payload


class LogsResponse(BaseModel):
logs: str

return {"success": True, "data": payload}

@router.get("/logs", operation_id="logs")
async def get_logs():
async def get_logs() -> str:
log_file_path = None
for handler in logger._core.handlers.values():
if ".log" in handler._name:
Expand All @@ -153,24 +205,29 @@ async def get_logs():
return {"success": False, "message": "Log file handler not found"}

try:
with open(log_file_path, 'r') as log_file:
with open(log_file_path, "r") as log_file:
log_contents = log_file.read()
return {"success": True, "logs": log_contents}
return {"logs": log_contents}
except Exception as e:
logger.error(f"Failed to read log file: {e}")
return {"success": False, "message": "Failed to read log file"}

raise HTTPException(status_code=500, detail="Failed to read log file")


@router.get("/events", operation_id="events")
async def get_events(request: Request):
return {"success": True, "data": request.app.program.em.get_event_updates()}
async def get_events(
request: Request,
) -> dict[str, list[EventUpdate]]:
return request.app.program.em.get_event_updates()


@router.get("/mount", operation_id="mount")
async def get_rclone_files():
async def get_rclone_files() -> dict[str, str]:
"""Get all files in the rclone mount."""
import os

rclone_dir = settings_manager.settings.symlink.rclone_path
file_map = {}

def scan_dir(path):
with os.scandir(path) as entries:
for entry in entries:
Expand All @@ -179,6 +236,5 @@ def scan_dir(path):
elif entry.is_dir():
scan_dir(entry.path)

scan_dir(rclone_dir) # dict of `filename: filepath``
return {"success": True, "data": file_map}

scan_dir(rclone_dir) # dict of `filename: filepath``
return file_map
Loading
Loading