Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Ensure Celery leverages the Flask-SQLAlchemy session #26186

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
688 changes: 688 additions & 0 deletions 1

Large diffs are not rendered by default.

70 changes: 30 additions & 40 deletions superset/commands/report/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@

import pandas as pd
from celery.exceptions import SoftTimeLimitExceeded
from sqlalchemy.orm import Session

from superset import app, security_manager
from superset import app, db, security_manager
from superset.commands.base import BaseCommand
from superset.commands.dashboard.permalink.create import CreateDashboardPermalinkCommand
from superset.commands.exceptions import CommandException
Expand Down Expand Up @@ -68,7 +67,6 @@
from superset.reports.notifications.base import NotificationContent
from superset.reports.notifications.exceptions import NotificationError
from superset.tasks.utils import get_executor
from superset.utils.celery import session_scope
from superset.utils.core import HeaderDataType, override_user
from superset.utils.csv import get_chart_csv_data, get_chart_dataframe
from superset.utils.decorators import logs_context
Expand All @@ -85,12 +83,10 @@
@logs_context()
def __init__(
self,
session: Session,
report_schedule: ReportSchedule,
scheduled_dttm: datetime,
execution_id: UUID,
) -> None:
self._session = session
self._report_schedule = report_schedule
self._scheduled_dttm = scheduled_dttm
self._start_dttm = datetime.utcnow()
Expand Down Expand Up @@ -123,7 +119,7 @@

self._report_schedule.last_state = state
self._report_schedule.last_eval_dttm = datetime.utcnow()
self._session.commit()
db.session.commit()

