Skip to content

Commit

Permalink
fix: add hard reset to cli
Browse files Browse the repository at this point in the history
  • Loading branch information
dreulavelle committed Jul 27, 2024
1 parent 16c1ceb commit e3366a6
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 46 deletions.
10 changes: 10 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from program import Program
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from program.db.db_functions import hard_reset_database
from utils.logger import logger


Expand All @@ -46,9 +47,18 @@ async def dispatch(self, request: Request, call_next):
action="store_true",
help="Ignore the cached metadata, create new data from scratch.",
)
parser.add_argument(
"--hard_reset_db",
action="store_true",
help="Hard reset the database, including deleting the Alembic directory.",
)

args = parser.parse_args()

if args.hard_reset_db:
hard_reset_database()
exit(0)

app = FastAPI(
title="Riven",
summary="A media management system.",
Expand Down
85 changes: 39 additions & 46 deletions src/program/db/db_functions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import shutil

import alembic

from program.media.item import Episode, MediaItem, Movie, Season, Show
from program.media.stream import Stream
from program.types import Event
from sqlalchemy import func, select
from sqlalchemy import func, select, text
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import NoResultFound, IntegrityError, InvalidRequestError
from utils.logger import logger
from utils import alembic_dir

from .db import db

Expand All @@ -19,39 +23,17 @@ def _ensure_item_exists_in_db(item: MediaItem) -> bool:

def _get_item_type_from_db(item: MediaItem) -> str:
with db.Session() as session:
try:
if item._id is None:
return session.execute(select(MediaItem.type).where((MediaItem.imdb_id == item.imdb_id) & ((MediaItem.type == "show") | (MediaItem.type == "movie")))).scalar_one()
return session.execute(select(MediaItem.type).where(MediaItem._id == item._id)).scalar_one()
except NoResultFound as e:
logger.error(f"No Result Found in db for {item.log_string} with id {item._id}: {e}")
except Exception as e:
logger.exception(f"Failed to get item type from db for item: {item.log_string} with id {item._id} - {e}")
if item._id is None:
return session.execute(select(MediaItem.type).where( (MediaItem.imdb_id==item.imdb_id ) & ( (MediaItem.type == "show") | (MediaItem.type == "movie") ) )).scalar_one()
return session.execute(select(MediaItem.type).where(MediaItem._id==item._id)).scalar_one()

def _store_item(item: MediaItem):
if isinstance(item, (Movie, Show, Season, Episode)) and item._id is not None:
with db.Session() as session:
try:
session.merge(item)
session.commit()
except IntegrityError as e:
logger.error(f"IntegrityError: {e}. Attempting to update existing item.")
logger.warning(f"Attempting rollback of session for item: {item.log_string}")
session.rollback()
existing_item = session.query(MediaItem).filter_by(_id=item._id).one()
for key, value in item.__dict__.items():
if key != '_sa_instance_state':
if getattr(existing_item, key) != value:
setattr(existing_item, key, value)
logger.warning(f"Committing changes to existing item: {item.log_string}")
session.commit()
except InvalidRequestError as e:
logger.error(f"InvalidRequestError: {e}. Could not update existing item.")
session.rollback()
except Exception as e:
logger.exception(f"Failed to update existing item: {item.log_string} - {e}")
else:
with db.Session() as session:
session.merge(item)
session.commit()
else:
with db.Session() as session:
_check_for_and_run_insertion_required(session, item)

def _get_item_from_db(session, item: MediaItem):
Expand Down Expand Up @@ -97,7 +79,7 @@ def _run_thread_with_db_item(fn, service, program, input_item: MediaItem | None)
if not _check_for_and_run_insertion_required(session, item):
pass
item = _get_item_from_db(session, item)

# session.merge(item)
for res in fn(item):
if isinstance(res, list):
Expand All @@ -117,7 +99,7 @@ def _run_thread_with_db_item(fn, service, program, input_item: MediaItem | None)
program._remove_from_running_items(item, service.__name__)
if res is not None and isinstance(res, MediaItem):
program._push_event_queue(Event(emitted_by=service, item=res))
# self._check_for_and_run_insertion_required(item)
# self._check_for_and_run_insertion_required(item)

item.store_state()
session.commit()
Expand All @@ -141,19 +123,30 @@ def _run_thread_with_db_item(fn, service, program, input_item: MediaItem | None)
program._push_event_queue(Event(emitted_by=service, item=i))
return

def hard_reset_database():
logger.debug("Resetting Database")

# Drop all tables
db.Model.metadata.drop_all(db.engine)
logger.debug("All MediaItem tables dropped")

# Drop the alembic_version table
with db.engine.connect() as connection:
connection.execute(text("DROP TABLE IF EXISTS alembic_version"))
logger.debug("Alembic table dropped")

# Recreate all tables
db.Model.metadata.create_all(db.engine)
logger.debug("All tables recreated")

# Reinitialize Alembic
logger.debug("Removing Alembic Directory")
shutil.rmtree(alembic_dir, ignore_errors=True)
os.makedirs(alembic_dir, exist_ok=True)

logger.debug("Hard Reset Complete")

reset = os.getenv("HARD_RESET", None)
if reset is not None and reset.lower() in ["true","1"]:
def run_delete(_type):
with db.Session() as session:
all = session.execute(select(_type).options(joinedload("*"))).unique().scalars().all()
for i in all:
session.delete(i)
session.commit()
run_delete(Episode)
run_delete(Season)
run_delete(Show)
run_delete(Movie)
run_delete(MediaItem)

logger.log("PROGRAM", "Database reset. Turning off HARD_RESET Env Var.")
os.environ.pop('HARD_RESET', None)
hard_reset_database()
del os.environ["HARD_RESET"]
1 change: 1 addition & 0 deletions src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
root_dir = Path(__file__).resolve().parents[2]

data_dir_path = root_dir / "data"
alembic_dir = data_dir_path / "alembic"

0 comments on commit e3366a6

Please sign in to comment.