From ac2937a6c58b63c2723917a1cd111a8363a728a0 Mon Sep 17 00:00:00 2001 From: Bogdan Date: Thu, 10 Sep 2020 13:29:57 -0700 Subject: [PATCH] fix: use nullpool in the celery workers (#10819) * Use nullpool in the celery workers * Address comments Co-authored-by: bogdan kyryliuk --- superset/cli.py | 12 +- superset/sql_lab.py | 47 +----- superset/tasks/alerts/observer.py | 12 +- superset/tasks/schedules.py | 247 ++++++++++++++++-------------- superset/utils/celery.py | 57 +++++++ tests/alerts_tests.py | 48 +++--- tests/schedules_test.py | 4 + 7 files changed, 234 insertions(+), 193 deletions(-) create mode 100644 superset/utils/celery.py diff --git a/superset/cli.py b/superset/cli.py index ef176822d9a91..f0f7f1e1e54bd 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -34,6 +34,7 @@ from superset.app import create_app from superset.extensions import celery_app, db from superset.utils import core as utils +from superset.utils.celery import session_scope from superset.utils.urls import get_url_path logger = logging.getLogger(__name__) @@ -619,6 +620,11 @@ def alert() -> None: from superset.tasks.schedules import schedule_window click.secho("Processing one alert loop", fg="green") - schedule_window( - ScheduleType.alert, datetime.now() - timedelta(1000), datetime.now(), 6000 - ) + with session_scope(nullpool=True) as session: + schedule_window( + report_type=ScheduleType.alert, + start_at=datetime.now() - timedelta(1000), + stop_at=datetime.now(), + resolution=6000, + session=session, + ) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 09be5e86fa0de..796ddba87bd6f 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -19,33 +19,25 @@ from contextlib import closing from datetime import datetime from sys import getsizeof -from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Tuple, Union import backoff import msgpack import pyarrow as pa import simplejson as json -import sqlalchemy from celery.exceptions import SoftTimeLimitExceeded from celery.task.base import Task -from contextlib2 import contextmanager from flask_babel import lazy_gettext as _ -from sqlalchemy.orm import Session, sessionmaker -from sqlalchemy.pool import NullPool - -from superset import ( - app, - db, - results_backend, - results_backend_use_msgpack, - security_manager, -) +from sqlalchemy.orm import Session + +from superset import app, results_backend, results_backend_use_msgpack, security_manager from superset.dataframe import df_to_records from superset.db_engine_specs import BaseEngineSpec from superset.extensions import celery_app from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet from superset.sql_parse import ParsedQuery +from superset.utils.celery import session_scope from superset.utils.core import ( json_iso_dttm_ser, QuerySource, @@ -121,35 +113,6 @@ def get_query(query_id: int, session: Session) -> Query: raise SqlLabException("Failed at getting query") -@contextmanager -def session_scope(nullpool: bool) -> Iterator[Session]: - """Provide a transactional scope around a series of operations.""" - database_uri = app.config["SQLALCHEMY_DATABASE_URI"] - if "sqlite" in database_uri: - logger.warning( - "SQLite Database support for metadata databases will be removed \ - in a future version of Superset." - ) - if nullpool: - engine = sqlalchemy.create_engine(database_uri, poolclass=NullPool) - session_class = sessionmaker() - session_class.configure(bind=engine) - session = session_class() - else: - session = db.session() - session.commit() # HACK - - try: - yield session - session.commit() - except Exception as ex: - session.rollback() - logger.exception(ex) - raise - finally: - session.close() - - @celery_app.task( name="sql_lab.get_sql_results", bind=True, diff --git a/superset/tasks/alerts/observer.py b/superset/tasks/alerts/observer.py index f7c5373fd72da..34ff6689583dc 100644 --- a/superset/tasks/alerts/observer.py +++ b/superset/tasks/alerts/observer.py @@ -20,22 +20,24 @@ from typing import Optional import pandas as pd +from sqlalchemy.orm import Session -from superset import db from superset.models.alerts import Alert, SQLObservation from superset.sql_parse import ParsedQuery logger = logging.getLogger("tasks.email_reports") -def observe(alert_id: int) -> Optional[str]: +# Session needs to be passed along in the celery workers and db.session cannot be used. +# For more info see: https://github.com/apache/incubator-superset/issues/10530 +def observe(alert_id: int, session: Session) -> Optional[str]: """ Runs the SQL query in an alert's SQLObserver and then stores the result in a SQLObservation. Returns an error message if the observer value was not valid """ - alert = db.session.query(Alert).filter_by(id=alert_id).one() + alert = session.query(Alert).filter_by(id=alert_id).one() sql_observer = alert.sql_observer[0] value = None @@ -57,8 +59,8 @@ def observe(alert_id: int) -> Optional[str]: error_msg=error_msg, ) - db.session.add(observation) - db.session.commit() + session.add(observation) + session.commit() return error_msg diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py index 9643f094f6361..7c5ad4adbb25f 100644 --- a/superset/tasks/schedules.py +++ b/superset/tasks/schedules.py @@ -47,8 +47,9 @@ from selenium.webdriver import chrome, firefox from selenium.webdriver.remote.webdriver import WebDriver from sqlalchemy.exc import NoSuchColumnError, ResourceClosedError +from sqlalchemy.orm import Session -from superset import app, db, security_manager, thumbnail_cache +from superset import app, security_manager, thumbnail_cache from superset.extensions import celery_app, machine_auth_provider_factory from superset.models.alerts import Alert, AlertLog from superset.models.dashboard import Dashboard @@ -62,6 +63,7 @@ from superset.tasks.alerts.observer import observe from superset.tasks.alerts.validator import get_validator_function from superset.tasks.slack_util import deliver_slack_msg +from superset.utils.celery import session_scope from superset.utils.core import get_email_address_list, send_email_smtp from superset.utils.screenshots import ChartScreenshot, WebDriverProxy from superset.utils.urls import get_url_path @@ -225,7 +227,7 @@ def destroy_webdriver( pass -def deliver_dashboard( +def deliver_dashboard( # pylint: disable=too-many-locals dashboard_id: int, recipients: Optional[str], slack_channel: Optional[str], @@ -236,69 +238,70 @@ def deliver_dashboard( """ Given a schedule, delivery the dashboard as an email report """ - dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one() + with session_scope(nullpool=True) as session: + dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one() - dashboard_url = _get_url_path( - "Superset.dashboard", dashboard_id_or_slug=dashboard.id - ) - dashboard_url_user_friendly = _get_url_path( - "Superset.dashboard", user_friendly=True, dashboard_id_or_slug=dashboard.id - ) - - # Create a driver, fetch the page, wait for the page to render - driver = create_webdriver() - window = config["WEBDRIVER_WINDOW"]["dashboard"] - driver.set_window_size(*window) - driver.get(dashboard_url) - time.sleep(EMAIL_PAGE_RENDER_WAIT) - - # Set up a function to retry once for the element. - # This is buggy in certain selenium versions with firefox driver - get_element = getattr(driver, "find_element_by_class_name") - element = retry_call( - get_element, fargs=["grid-container"], tries=2, delay=EMAIL_PAGE_RENDER_WAIT - ) - - try: - screenshot = element.screenshot_as_png - except WebDriverException: - # Some webdrivers do not support screenshots for elements. - # In such cases, take a screenshot of the entire page. - screenshot = driver.screenshot() # pylint: disable=no-member - finally: - destroy_webdriver(driver) - - # Generate the email body and attachments - report_content = _generate_report_content( - delivery_type, - screenshot, - dashboard.dashboard_title, - dashboard_url_user_friendly, - ) + dashboard_url = _get_url_path( + "Superset.dashboard", dashboard_id_or_slug=dashboard.id + ) + dashboard_url_user_friendly = _get_url_path( + "Superset.dashboard", user_friendly=True, dashboard_id_or_slug=dashboard.id + ) - subject = __( - "%(prefix)s %(title)s", - prefix=config["EMAIL_REPORTS_SUBJECT_PREFIX"], - title=dashboard.dashboard_title, - ) + # Create a driver, fetch the page, wait for the page to render + driver = create_webdriver() + window = config["WEBDRIVER_WINDOW"]["dashboard"] + driver.set_window_size(*window) + driver.get(dashboard_url) + time.sleep(EMAIL_PAGE_RENDER_WAIT) + + # Set up a function to retry once for the element. + # This is buggy in certain selenium versions with firefox driver + get_element = getattr(driver, "find_element_by_class_name") + element = retry_call( + get_element, fargs=["grid-container"], tries=2, delay=EMAIL_PAGE_RENDER_WAIT + ) - if recipients: - _deliver_email( - recipients, - deliver_as_group, - subject, - report_content.body, - report_content.data, - report_content.images, + try: + screenshot = element.screenshot_as_png + except WebDriverException: + # Some webdrivers do not support screenshots for elements. + # In such cases, take a screenshot of the entire page. + screenshot = driver.screenshot() # pylint: disable=no-member + finally: + destroy_webdriver(driver) + + # Generate the email body and attachments + report_content = _generate_report_content( + delivery_type, + screenshot, + dashboard.dashboard_title, + dashboard_url_user_friendly, ) - if slack_channel: - deliver_slack_msg( - slack_channel, - subject, - report_content.slack_message, - report_content.slack_attachment, + + subject = __( + "%(prefix)s %(title)s", + prefix=config["EMAIL_REPORTS_SUBJECT_PREFIX"], + title=dashboard.dashboard_title, ) + if recipients: + _deliver_email( + recipients, + deliver_as_group, + subject, + report_content.body, + report_content.data, + report_content.images, + ) + if slack_channel: + deliver_slack_msg( + slack_channel, + subject, + report_content.slack_message, + report_content.slack_attachment, + ) + def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> ReportContent: slice_url = _get_url_path( @@ -362,8 +365,8 @@ def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> ReportConte return ReportContent(body, data, None, slack_message, content) -def _get_slice_screenshot(slice_id: int) -> ScreenshotData: - slice_obj = db.session.query(Slice).get(slice_id) +def _get_slice_screenshot(slice_id: int, session: Session) -> ScreenshotData: + slice_obj = session.query(Slice).get(slice_id) chart_url = get_url_path("Superset.slice", slice_id=slice_obj.id, standalone="true") screenshot = ChartScreenshot(chart_url, slice_obj.digest) @@ -376,7 +379,7 @@ def _get_slice_screenshot(slice_id: int) -> ScreenshotData: user=user, cache=thumbnail_cache, force=True, ) - db.session.commit() + session.commit() return ScreenshotData(image_url, image_data) @@ -427,11 +430,12 @@ def deliver_slice( # pylint: disable=too-many-arguments delivery_type: EmailDeliveryType, email_format: SliceEmailReportFormat, deliver_as_group: bool, + session: Session, ) -> None: """ Given a schedule, delivery the slice as an email report """ - slc = db.session.query(Slice).filter_by(id=slice_id).one() + slc = session.query(Slice).filter_by(id=slice_id).one() if email_format == SliceEmailReportFormat.data: report_content = _get_slice_data(slc, delivery_type) @@ -477,38 +481,42 @@ def schedule_email_report( # pylint: disable=unused-argument slack_channel: Optional[str] = None, ) -> None: model_cls = get_scheduler_model(report_type) - schedule = db.create_scoped_session().query(model_cls).get(schedule_id) + with session_scope(nullpool=True) as session: + schedule = session.query(model_cls).get(schedule_id) - # The user may have disabled the schedule. If so, ignore this - if not schedule or not schedule.active: - logger.info("Ignoring deactivated schedule") - return - - recipients = recipients or schedule.recipients - slack_channel = slack_channel or schedule.slack_channel - logger.info( - "Starting report for slack: %s and recipients: %s.", slack_channel, recipients - ) + # The user may have disabled the schedule. If so, ignore this + if not schedule or not schedule.active: + logger.info("Ignoring deactivated schedule") + return - if report_type == ScheduleType.dashboard: - deliver_dashboard( - schedule.dashboard_id, - recipients, + recipients = recipients or schedule.recipients + slack_channel = slack_channel or schedule.slack_channel + logger.info( + "Starting report for slack: %s and recipients: %s.", slack_channel, - schedule.delivery_type, - schedule.deliver_as_group, - ) - elif report_type == ScheduleType.slice: - deliver_slice( - schedule.slice_id, recipients, - slack_channel, - schedule.delivery_type, - schedule.email_format, - schedule.deliver_as_group, ) - else: - raise RuntimeError("Unknown report type") + + if report_type == ScheduleType.dashboard: + deliver_dashboard( + schedule.dashboard_id, + recipients, + slack_channel, + schedule.delivery_type, + schedule.deliver_as_group, + ) + elif report_type == ScheduleType.slice: + deliver_slice( + schedule.slice_id, + recipients, + slack_channel, + schedule.delivery_type, + schedule.email_format, + schedule.deliver_as_group, + session, + ) + else: + raise RuntimeError("Unknown report type") @celery_app.task( @@ -529,9 +537,8 @@ def schedule_alert_query( # pylint: disable=unused-argument slack_channel: Optional[str] = None, ) -> None: model_cls = get_scheduler_model(report_type) - - try: - schedule = db.session.query(model_cls).get(schedule_id) + with session_scope(nullpool=True) as session: + schedule = session.query(model_cls).get(schedule_id) # The user may have disabled the schedule. If so, ignore this if not schedule or not schedule.active: @@ -539,15 +546,11 @@ def schedule_alert_query( # pylint: disable=unused-argument return if report_type == ScheduleType.alert: - evaluate_alert(schedule.id, schedule.label, recipients, slack_channel) + evaluate_alert( + schedule.id, schedule.label, session, recipients, slack_channel + ) else: raise RuntimeError("Unknown report type") - except NoSuchColumnError as column_error: - stats_logger.incr("run_alert_task.error.nosuchcolumnerror") - raise column_error - except ResourceClosedError as resource_error: - stats_logger.incr("run_alert_task.error.resourceclosederror") - raise resource_error class AlertState: @@ -558,6 +561,7 @@ class AlertState: def deliver_alert( alert_id: int, + session: Session, recipients: Optional[str] = None, slack_channel: Optional[str] = None, ) -> None: @@ -566,7 +570,7 @@ def deliver_alert( to its respective email and slack recipients """ - alert = db.session.query(Alert).get(alert_id) + alert = session.query(Alert).get(alert_id) logging.info("Triggering alert: %s", alert) @@ -588,7 +592,7 @@ def deliver_alert( str(alert.observations[-1].value), validation_error_message, _get_url_path("AlertModelView.show", user_friendly=True, pk=alert_id), - _get_slice_screenshot(alert.slice.id), + _get_slice_screenshot(alert.slice.id, session), ) else: # TODO: dashboard delivery! @@ -668,6 +672,7 @@ def deliver_slack_alert(alert_content: AlertContent, slack_channel: str) -> None def evaluate_alert( alert_id: int, label: str, + session: Session, recipients: Optional[str] = None, slack_channel: Optional[str] = None, ) -> None: @@ -680,7 +685,7 @@ def evaluate_alert( try: logger.info("Querying observers for alert <%s:%s>", alert_id, label) - error_msg = observe(alert_id) + error_msg = observe(alert_id, session) if error_msg: state = AlertState.ERROR logging.error(error_msg) @@ -694,17 +699,17 @@ def evaluate_alert( if state != AlertState.ERROR: # Don't validate alert on test runs since it may not be triggered if recipients or slack_channel: - deliver_alert(alert_id, recipients, slack_channel) + deliver_alert(alert_id, session, recipients, slack_channel) state = AlertState.TRIGGER # Validate during regular workflow and deliver only if triggered - elif validate_observations(alert_id, label): - deliver_alert(alert_id, recipients, slack_channel) + elif validate_observations(alert_id, label, session): + deliver_alert(alert_id, session, recipients, slack_channel) state = AlertState.TRIGGER else: state = AlertState.PASS - db.session.commit() - alert = db.session.query(Alert).get(alert_id) + session.commit() + alert = session.query(Alert).get(alert_id) if state != AlertState.ERROR: alert.last_eval_dttm = dttm_end alert.last_state = state @@ -716,10 +721,10 @@ def evaluate_alert( state=state, ) ) - db.session.commit() + session.commit() -def validate_observations(alert_id: int, label: str) -> bool: +def validate_observations(alert_id: int, label: str, session: Session) -> bool: """ Runs an alert's validators to check if it should be triggered or not If so, return the name of the validator that returned true @@ -727,7 +732,7 @@ def validate_observations(alert_id: int, label: str) -> bool: logger.info("Validating observations for alert <%s:%s>", alert_id, label) - alert = db.session.query(Alert).get(alert_id) + alert = session.query(Alert).get(alert_id) if alert.validators: validator = alert.validators[0] validate = get_validator_function(validator.validator_type) @@ -760,7 +765,11 @@ def next_schedules( def schedule_window( - report_type: str, start_at: datetime, stop_at: datetime, resolution: int + report_type: str, + start_at: datetime, + stop_at: datetime, + resolution: int, + session: Session, ) -> None: """ Find all active schedules and schedule celery tasks for @@ -772,8 +781,7 @@ def schedule_window( if not model_cls: return None - dbsession = db.create_scoped_session() - schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True)) + schedules = session.query(model_cls).filter(model_cls.active.is_(True)) for schedule in schedules: logging.info("Processing schedule %s", schedule) @@ -810,7 +818,6 @@ def get_scheduler_action(report_type: str) -> Optional[Callable[..., Any]]: @celery_app.task(name="email_reports.schedule_hourly") def schedule_hourly() -> None: """ Celery beat job meant to be invoked hourly """ - if not config["ENABLE_SCHEDULED_EMAIL_REPORTS"]: logger.info("Scheduled email reports not enabled in config") return @@ -820,8 +827,10 @@ def schedule_hourly() -> None: # Get the top of the hour start_at = datetime.now(tzlocal()).replace(microsecond=0, second=0, minute=0) stop_at = start_at + timedelta(seconds=3600) - schedule_window(ScheduleType.dashboard, start_at, stop_at, resolution) - schedule_window(ScheduleType.slice, start_at, stop_at, resolution) + + with session_scope(nullpool=True) as session: + schedule_window(ScheduleType.dashboard, start_at, stop_at, resolution, session) + schedule_window(ScheduleType.slice, start_at, stop_at, resolution, session) @celery_app.task(name="alerts.schedule_check") @@ -833,5 +842,5 @@ def schedule_alerts() -> None: seconds=3600 ) # process any missed tasks in the past hour stop_at = now + timedelta(seconds=1) - - schedule_window(ScheduleType.alert, start_at, stop_at, resolution) + with session_scope(nullpool=True) as session: + schedule_window(ScheduleType.alert, start_at, stop_at, resolution, session) diff --git a/superset/utils/celery.py b/superset/utils/celery.py new file mode 100644 index 0000000000000..1692e5574c28d --- /dev/null +++ b/superset/utils/celery.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Iterator + +import sqlalchemy as sa +from contextlib2 import contextmanager +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import NullPool + +from superset import app, db + +logger = logging.getLogger(__name__) + +# Null pool is used for the celery workers due process forking side effects. +# For more info see: https://github.com/apache/incubator-superset/issues/10530 +@contextmanager +def session_scope(nullpool: bool) -> Iterator[Session]: + """Provide a transactional scope around a series of operations.""" + database_uri = app.config["SQLALCHEMY_DATABASE_URI"] + if "sqlite" in database_uri: + logger.warning( + "SQLite Database support for metadata databases will be removed \ + in a future version of Superset." + ) + if nullpool: + engine = sa.create_engine(database_uri, poolclass=NullPool) + session_class = sessionmaker() + session_class.configure(bind=engine) + session = session_class() + else: + session = db.session() + session.commit() # HACK + + try: + yield session + session.commit() + except Exception as ex: + session.rollback() + logger.exception(ex) + raise + finally: + session.close() diff --git a/tests/alerts_tests.py b/tests/alerts_tests.py index a2226107a01d8..5f6464cc1a36b 100644 --- a/tests/alerts_tests.py +++ b/tests/alerts_tests.py @@ -112,37 +112,37 @@ def test_alert_observer(setup_database): # Test SQLObserver with int SQL return alert1 = create_alert(dbsession, "SELECT 55") - observe(alert1.id) + observe(alert1.id, dbsession) assert alert1.sql_observer[0].observations[-1].value == 55.0 assert alert1.sql_observer[0].observations[-1].error_msg is None # Test SQLObserver with double SQL return alert2 = create_alert(dbsession, "SELECT 30.0 as wage") - observe(alert2.id) + observe(alert2.id, dbsession) assert alert2.sql_observer[0].observations[-1].value == 30.0 assert alert2.sql_observer[0].observations[-1].error_msg is None # Test SQLObserver with NULL result alert3 = create_alert(dbsession, "SELECT null as null_result") - observe(alert3.id) + observe(alert3.id, dbsession) assert alert3.sql_observer[0].observations[-1].value is None assert alert3.sql_observer[0].observations[-1].error_msg is None # Test SQLObserver with empty SQL return alert4 = create_alert(dbsession, "SELECT first FROM test_table WHERE first = -1") - observe(alert4.id) + observe(alert4.id, dbsession) assert alert4.sql_observer[0].observations[-1].value is None assert alert4.sql_observer[0].observations[-1].error_msg is not None # Test SQLObserver with str result alert5 = create_alert(dbsession, "SELECT 'test_string' as string_value") - observe(alert5.id) + observe(alert5.id, dbsession) assert alert5.sql_observer[0].observations[-1].value is None assert alert5.sql_observer[0].observations[-1].error_msg is not None # Test SQLObserver with two row result alert6 = create_alert(dbsession, "SELECT first FROM test_table") - observe(alert6.id) + observe(alert6.id, dbsession) assert alert6.sql_observer[0].observations[-1].value is None assert alert6.sql_observer[0].observations[-1].error_msg is not None @@ -150,7 +150,7 @@ def test_alert_observer(setup_database): alert7 = create_alert( dbsession, "SELECT first, second FROM test_table WHERE first = 1" ) - observe(alert7.id) + observe(alert7.id, dbsession) assert alert7.sql_observer[0].observations[-1].value is None assert alert7.sql_observer[0].observations[-1].error_msg is not None @@ -161,22 +161,22 @@ def test_evaluate_alert(mock_deliver_alert, setup_database): # Test error with Observer SQL statement alert1 = create_alert(dbsession, "$%^&") - evaluate_alert(alert1.id, alert1.label) + evaluate_alert(alert1.id, alert1.label, dbsession) assert alert1.logs[-1].state == AlertState.ERROR # Test error with alert lacking observer alert2 = dbsession.query(Alert).filter_by(label="No Observer").one() - evaluate_alert(alert2.id, alert2.label) + evaluate_alert(alert2.id, alert2.label, dbsession) assert alert2.logs[-1].state == AlertState.ERROR # Test pass on alert lacking validator alert3 = create_alert(dbsession, "SELECT 55") - evaluate_alert(alert3.id, alert3.label) + evaluate_alert(alert3.id, alert3.label, dbsession) assert alert3.logs[-1].state == AlertState.PASS # Test triggering successful alert alert4 = create_alert(dbsession, "SELECT 55", "not null", "{}") - evaluate_alert(alert4.id, alert4.label) + evaluate_alert(alert4.id, alert4.label, dbsession) assert mock_deliver_alert.call_count == 1 assert alert4.logs[-1].state == AlertState.TRIGGER @@ -214,17 +214,17 @@ def test_not_null_validator(setup_database): # Test passing SQLObserver with 'null' SQL result alert1 = create_alert(dbsession, "SELECT 0") - observe(alert1.id) + observe(alert1.id, dbsession) assert not_null_validator(alert1.sql_observer[0], "{}") is False # Test passing SQLObserver with empty SQL result alert2 = create_alert(dbsession, "SELECT first FROM test_table WHERE first = -1") - observe(alert2.id) + observe(alert2.id, dbsession) assert not_null_validator(alert2.sql_observer[0], "{}") is False # Test triggering alert with non-null SQL result alert3 = create_alert(dbsession, "SELECT 55") - observe(alert3.id) + observe(alert3.id, dbsession) assert not_null_validator(alert3.sql_observer[0], "{}") is True @@ -233,7 +233,7 @@ def test_operator_validator(setup_database): # Test passing SQLObserver with empty SQL result alert1 = create_alert(dbsession, "SELECT first FROM test_table WHERE first = -1") - observe(alert1.id) + observe(alert1.id, dbsession) assert ( operator_validator(alert1.sql_observer[0], '{"op": ">=", "threshold": 60}') is False @@ -241,7 +241,7 @@ def test_operator_validator(setup_database): # Test passing SQLObserver with result that doesn't pass a greater than threshold alert2 = create_alert(dbsession, "SELECT 55") - observe(alert2.id) + observe(alert2.id, dbsession) assert ( operator_validator(alert2.sql_observer[0], '{"op": ">=", "threshold": 60}') is False @@ -283,23 +283,23 @@ def test_validate_observations(setup_database): # Test False on alert with no validator alert1 = create_alert(dbsession, "SELECT 55") - assert validate_observations(alert1.id, alert1.label) is False + assert validate_observations(alert1.id, alert1.label, dbsession) is False # Test False on alert with no observations alert2 = create_alert(dbsession, "SELECT 55", "not null", "{}") - assert validate_observations(alert2.id, alert2.label) is False + assert validate_observations(alert2.id, alert2.label, dbsession) is False # Test False on alert that shouldnt be triggered alert3 = create_alert(dbsession, "SELECT 0", "not null", "{}") - observe(alert3.id) - assert validate_observations(alert3.id, alert3.label) is False + observe(alert3.id, dbsession) + assert validate_observations(alert3.id, alert3.label, dbsession) is False # Test True on alert that should be triggered alert4 = create_alert( dbsession, "SELECT 55", "operator", '{"op": "<=", "threshold": 60}' ) - observe(alert4.id) - assert validate_observations(alert4.id, alert4.label) is True + observe(alert4.id, dbsession) + assert validate_observations(alert4.id, alert4.label, dbsession) is True @patch("superset.tasks.slack_util.WebClient.files_upload") @@ -311,7 +311,7 @@ def test_deliver_alert_screenshot( ): dbsession = setup_database alert = create_alert(dbsession, "SELECT 55", "not null", "{}") - observe(alert.id) + observe(alert.id, dbsession) screenshot = read_fixture("sample.png") screenshot_mock.return_value = screenshot @@ -322,7 +322,7 @@ def test_deliver_alert_screenshot( f"http://0.0.0.0:8080/superset/slice/{alert.slice_id}/", ] - deliver_alert(alert_id=alert.id) + deliver_alert(alert.id, dbsession) assert email_mock.call_args[1]["images"]["screenshot"] == screenshot assert file_upload_mock.call_args[1] == { "channels": alert.slack_channel, diff --git a/tests/schedules_test.py b/tests/schedules_test.py index 77f70703c3a01..88b6d1f924d1f 100644 --- a/tests/schedules_test.py +++ b/tests/schedules_test.py @@ -366,6 +366,7 @@ def test_deliver_slice_inline_image( schedule.delivery_type, schedule.email_format, schedule.deliver_as_group, + db.session, ) mtime.sleep.assert_called_once() driver.screenshot.assert_not_called() @@ -418,6 +419,7 @@ def test_deliver_slice_attachment( schedule.delivery_type, schedule.email_format, schedule.deliver_as_group, + db.session, ) mtime.sleep.assert_called_once() @@ -466,6 +468,7 @@ def test_deliver_slice_csv_attachment( schedule.delivery_type, schedule.email_format, schedule.deliver_as_group, + db.session, ) send_email_smtp.assert_called_once() @@ -510,6 +513,7 @@ def test_deliver_slice_csv_inline( schedule.delivery_type, schedule.email_format, schedule.deliver_as_group, + db.session, ) send_email_smtp.assert_called_once()