diff --git a/src/main.py b/src/main.py index 77a3a1f7..6625098e 100644 --- a/src/main.py +++ b/src/main.py @@ -20,6 +20,7 @@ from program import Program from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request +from program.db.db_functions import hard_reset_database from utils.logger import logger @@ -46,9 +47,18 @@ async def dispatch(self, request: Request, call_next): action="store_true", help="Ignore the cached metadata, create new data from scratch.", ) +parser.add_argument( + "--hard_reset_db", + action="store_true", + help="Hard reset the database, including deleting the Alembic directory.", +) args = parser.parse_args() +if args.hard_reset_db: + hard_reset_database() + exit(0) + app = FastAPI( title="Riven", summary="A media management system.", diff --git a/src/program/db/db_functions.py b/src/program/db/db_functions.py index 6e28d20c..b5ddf33f 100644 --- a/src/program/db/db_functions.py +++ b/src/program/db/db_functions.py @@ -1,12 +1,16 @@ import os +import shutil + +import alembic from program.media.item import Episode, MediaItem, Movie, Season, Show from program.media.stream import Stream from program.types import Event -from sqlalchemy import func, select +from sqlalchemy import func, select, text from sqlalchemy.orm import joinedload from sqlalchemy.exc import NoResultFound, IntegrityError, InvalidRequestError from utils.logger import logger +from utils import alembic_dir from .db import db @@ -19,39 +23,17 @@ def _ensure_item_exists_in_db(item: MediaItem) -> bool: def _get_item_type_from_db(item: MediaItem) -> str: with db.Session() as session: - try: - if item._id is None: - return session.execute(select(MediaItem.type).where((MediaItem.imdb_id == item.imdb_id) & ((MediaItem.type == "show") | (MediaItem.type == "movie")))).scalar_one() - return session.execute(select(MediaItem.type).where(MediaItem._id == item._id)).scalar_one() - except NoResultFound as e: - logger.error(f"No Result Found in db for {item.log_string} with id {item._id}: {e}") - except Exception as e: - logger.exception(f"Failed to get item type from db for item: {item.log_string} with id {item._id} - {e}") + if item._id is None: + return session.execute(select(MediaItem.type).where( (MediaItem.imdb_id==item.imdb_id ) & ( (MediaItem.type == "show") | (MediaItem.type == "movie") ) )).scalar_one() + return session.execute(select(MediaItem.type).where(MediaItem._id==item._id)).scalar_one() def _store_item(item: MediaItem): if isinstance(item, (Movie, Show, Season, Episode)) and item._id is not None: with db.Session() as session: - try: - session.merge(item) - session.commit() - except IntegrityError as e: - logger.error(f"IntegrityError: {e}. Attempting to update existing item.") - logger.warning(f"Attempting rollback of session for item: {item.log_string}") - session.rollback() - existing_item = session.query(MediaItem).filter_by(_id=item._id).one() - for key, value in item.__dict__.items(): - if key != '_sa_instance_state': - if getattr(existing_item, key) != value: - setattr(existing_item, key, value) - logger.warning(f"Committing changes to existing item: {item.log_string}") - session.commit() - except InvalidRequestError as e: - logger.error(f"InvalidRequestError: {e}. Could not update existing item.") - session.rollback() - except Exception as e: - logger.exception(f"Failed to update existing item: {item.log_string} - {e}") - else: - with db.Session() as session: + session.merge(item) + session.commit() + else: + with db.Session() as session: _check_for_and_run_insertion_required(session, item) def _get_item_from_db(session, item: MediaItem): @@ -97,7 +79,7 @@ def _run_thread_with_db_item(fn, service, program, input_item: MediaItem | None) if not _check_for_and_run_insertion_required(session, item): pass item = _get_item_from_db(session, item) - + # session.merge(item) for res in fn(item): if isinstance(res, list): @@ -117,7 +99,7 @@ def _run_thread_with_db_item(fn, service, program, input_item: MediaItem | None) program._remove_from_running_items(item, service.__name__) if res is not None and isinstance(res, MediaItem): program._push_event_queue(Event(emitted_by=service, item=res)) - # self._check_for_and_run_insertion_required(item) + # self._check_for_and_run_insertion_required(item) item.store_state() session.commit() @@ -141,19 +123,30 @@ def _run_thread_with_db_item(fn, service, program, input_item: MediaItem | None) program._push_event_queue(Event(emitted_by=service, item=i)) return +def hard_reset_database(): + logger.debug("Resetting Database") + + # Drop all tables + db.Model.metadata.drop_all(db.engine) + logger.debug("All MediaItem tables dropped") + + # Drop the alembic_version table + with db.engine.connect() as connection: + connection.execute(text("DROP TABLE IF EXISTS alembic_version")) + logger.debug("Alembic table dropped") + + # Recreate all tables + db.Model.metadata.create_all(db.engine) + logger.debug("All tables recreated") + + # Reinitialize Alembic + logger.debug("Removing Alembic Directory") + shutil.rmtree(alembic_dir, ignore_errors=True) + os.makedirs(alembic_dir, exist_ok=True) + + logger.debug("Hard Reset Complete") + reset = os.getenv("HARD_RESET", None) if reset is not None and reset.lower() in ["true","1"]: - def run_delete(_type): - with db.Session() as session: - all = session.execute(select(_type).options(joinedload("*"))).unique().scalars().all() - for i in all: - session.delete(i) - session.commit() - run_delete(Episode) - run_delete(Season) - run_delete(Show) - run_delete(Movie) - run_delete(MediaItem) - - logger.log("PROGRAM", "Database reset. Turning off HARD_RESET Env Var.") - os.environ.pop('HARD_RESET', None) \ No newline at end of file + hard_reset_database() + del os.environ["HARD_RESET"] \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 5c5ef6d6..e2f85f0e 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -3,3 +3,4 @@ root_dir = Path(__file__).resolve().parents[2] data_dir_path = root_dir / "data" +alembic_dir = data_dir_path / "alembic" \ No newline at end of file