def create_log(self, error_message: Optional[str] = None) -> None:
"""
Expand All @@ -140,8 +136,8 @@
report_schedule=self._report_schedule,
uuid=self._execution_id,
)
self._session.add(log)
self._session.commit()
db.session.add(log)
db.session.commit()

def _get_url(
self,
Expand Down Expand Up @@ -485,9 +481,7 @@
"""
Checks if an alert is in it's grace period
"""
last_success = ReportScheduleDAO.find_last_success_log(
self._report_schedule, session=self._session
)
last_success = ReportScheduleDAO.find_last_success_log(self._report_schedule)
return (
last_success is not None
and self._report_schedule.grace_period
Expand All @@ -501,7 +495,7 @@
Checks if an alert/report on error is in it's notification grace period
"""
last_success = ReportScheduleDAO.find_last_error_notification(
self._report_schedule, session=self._session
self._report_schedule
)
if not last_success:
return False
Expand All @@ -518,7 +512,7 @@
Checks if an alert is in a working timeout
"""
last_working = ReportScheduleDAO.find_last_entered_working_log(
self._report_schedule, session=self._session
self._report_schedule
)
if not last_working:
return False
Expand Down Expand Up @@ -668,12 +662,10 @@

def __init__(
self,
session: Session,
task_uuid: UUID,
report_schedule: ReportSchedule,
scheduled_dttm: datetime,
):
self._session = session
self._execution_id = task_uuid
self._report_schedule = report_schedule
self._scheduled_dttm = scheduled_dttm
Expand All @@ -684,7 +676,6 @@
self._report_schedule.last_state in state_cls.current_states
):
state_cls(
self._session,
self._report_schedule,
self._scheduled_dttm,
self._execution_id,
Expand All @@ -708,39 +699,38 @@
self._execution_id = UUID(task_id)

def run(self) -> None:
with session_scope(nullpool=True) as session:
try:
self.validate(session=session)
if not self._model:
raise ReportScheduleExecuteUnexpectedError()
_, username = get_executor(
executor_types=app.config["ALERT_REPORTS_EXECUTE_AS"],
model=self._model,
try:
self.validate()
if not self._model:
raise ReportScheduleExecuteUnexpectedError()

Check warning on line 705 in superset/commands/report/execute.py

View check run for this annotation

Codecov / codecov/patch

superset/commands/report/execute.py#L705

Added line #L705 was not covered by tests
_, username = get_executor(
executor_types=app.config["ALERT_REPORTS_EXECUTE_AS"],
model=self._model,
)
user = security_manager.find_user(username)
with override_user(user):
logger.info(
"Running report schedule %s as user %s",
self._execution_id,
username,
)
user = security_manager.find_user(username)
with override_user(user):
logger.info(
"Running report schedule %s as user %s",
self._execution_id,
username,
)
ReportScheduleStateMachine(
session, self._execution_id, self._model, self._scheduled_dttm
).run()
except CommandException as ex:
raise ex
except Exception as ex:
raise ReportScheduleUnexpectedError(str(ex)) from ex
ReportScheduleStateMachine(
self._execution_id, self._model, self._scheduled_dttm
).run()
except CommandException as ex:
raise ex
except Exception as ex:
raise ReportScheduleUnexpectedError(str(ex)) from ex

Check warning on line 723 in superset/commands/report/execute.py

View check run for this annotation

Codecov / codecov/patch

superset/commands/report/execute.py#L722-L723

Added lines #L722 - L723 were not covered by tests

def validate(self, session: Session = None) -> None:
def validate(self) -> None:
# Validate/populate model exists
logger.info(
"session is validated: id %s, executionid: %s",
self._model_id,
self._execution_id,
)
self._model = (
session.query(ReportSchedule).filter_by(id=self._model_id).one_or_none()
db.session.query(ReportSchedule).filter_by(id=self._model_id).one_or_none()
)
if not self._model:
raise ReportScheduleNotFoundError()
43 changes: 21 additions & 22 deletions superset/commands/report/log_prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import logging
from datetime import datetime, timedelta

from superset import db
from superset.commands.base import BaseCommand
from superset.commands.report.exceptions import ReportSchedulePruneLogError
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO
from superset.reports.models import ReportSchedule
from superset.utils.celery import session_scope

logger = logging.getLogger(__name__)

Expand All @@ -36,28 +36,27 @@
self._worker_context = worker_context

def run(self) -> None:
with session_scope(nullpool=True) as session:
self.validate()
prune_errors = []

for report_schedule in session.query(ReportSchedule).all():
if report_schedule.log_retention is not None:
from_date = datetime.utcnow() - timedelta(
days=report_schedule.log_retention
self.validate()
prune_errors = []

for report_schedule in db.session.query(ReportSchedule).all():
if report_schedule.log_retention is not None:
from_date = datetime.utcnow() - timedelta(
days=report_schedule.log_retention
)
try:
row_count = ReportScheduleDAO.bulk_delete_logs(
report_schedule, from_date, commit=False
)
try:
row_count = ReportScheduleDAO.bulk_delete_logs(
report_schedule, from_date, session=session, commit=False
)
logger.info(
"Deleted %s logs for report schedule id: %s",
str(row_count),
str(report_schedule.id),
)
except DAODeleteFailedError as ex:
prune_errors.append(str(ex))
if prune_errors:
raise ReportSchedulePruneLogError(";".join(prune_errors))
logger.info(

Check warning on line 51 in superset/commands/report/log_prune.py

View check run for this annotation

Codecov / codecov/patch

superset/commands/report/log_prune.py#L51

Added line #L51 was not covered by tests
"Deleted %s logs for report schedule id: %s",
str(row_count),
str(report_schedule.id),
)
except DAODeleteFailedError as ex:
prune_errors.append(str(ex))
if prune_errors:
raise ReportSchedulePruneLogError(";".join(prune_errors))

Check warning on line 59 in superset/commands/report/log_prune.py

View check run for this annotation

Codecov / codecov/patch

superset/commands/report/log_prune.py#L57-L59

Added lines #L57 - L59 were not covered by tests

def validate(self) -> None:
pass
33 changes: 12 additions & 21 deletions superset/daos/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Any

from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session

from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAODeleteFailedError
Expand Down Expand Up @@ -204,27 +203,25 @@
return super().update(item, attributes, commit)

@staticmethod
def find_active(session: Session | None = None) -> list[ReportSchedule]:
def find_active() -> list[ReportSchedule]:
"""
Find all active reports. If session is passed it will be used instead of the
default `db.session`, this is useful when on a celery worker session context
Find all active reports.
"""
session = session or db.session
return (
session.query(ReportSchedule).filter(ReportSchedule.active.is_(True)).all()
db.session.query(ReportSchedule)
.filter(ReportSchedule.active.is_(True))
.all()
)

@staticmethod
def find_last_success_log(
report_schedule: ReportSchedule,
session: Session | None = None,
) -> ReportExecutionLog | None:
"""
Finds last success execution log for a given report
"""
session = session or db.session
return (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.state == ReportState.SUCCESS,
ReportExecutionLog.report_schedule == report_schedule,
Expand All @@ -236,14 +233,12 @@
@staticmethod
def find_last_entered_working_log(
report_schedule: ReportSchedule,
session: Session | None = None,
) -> ReportExecutionLog | None:
"""
Finds last success execution log for a given report
"""
session = session or db.session
return (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.state == ReportState.WORKING,
ReportExecutionLog.report_schedule == report_schedule,
Expand All @@ -256,14 +251,12 @@
@staticmethod
def find_last_error_notification(
report_schedule: ReportSchedule,
session: Session | None = None,
) -> ReportExecutionLog | None:
"""
Finds last error email sent
"""
session = session or db.session
last_error_email_log = (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.error_message
== REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER,
Expand All @@ -276,7 +269,7 @@
return None
# Checks that only errors have occurred since the last email
report_from_last_email = (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.state.notin_(
[ReportState.ERROR, ReportState.WORKING]
Expand All @@ -293,22 +286,20 @@
def bulk_delete_logs(
model: ReportSchedule,
from_date: datetime,
session: Session | None = None,
commit: bool = True,
) -> int | None:
session = session or db.session
try:
row_count = (
session.query(ReportExecutionLog)
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.report_schedule == model,
ReportExecutionLog.end_dttm < from_date,
)
.delete(synchronize_session="fetch")
)
if commit:
session.commit()
db.session.commit()

Check warning on line 301 in superset/daos/report.py

View check run for this annotation

Codecov / codecov/patch

superset/daos/report.py#L301

Added line #L301 was not covered by tests
return row_count
except SQLAlchemyError as ex:
session.rollback()
db.session.rollback()

Check warning on line 304 in superset/daos/report.py

View check run for this annotation

Codecov / codecov/patch

superset/daos/report.py#L304

Added line #L304 was not covered by tests
raise DAODeleteFailedError(str(ex)) from ex
11 changes: 4 additions & 7 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.sql import literal_column, quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
from sqlalchemy.types import TypeEngine
Expand Down Expand Up @@ -1071,7 +1070,7 @@ def convert_dttm( # pylint: disable=unused-argument
return None

@classmethod
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
def handle_cursor(cls, cursor: Any, query: Query) -> None:
"""Handle a live cursor between the execute and fetchall calls

The flow works without this method doing anything, but it allows
Expand All @@ -1080,9 +1079,7 @@ def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
# TODO: Fix circular import error caused by importing sql_lab.Query

@classmethod
def execute_with_cursor(
cls, cursor: Any, sql: str, query: Query, session: Session
) -> None:
def execute_with_cursor(cls, cursor: Any, sql: str, query: Query) -> None:
"""
Trigger execution of a query and handle the resulting cursor.

Expand All @@ -1095,7 +1092,7 @@ def execute_with_cursor(
logger.debug("Query %d: Running query: %s", query.id, sql)
cls.execute(cursor, sql, async_=True)
logger.debug("Query %d: Handling cursor", query.id)
cls.handle_cursor(cursor, query, session)
cls.handle_cursor(cursor, query)

@classmethod
def extract_error_message(cls, ex: Exception) -> str:
Expand Down Expand Up @@ -1841,7 +1838,7 @@ def get_sqla_column_type(

# pylint: disable=unused-argument
@classmethod
def prepare_cancel_query(cls, query: Query, session: Session) -> None:
def prepare_cancel_query(cls, query: Query) -> None:
"""
Some databases may acquire the query cancelation id after the query
cancelation request has been received. For those cases, the db engine spec
Expand Down
Loading
Loading