diff --git a/superset/dao/base.py b/superset/dao/base.py index 6b33c4e638d4f..abfa4ac61ccac 100644 --- a/superset/dao/base.py +++ b/superset/dao/base.py @@ -20,6 +20,7 @@ from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla.interface import SQLAInterface from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session from superset.dao.exceptions import ( DAOConfigError, @@ -46,13 +47,14 @@ class BaseDAO: """ @classmethod - def find_by_id(cls, model_id: int) -> Model: + def find_by_id(cls, model_id: int, session: Session = None) -> Model: """ Find a model by id, if defined applies `base_filter` """ - query = db.session.query(cls.model_cls) + session = session or db.session + query = session.query(cls.model_cls) if cls.base_filter: - data_model = SQLAInterface(cls.model_cls, db.session) + data_model = SQLAInterface(cls.model_cls, session) query = cls.base_filter( # pylint: disable=not-callable "id", data_model ).apply(query, None) diff --git a/superset/exceptions.py b/superset/exceptions.py index c0d55f8924426..fd95a59e2cfa7 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -25,7 +25,9 @@ class SupersetException(Exception): status = 500 message = "" - def __init__(self, message: str = "", exception: Optional[Exception] = None): + def __init__( + self, message: str = "", exception: Optional[Exception] = None, + ) -> None: if message: self.message = message self._exception = exception diff --git a/superset/models/reports.py b/superset/models/reports.py index 731d1f9629166..7b1f1831399d9 100644 --- a/superset/models/reports.py +++ b/superset/models/reports.py @@ -60,7 +60,9 @@ class ReportRecipientType(str, enum.Enum): class ReportLogState(str, enum.Enum): SUCCESS = "Success" + WORKING = "Working" ERROR = "Error" + NOOP = "Not triggered" class ReportEmailFormat(str, enum.Enum): @@ -175,6 +177,6 @@ class ReportExecutionLog(Model): # pylint: disable=too-few-public-methods ) report_schedule = relationship( ReportSchedule, - backref=backref("logs", cascade="all,delete"), + backref=backref("logs", cascade="all,delete,delete-orphan"), foreign_keys=[report_schedule_id], ) diff --git a/superset/reports/commands/alert.py b/superset/reports/commands/alert.py new file mode 100644 index 0000000000000..cab294ce5c29b --- /dev/null +++ b/superset/reports/commands/alert.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import logging +from operator import eq, ge, gt, le, lt, ne +from typing import Optional + +import numpy as np +from flask_babel import lazy_gettext as _ + +from superset import jinja_context +from superset.commands.base import BaseCommand +from superset.models.reports import ReportSchedule, ReportScheduleValidatorType +from superset.reports.commands.exceptions import ( + AlertQueryInvalidTypeError, + AlertQueryMultipleColumnsError, + AlertQueryMultipleRowsError, +) + +logger = logging.getLogger(__name__) + + +OPERATOR_FUNCTIONS = {">=": ge, ">": gt, "<=": le, "<": lt, "==": eq, "!=": ne} + + +class AlertCommand(BaseCommand): + def __init__(self, report_schedule: ReportSchedule): + self._report_schedule = report_schedule + self._result: Optional[float] = None + + def run(self) -> bool: + self.validate() + + if self._report_schedule.validator_type == ReportScheduleValidatorType.NOT_NULL: + self._report_schedule.last_value_row_json = self._result + return self._result not in (0, None, np.nan) + self._report_schedule.last_value = self._result + operator = json.loads(self._report_schedule.validator_config_json)["op"] + threshold = json.loads(self._report_schedule.validator_config_json)["threshold"] + return OPERATOR_FUNCTIONS[operator](self._result, threshold) + + def _validate_not_null(self, rows: np.recarray) -> None: + self._result = rows[0][1] + + def _validate_operator(self, rows: np.recarray) -> None: + # check if query return more then one row + if len(rows) > 1: + raise AlertQueryMultipleRowsError( + message=_( + "Alert query returned more then one row. %s rows returned" + % len(rows), + ) + ) + # check if query returned more then one column + if len(rows[0]) > 2: + raise AlertQueryMultipleColumnsError( + _( + "Alert query returned more then one column. %s columns returned" + % len(rows[0]) + ) + ) + if rows[0][1] is None: + return + try: + # Check if it's float or if we can convert it + self._result = float(rows[0][1]) + return + except (AssertionError, TypeError, ValueError): + raise AlertQueryInvalidTypeError() + + def validate(self) -> None: + """ + Validate the query result as a Pandas DataFrame + """ + sql_template = jinja_context.get_template_processor( + database=self._report_schedule.database + ) + rendered_sql = sql_template.process_template(self._report_schedule.sql) + df = self._report_schedule.database.get_df(rendered_sql) + + if df.empty: + return + rows = df.to_records() + if self._report_schedule.validator_type == ReportScheduleValidatorType.NOT_NULL: + self._validate_not_null(rows) + return + self._validate_operator(rows) diff --git a/superset/reports/commands/exceptions.py b/superset/reports/commands/exceptions.py index 23a21425bd92b..3a56a49b9d7f2 100644 --- a/superset/reports/commands/exceptions.py +++ b/superset/reports/commands/exceptions.py @@ -103,6 +103,22 @@ class ReportScheduleDeleteFailedError(CommandException): message = _("Report Schedule delete failed.") +class PruneReportScheduleLogFailedError(CommandException): + message = _("Report Schedule log prune failed.") + + +class ReportScheduleScreenshotFailedError(CommandException): + message = _("Report Schedule execution failed when generating a screenshot.") + + +class ReportScheduleExecuteUnexpectedError(CommandException): + message = _("Report Schedule execution got an unexpected error.") + + +class ReportSchedulePreviousWorkingError(CommandException): + message = _("Report Schedule is still working, refusing to re-compute.") + + class ReportScheduleNameUniquenessValidationError(ValidationError): """ Marshmallow validation error for Report Schedule name already exists @@ -110,3 +126,24 @@ class ReportScheduleNameUniquenessValidationError(ValidationError): def __init__(self) -> None: super().__init__([_("Name must be unique")], field_name="name") + + +class AlertQueryMultipleRowsError(CommandException): + + message = _("Alert query returned more then one row.") + + +class AlertQueryMultipleColumnsError(CommandException): + message = _("Alert query returned more then one column.") + + +class AlertQueryInvalidTypeError(CommandException): + message = _("Alert query returned a non-number value.") + + +class ReportScheduleAlertGracePeriodError(CommandException): + message = _("Alert fired during grace period.") + + +class ReportScheduleNotificationError(CommandException): + message = _("Alert on grace period") diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py new file mode 100644 index 0000000000000..bb3384702d1f3 --- /dev/null +++ b/superset/reports/commands/execute.py @@ -0,0 +1,256 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from datetime import datetime, timedelta +from typing import Optional + +from sqlalchemy.orm import Session + +from superset import app, thumbnail_cache +from superset.commands.base import BaseCommand +from superset.commands.exceptions import CommandException +from superset.extensions import security_manager +from superset.models.reports import ( + ReportExecutionLog, + ReportLogState, + ReportSchedule, + ReportScheduleType, +) +from superset.reports.commands.alert import AlertCommand +from superset.reports.commands.exceptions import ( + ReportScheduleAlertGracePeriodError, + ReportScheduleExecuteUnexpectedError, + ReportScheduleNotFoundError, + ReportScheduleNotificationError, + ReportSchedulePreviousWorkingError, + ReportScheduleScreenshotFailedError, +) +from superset.reports.dao import ReportScheduleDAO +from superset.reports.notifications import create_notification +from superset.reports.notifications.base import NotificationContent, ScreenshotData +from superset.reports.notifications.exceptions import NotificationError +from superset.utils.celery import session_scope +from superset.utils.screenshots import ( + BaseScreenshot, + ChartScreenshot, + DashboardScreenshot, +) +from superset.utils.urls import get_url_path + +logger = logging.getLogger(__name__) + + +class AsyncExecuteReportScheduleCommand(BaseCommand): + """ + Execute all types of report schedules. + - On reports takes chart or dashboard screenshots and sends configured notifications + - On Alerts uses related Command AlertCommand and sends configured notifications + """ + + def __init__(self, model_id: int, scheduled_dttm: datetime): + self._model_id = model_id + self._model: Optional[ReportSchedule] = None + self._scheduled_dttm = scheduled_dttm + + def set_state_and_log( + self, + session: Session, + start_dttm: datetime, + state: ReportLogState, + error_message: Optional[str] = None, + ) -> None: + """ + Updates current ReportSchedule state and TS. If on final state writes the log + for this execution + """ + now_dttm = datetime.utcnow() + if state == ReportLogState.WORKING: + self.set_state(session, state, now_dttm) + return + self.set_state(session, state, now_dttm) + self.create_log( + session, start_dttm, now_dttm, state, error_message=error_message, + ) + + def set_state( + self, session: Session, state: ReportLogState, dttm: datetime + ) -> None: + """ + Set the current report schedule state, on this case we want to + commit immediately + """ + if self._model: + self._model.last_state = state + self._model.last_eval_dttm = dttm + session.commit() + + def create_log( # pylint: disable=too-many-arguments + self, + session: Session, + start_dttm: datetime, + end_dttm: datetime, + state: ReportLogState, + error_message: Optional[str] = None, + ) -> None: + """ + Creates a Report execution log, uses the current computed last_value for Alerts + """ + if self._model: + log = ReportExecutionLog( + scheduled_dttm=self._scheduled_dttm, + start_dttm=start_dttm, + end_dttm=end_dttm, + value=self._model.last_value, + value_row_json=self._model.last_value_row_json, + state=state, + error_message=error_message, + report_schedule=self._model, + ) + session.add(log) + + @staticmethod + def _get_url(report_schedule: ReportSchedule, user_friendly: bool = False) -> str: + """ + Get the url for this report schedule: chart or dashboard + """ + if report_schedule.chart: + return get_url_path( + "Superset.slice", + user_friendly=user_friendly, + slice_id=report_schedule.chart_id, + standalone="true", + ) + return get_url_path( + "Superset.dashboard", + user_friendly=user_friendly, + dashboard_id_or_slug=report_schedule.dashboard_id, + ) + + def _get_screenshot(self, report_schedule: ReportSchedule) -> ScreenshotData: + """ + Get a chart or dashboard screenshot + :raises: ReportScheduleScreenshotFailedError + """ + url = self._get_url(report_schedule) + screenshot: Optional[BaseScreenshot] = None + if report_schedule.chart: + screenshot = ChartScreenshot(url, report_schedule.chart.digest) + else: + screenshot = DashboardScreenshot(url, report_schedule.dashboard.digest) + image_url = self._get_url(report_schedule, user_friendly=True) + user = security_manager.find_user(app.config["THUMBNAIL_SELENIUM_USER"]) + image_data = screenshot.compute_and_cache( + user=user, cache=thumbnail_cache, force=True, + ) + if not image_data: + raise ReportScheduleScreenshotFailedError() + return ScreenshotData(url=image_url, image=image_data) + + def _get_notification_content( + self, report_schedule: ReportSchedule + ) -> NotificationContent: + """ + Gets a notification content, this is composed by a title and a screenshot + :raises: ReportScheduleScreenshotFailedError + """ + screenshot_data = self._get_screenshot(report_schedule) + if report_schedule.chart: + name = report_schedule.chart.slice_name + else: + name = report_schedule.dashboard.dashboard_title + return NotificationContent(name=name, screenshot=screenshot_data) + + def _send(self, report_schedule: ReportSchedule) -> None: + """ + Creates the notification content and sends them to all recipients + + :raises: ReportScheduleNotificationError + """ + notification_errors = [] + notification_content = self._get_notification_content(report_schedule) + for recipient in report_schedule.recipients: + notification = create_notification(recipient, notification_content) + try: + notification.send() + except NotificationError as ex: + # collect notification errors but keep processing them + notification_errors.append(str(ex)) + if notification_errors: + raise ReportScheduleNotificationError(";".join(notification_errors)) + + def run(self) -> None: + with session_scope(nullpool=True) as session: + try: + start_dttm = datetime.utcnow() + self.validate(session=session) + if not self._model: + raise ReportScheduleExecuteUnexpectedError() + self.set_state_and_log(session, start_dttm, ReportLogState.WORKING) + # If it's an alert check if the alert is triggered + if self._model.type == ReportScheduleType.ALERT: + if not AlertCommand(self._model).run(): + self.set_state_and_log(session, start_dttm, ReportLogState.NOOP) + return + + self._send(self._model) + + # Log, state and TS + self.set_state_and_log(session, start_dttm, ReportLogState.SUCCESS) + except ReportScheduleAlertGracePeriodError as ex: + self.set_state_and_log( + session, start_dttm, ReportLogState.NOOP, error_message=str(ex) + ) + except ReportSchedulePreviousWorkingError as ex: + self.create_log( + session, + start_dttm, + datetime.utcnow(), + state=ReportLogState.ERROR, + error_message=str(ex), + ) + session.commit() + raise + except CommandException as ex: + self.set_state_and_log( + session, start_dttm, ReportLogState.ERROR, error_message=str(ex) + ) + # We want to actually commit the state and log inside the scope + session.commit() + raise + + def validate( # pylint: disable=arguments-differ + self, session: Session = None + ) -> None: + # Validate/populate model exists + self._model = ReportScheduleDAO.find_by_id(self._model_id, session=session) + if not self._model: + raise ReportScheduleNotFoundError() + # Avoid overlap processing + if self._model.last_state == ReportLogState.WORKING: + raise ReportSchedulePreviousWorkingError() + # Check grace period + if self._model.type == ReportScheduleType.ALERT: + last_success = ReportScheduleDAO.find_last_success_log(session) + if ( + last_success + and self._model.last_state + in (ReportLogState.SUCCESS, ReportLogState.NOOP) + and self._model.grace_period + and datetime.utcnow() - timedelta(seconds=self._model.grace_period) + < last_success.end_dttm + ): + raise ReportScheduleAlertGracePeriodError() diff --git a/superset/reports/commands/log_prune.py b/superset/reports/commands/log_prune.py new file mode 100644 index 0000000000000..9825a35eef2a8 --- /dev/null +++ b/superset/reports/commands/log_prune.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from datetime import datetime, timedelta + +from superset.commands.base import BaseCommand +from superset.models.reports import ReportSchedule +from superset.reports.dao import ReportScheduleDAO +from superset.utils.celery import session_scope + +logger = logging.getLogger(__name__) + + +class AsyncPruneReportScheduleLogCommand(BaseCommand): + """ + Prunes logs from all report schedules + """ + + def __init__(self, worker_context: bool = True): + self._worker_context = worker_context + + def run(self) -> None: + with session_scope(nullpool=True) as session: + self.validate() + for report_schedule in session.query(ReportSchedule).all(): + from_date = datetime.utcnow() - timedelta( + days=report_schedule.log_retention + ) + ReportScheduleDAO.bulk_delete_logs( + report_schedule, from_date, session=session, commit=False + ) + + def validate(self) -> None: + pass diff --git a/superset/reports/dao.py b/superset/reports/dao.py index e02770af90044..6081fc8efa67e 100644 --- a/superset/reports/dao.py +++ b/superset/reports/dao.py @@ -15,15 +15,22 @@ # specific language governing permissions and limitations # under the License. import logging +from datetime import datetime from typing import Any, Dict, List, Optional from flask_appbuilder import Model from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session from superset.dao.base import BaseDAO from superset.dao.exceptions import DAOCreateFailedError, DAODeleteFailedError from superset.extensions import db -from superset.models.reports import ReportRecipients, ReportSchedule +from superset.models.reports import ( + ReportExecutionLog, + ReportLogState, + ReportRecipients, + ReportSchedule, +) logger = logging.getLogger(__name__) @@ -135,3 +142,49 @@ def update( except SQLAlchemyError: db.session.rollback() raise DAOCreateFailedError + + @staticmethod + def find_active(session: Optional[Session] = None) -> List[ReportSchedule]: + """ + Find all active reports. If session is passed it will be used instead of the + default `db.session`, this is useful when on a celery worker session context + """ + session = session or db.session + return ( + session.query(ReportSchedule).filter(ReportSchedule.active.is_(True)).all() + ) + + @staticmethod + def find_last_success_log( + session: Optional[Session] = None, + ) -> Optional[ReportExecutionLog]: + """ + Finds last success execution log + """ + session = session or db.session + return ( + session.query(ReportExecutionLog) + .filter(ReportExecutionLog.state == ReportLogState.SUCCESS) + .order_by(ReportExecutionLog.end_dttm.desc()) + .first() + ) + + @staticmethod + def bulk_delete_logs( + model: ReportSchedule, + from_date: datetime, + session: Optional[Session] = None, + commit: bool = True, + ) -> None: + session = session or db.session + try: + session.query(ReportExecutionLog).filter( + ReportExecutionLog.report_schedule == model, + ReportExecutionLog.end_dttm < from_date, + ).delete(synchronize_session="fetch") + if commit: + session.commit() + except SQLAlchemyError as ex: + if commit: + session.rollback() + raise ex diff --git a/superset/reports/notifications/__init__.py b/superset/reports/notifications/__init__.py new file mode 100644 index 0000000000000..2553053131690 --- /dev/null +++ b/superset/reports/notifications/__init__.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from superset.models.reports import ReportRecipients +from superset.reports.notifications.base import BaseNotification, NotificationContent +from superset.reports.notifications.email import EmailNotification +from superset.reports.notifications.slack import SlackNotification + + +def create_notification( + recipient: ReportRecipients, screenshot_data: NotificationContent +) -> BaseNotification: + """ + Notification polymorphic factory + Returns the Notification class for the recipient type + """ + for plugin in BaseNotification.plugins: + if plugin.type == recipient.type: + return plugin(recipient, screenshot_data) + raise Exception("Recipient type not supported") diff --git a/superset/reports/notifications/base.py b/superset/reports/notifications/base.py new file mode 100644 index 0000000000000..f55154c1e7430 --- /dev/null +++ b/superset/reports/notifications/base.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from dataclasses import dataclass +from typing import Any, List, Optional, Type + +from superset.models.reports import ReportRecipients, ReportRecipientType + + +@dataclass +class ScreenshotData: + url: str # url to chart/dashboard for this screenshot + image: bytes # bytes for the screenshot + + +@dataclass +class NotificationContent: + name: str + screenshot: ScreenshotData + + +class BaseNotification: # pylint: disable=too-few-public-methods + """ + Serves has base for all notifications and creates a simple plugin system + for extending future implementations. + Child implementations get automatically registered and should identify the + notification type + """ + + plugins: List[Type["BaseNotification"]] = [] + type: Optional[ReportRecipientType] = None + """ + Child classes set their notification type ex: `type = "email"` this string will be + used by ReportRecipients.type to map to the correct implementation + """ + + def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: + super().__init_subclass__(*args, **kwargs) # type: ignore + cls.plugins.append(cls) + + def __init__( + self, recipient: ReportRecipients, content: NotificationContent + ) -> None: + self._recipient = recipient + self._content = content + + def send(self) -> None: + raise NotImplementedError() diff --git a/superset/reports/notifications/email.py b/superset/reports/notifications/email.py new file mode 100644 index 0000000000000..e99a7f43e37da --- /dev/null +++ b/superset/reports/notifications/email.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import logging +from dataclasses import dataclass +from email.utils import make_msgid, parseaddr +from typing import Dict + +from flask_babel import gettext as __ + +from superset import app +from superset.models.reports import ReportRecipientType +from superset.reports.notifications.base import BaseNotification +from superset.reports.notifications.exceptions import NotificationError +from superset.utils.core import send_email_smtp + +logger = logging.getLogger(__name__) + + +@dataclass +class EmailContent: + body: str + images: Dict[str, bytes] + + +class EmailNotification(BaseNotification): # pylint: disable=too-few-public-methods + """ + Sends an email notification for a report recipient + """ + + type = ReportRecipientType.EMAIL + + @staticmethod + def _get_smtp_domain() -> str: + return parseaddr(app.config["SMTP_MAIL_FROM"])[1].split("@")[1] + + def _get_content(self) -> EmailContent: + # Get the domain from the 'From' address .. + # and make a message id without the < > in the ends + domain = self._get_smtp_domain() + msgid = make_msgid(domain)[1:-1] + + image = {msgid: self._content.screenshot.image} + body = __( + """ + Explore in Superset
+ + """, + url=self._content.screenshot.url, + msgid=msgid, + ) + return EmailContent(body=body, images=image) + + def _get_subject(self) -> str: + return __( + "%(prefix)s %(title)s", + prefix=app.config["EMAIL_REPORTS_SUBJECT_PREFIX"], + title=self._content.name, + ) + + def _get_to(self) -> str: + return json.loads(self._recipient.recipient_config_json)["target"] + + def send(self) -> None: + subject = self._get_subject() + content = self._get_content() + to = self._get_to() + try: + send_email_smtp( + to, + subject, + content.body, + app.config, + files=[], + data=None, + images=content.images, + bcc="", + mime_subtype="related", + dryrun=False, + ) + logger.info("Report sent to email") + except Exception as ex: + raise NotificationError(ex) diff --git a/superset/reports/notifications/exceptions.py b/superset/reports/notifications/exceptions.py new file mode 100644 index 0000000000000..749a91fd955b0 --- /dev/null +++ b/superset/reports/notifications/exceptions.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +class NotificationError(Exception): + pass diff --git a/superset/reports/notifications/slack.py b/superset/reports/notifications/slack.py new file mode 100644 index 0000000000000..8e859ffc894fc --- /dev/null +++ b/superset/reports/notifications/slack.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import logging +from io import IOBase +from typing import cast, Optional, Union + +from flask_babel import gettext as __ +from retry.api import retry +from slack import WebClient +from slack.errors import SlackApiError, SlackClientError +from slack.web.slack_response import SlackResponse + +from superset import app +from superset.models.reports import ReportRecipientType +from superset.reports.notifications.base import BaseNotification +from superset.reports.notifications.exceptions import NotificationError + +logger = logging.getLogger(__name__) + + +class SlackNotification(BaseNotification): # pylint: disable=too-few-public-methods + """ + Sends a slack notification for a report recipient + """ + + type = ReportRecipientType.SLACK + + def _get_channel(self) -> str: + return json.loads(self._recipient.recipient_config_json)["target"] + + def _get_body(self) -> str: + return __( + """ + *%(name)s*\n + <%(url)s|Explore in Superset> + """, + name=self._content.name, + url=self._content.screenshot.url, + ) + + def _get_inline_screenshot(self) -> Optional[Union[str, IOBase, bytes]]: + return self._content.screenshot.image + + @retry(SlackApiError, delay=10, backoff=2, tries=5) + def send(self) -> None: + file = self._get_inline_screenshot() + channel = self._get_channel() + body = self._get_body() + + try: + client = WebClient( + token=app.config["SLACK_API_TOKEN"], proxy=app.config["SLACK_PROXY"] + ) + # files_upload returns SlackResponse as we run it in sync mode. + if file: + response = cast( + SlackResponse, + client.files_upload( + channels=channel, + file=file, + initial_comment=body, + title="subject", + ), + ) + assert response["file"], str(response) # the uploaded file + else: + response = cast( + SlackResponse, client.chat_postMessage(channel=channel, text=body), + ) + assert response["message"]["text"], str(response) + logger.info("Report sent to slack") + except SlackClientError as ex: + raise NotificationError(ex) diff --git a/superset/tasks/celery_app.py b/superset/tasks/celery_app.py index 0f3cd0ef558b2..d84273f4ee710 100644 --- a/superset/tasks/celery_app.py +++ b/superset/tasks/celery_app.py @@ -29,7 +29,7 @@ # Need to import late, as the celery_app will have been setup by "create_app()" # pylint: disable=wrong-import-position, unused-import -from . import cache, schedules # isort:skip +from . import cache, schedules, scheduler # isort:skip # Export the celery app globally for Celery (as run on the cmd line) to find app = celery_app diff --git a/superset/tasks/scheduler.py b/superset/tasks/scheduler.py new file mode 100644 index 0000000000000..62398f08df9ff --- /dev/null +++ b/superset/tasks/scheduler.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from datetime import datetime, timedelta +from typing import Iterator + +import croniter + +from superset.commands.exceptions import CommandException +from superset.extensions import celery_app +from superset.reports.commands.execute import AsyncExecuteReportScheduleCommand +from superset.reports.commands.log_prune import AsyncPruneReportScheduleLogCommand +from superset.reports.dao import ReportScheduleDAO +from superset.utils.celery import session_scope + +logger = logging.getLogger(__name__) + + +def cron_schedule_window(cron: str, window_size: int = 10) -> Iterator[datetime]: + utc_now = datetime.utcnow() + start_at = utc_now - timedelta(seconds=1) + stop_at = utc_now + timedelta(seconds=window_size) + crons = croniter.croniter(cron, start_at) + for schedule in crons.all_next(datetime): + if schedule >= stop_at: + break + yield schedule + + +@celery_app.task(name="reports.scheduler") +def scheduler() -> None: + """ + Celery beat main scheduler for reports + """ + with session_scope(nullpool=True) as session: + active_schedules = ReportScheduleDAO.find_active(session) + for active_schedule in active_schedules: + for schedule in cron_schedule_window(active_schedule.crontab): + execute.apply_async((active_schedule.id, schedule,), eta=schedule) + + +@celery_app.task(name="reports.execute") +def execute(report_schedule_id: int, scheduled_dttm: datetime) -> None: + try: + AsyncExecuteReportScheduleCommand(report_schedule_id, scheduled_dttm).run() + except CommandException as ex: + logger.error("An exception occurred while executing the report: %s", ex) + + +@celery_app.task(name="reports.prune_log") +def prune_log() -> None: + try: + AsyncPruneReportScheduleLogCommand().run() + except CommandException as ex: + logger.error("An exception occurred while pruning report schedule logs: %s", ex) diff --git a/superset/utils/urls.py b/superset/utils/urls.py index 905376991dc48..fe9455d27e3c3 100644 --- a/superset/utils/urls.py +++ b/superset/utils/urls.py @@ -20,11 +20,15 @@ from flask import current_app, url_for -def headless_url(path: str) -> str: - base_url = current_app.config.get("WEBDRIVER_BASEURL", "") +def headless_url(path: str, user_friendly: bool = False) -> str: + base_url = ( + current_app.config["WEBDRIVER_BASEURL_USER_FRIENDLY"] + if user_friendly + else current_app.config["WEBDRIVER_BASEURL"] + ) return urllib.parse.urljoin(base_url, path) -def get_url_path(view: str, **kwargs: Any) -> str: +def get_url_path(view: str, user_friendly: bool = False, **kwargs: Any) -> str: with current_app.test_request_context(): - return headless_url(url_for(view, **kwargs)) + return headless_url(url_for(view, **kwargs), user_friendly=user_friendly) diff --git a/tests/reports/api_tests.py b/tests/reports/api_tests.py index eb5425dade989..26dbe165768a8 100644 --- a/tests/reports/api_tests.py +++ b/tests/reports/api_tests.py @@ -40,6 +40,7 @@ ) from tests.base_tests import SupersetTestCase +from tests.reports.utils import insert_report_schedule from superset.utils.core import get_example_database @@ -47,48 +48,6 @@ class TestReportSchedulesApi(SupersetTestCase): - def insert_report_schedule( - self, - type: str, - name: str, - crontab: str, - sql: Optional[str] = None, - description: Optional[str] = None, - chart: Optional[Slice] = None, - dashboard: Optional[Dashboard] = None, - database: Optional[Database] = None, - owners: Optional[List[User]] = None, - validator_type: Optional[str] = None, - validator_config_json: Optional[str] = None, - log_retention: Optional[int] = None, - grace_period: Optional[int] = None, - recipients: Optional[List[ReportRecipients]] = None, - logs: Optional[List[ReportExecutionLog]] = None, - ) -> ReportSchedule: - owners = owners or [] - recipients = recipients or [] - logs = logs or [] - report_schedule = ReportSchedule( - type=type, - name=name, - crontab=crontab, - sql=sql, - description=description, - chart=chart, - dashboard=dashboard, - database=database, - owners=owners, - validator_type=validator_type, - validator_config_json=validator_config_json, - log_retention=log_retention, - grace_period=grace_period, - recipients=recipients, - logs=logs, - ) - db.session.add(report_schedule) - db.session.commit() - return report_schedule - @pytest.fixture() def create_report_schedules(self): with self.create_app().app_context(): @@ -116,7 +75,7 @@ def create_report_schedules(self): ) ) report_schedules.append( - self.insert_report_schedule( + insert_report_schedule( type=ReportScheduleType.ALERT, name=f"name{cx}", crontab=f"*/{cx} * * * *", @@ -169,10 +128,6 @@ def test_get_report_schedule(self): "last_value_row_json": report_schedule.last_value_row_json, "log_retention": report_schedule.log_retention, "name": report_schedule.name, - "owners": [ - {"first_name": "admin", "id": 1, "last_name": "user"}, - {"first_name": "alpha", "id": 5, "last_name": "user"}, - ], "recipients": [ { "id": report_schedule.recipients[0].id, @@ -184,7 +139,16 @@ def test_get_report_schedule(self): "validator_config_json": report_schedule.validator_config_json, "validator_type": report_schedule.validator_type, } - assert data["result"] == expected_result + for key in expected_result: + assert data["result"][key] == expected_result[key] + # needed because order may vary + assert {"first_name": "admin", "id": 1, "last_name": "user"} in data["result"][ + "owners" + ] + assert {"first_name": "alpha", "id": 5, "last_name": "user"} in data["result"][ + "owners" + ] + assert len(data["result"]["owners"]) == 2 def test_info_report_schedule(self): """ diff --git a/tests/reports/commands_tests.py b/tests/reports/commands_tests.py new file mode 100644 index 0000000000000..d5566944e537f --- /dev/null +++ b/tests/reports/commands_tests.py @@ -0,0 +1,531 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +from datetime import datetime +from typing import List, Optional +from unittest.mock import patch + +import pytest +from contextlib2 import contextmanager +from freezegun import freeze_time +from sqlalchemy.sql import func + +from superset import db +from superset.models.core import Database +from superset.models.dashboard import Dashboard +from superset.models.reports import ( + ReportExecutionLog, + ReportLogState, + ReportRecipients, + ReportRecipientType, + ReportSchedule, + ReportScheduleType, + ReportScheduleValidatorType, +) +from superset.models.slice import Slice +from superset.reports.commands.exceptions import ( + AlertQueryMultipleColumnsError, + AlertQueryMultipleRowsError, + ReportScheduleNotFoundError, + ReportScheduleNotificationError, + ReportSchedulePreviousWorkingError, +) +from superset.reports.commands.execute import AsyncExecuteReportScheduleCommand +from superset.utils.core import get_example_database +from tests.reports.utils import insert_report_schedule +from tests.test_app import app +from tests.utils import read_fixture + + +def get_target_from_report_schedule(report_schedule) -> List[str]: + return [ + json.loads(recipient.recipient_config_json)["target"] + for recipient in report_schedule.recipients + ] + + +def assert_log(state: str, error_message: Optional[str] = None): + db.session.commit() + logs = db.session.query(ReportExecutionLog).all() + assert len(logs) == 1 + assert logs[0].error_message == error_message + assert logs[0].state == state + + +def create_report_notification( + email_target: Optional[str] = None, + slack_channel: Optional[str] = None, + chart: Optional[Slice] = None, + dashboard: Optional[Dashboard] = None, + database: Optional[Database] = None, + sql: Optional[str] = None, + report_type: Optional[str] = None, + validator_type: Optional[str] = None, + validator_config_json: Optional[str] = None, +) -> ReportSchedule: + report_type = report_type or ReportScheduleType.REPORT + target = email_target or slack_channel + config_json = {"target": target} + if slack_channel: + recipient = ReportRecipients( + type=ReportRecipientType.SLACK, + recipient_config_json=json.dumps(config_json), + ) + else: + recipient = ReportRecipients( + type=ReportRecipientType.EMAIL, + recipient_config_json=json.dumps(config_json), + ) + + report_schedule = insert_report_schedule( + type=report_type, + name=f"report", + crontab=f"0 9 * * *", + description=f"Daily report", + sql=sql, + chart=chart, + dashboard=dashboard, + database=database, + recipients=[recipient], + validator_type=validator_type, + validator_config_json=validator_config_json, + ) + return report_schedule + + +@pytest.yield_fixture() +def create_report_email_chart(): + with app.app_context(): + chart = db.session.query(Slice).first() + report_schedule = create_report_notification( + email_target="target@email.com", chart=chart + ) + yield report_schedule + + db.session.delete(report_schedule) + db.session.commit() + + +@pytest.yield_fixture() +def create_report_email_dashboard(): + with app.app_context(): + dashboard = db.session.query(Dashboard).first() + report_schedule = create_report_notification( + email_target="target@email.com", dashboard=dashboard + ) + yield report_schedule + + db.session.delete(report_schedule) + db.session.commit() + + +@pytest.yield_fixture() +def create_report_slack_chart(): + with app.app_context(): + chart = db.session.query(Slice).first() + report_schedule = create_report_notification( + slack_channel="slack_channel", chart=chart + ) + yield report_schedule + + db.session.delete(report_schedule) + db.session.commit() + + +@pytest.yield_fixture() +def create_report_slack_chart_working(): + with app.app_context(): + chart = db.session.query(Slice).first() + report_schedule = create_report_notification( + slack_channel="slack_channel", chart=chart + ) + report_schedule.last_state = ReportLogState.WORKING + db.session.commit() + yield report_schedule + + db.session.delete(report_schedule) + db.session.commit() + + +@pytest.yield_fixture( + params=["alert1", "alert2", "alert3", "alert4", "alert5", "alert6", "alert7"] +) +def create_alert_email_chart(request): + param_config = { + "alert1": { + "sql": "SELECT 10 as metric", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": ">", "threshold": 9}', + }, + "alert2": { + "sql": "SELECT 10 as metric", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": ">=", "threshold": 10}', + }, + "alert3": { + "sql": "SELECT 10 as metric", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": "<", "threshold": 11}', + }, + "alert4": { + "sql": "SELECT 10 as metric", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": "<=", "threshold": 10}', + }, + "alert5": { + "sql": "SELECT 10 as metric", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": "!=", "threshold": 11}', + }, + "alert6": { + "sql": "SELECT 'something' as metric", + "validator_type": ReportScheduleValidatorType.NOT_NULL, + "validator_config_json": "{}", + }, + "alert7": { + "sql": "SELECT {{ 5 + 5 }} as metric", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": "!=", "threshold": 11}', + }, + } + with app.app_context(): + chart = db.session.query(Slice).first() + example_database = get_example_database() + + report_schedule = create_report_notification( + email_target="target@email.com", + chart=chart, + report_type=ReportScheduleType.ALERT, + database=example_database, + sql=param_config[request.param]["sql"], + validator_type=param_config[request.param]["validator_type"], + validator_config_json=param_config[request.param]["validator_config_json"], + ) + yield report_schedule + + db.session.delete(report_schedule) + db.session.commit() + + +@contextmanager +def create_test_table_context(database: Database): + database.get_sqla_engine().execute( + "CREATE TABLE test_table AS SELECT 1 as first, 2 as second" + ) + database.get_sqla_engine().execute( + "INSERT INTO test_table (first, second) VALUES (1, 2)" + ) + database.get_sqla_engine().execute( + "INSERT INTO test_table (first, second) VALUES (3, 4)" + ) + + yield db.session + database.get_sqla_engine().execute("DROP TABLE test_table") + + +@pytest.yield_fixture( + params=["alert1", "alert2", "alert3", "alert4", "alert5", "alert6"] +) +def create_no_alert_email_chart(request): + param_config = { + "alert1": { + "sql": "SELECT 10 as metric", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": "<", "threshold": 10}', + }, + "alert2": { + "sql": "SELECT 10 as metric", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": ">=", "threshold": 11}', + }, + "alert3": { + "sql": "SELECT 10 as metric", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": "<", "threshold": 10}', + }, + "alert4": { + "sql": "SELECT 10 as metric", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": "<=", "threshold": 9}', + }, + "alert5": { + "sql": "SELECT 10 as metric", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": "!=", "threshold": 10}', + }, + "alert6": { + "sql": "SELECT first from test_table where first=0", + "validator_type": ReportScheduleValidatorType.NOT_NULL, + "validator_config_json": "{}", + }, + } + with app.app_context(): + chart = db.session.query(Slice).first() + example_database = get_example_database() + with create_test_table_context(example_database): + + report_schedule = create_report_notification( + email_target="target@email.com", + chart=chart, + report_type=ReportScheduleType.ALERT, + database=example_database, + sql=param_config[request.param]["sql"], + validator_type=param_config[request.param]["validator_type"], + validator_config_json=param_config[request.param][ + "validator_config_json" + ], + ) + yield report_schedule + + db.session.delete(report_schedule) + db.session.commit() + + +@pytest.yield_fixture(params=["alert1", "alert2"]) +def create_mul_alert_email_chart(request): + param_config = { + "alert1": { + "sql": "SELECT first from test_table", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": "<", "threshold": 10}', + }, + "alert2": { + "sql": "SELECT first, second from test_table", + "validator_type": ReportScheduleValidatorType.OPERATOR, + "validator_config_json": '{"op": "<", "threshold": 10}', + }, + } + with app.app_context(): + chart = db.session.query(Slice).first() + example_database = get_example_database() + with create_test_table_context(example_database): + + report_schedule = create_report_notification( + email_target="target@email.com", + chart=chart, + report_type=ReportScheduleType.ALERT, + database=example_database, + sql=param_config[request.param]["sql"], + validator_type=param_config[request.param]["validator_type"], + validator_config_json=param_config[request.param][ + "validator_config_json" + ], + ) + yield report_schedule + + # needed for MySQL + logs = ( + db.session.query(ReportExecutionLog) + .filter(ReportExecutionLog.report_schedule == report_schedule) + .all() + ) + for log in logs: + db.session.delete(log) + db.session.commit() + db.session.delete(report_schedule) + db.session.commit() + + +@pytest.mark.usefixtures("create_report_email_chart") +@patch("superset.reports.notifications.email.send_email_smtp") +@patch("superset.utils.screenshots.ChartScreenshot.compute_and_cache") +def test_email_chart_report_schedule( + screenshot_mock, email_mock, create_report_email_chart +): + """ + ExecuteReport Command: Test chart email report schedule + """ + # setup screenshot mock + screenshot = read_fixture("sample.png") + screenshot_mock.return_value = screenshot + + with freeze_time("2020-01-01T00:00:00Z"): + AsyncExecuteReportScheduleCommand( + create_report_email_chart.id, datetime.utcnow() + ).run() + + notification_targets = get_target_from_report_schedule( + create_report_email_chart + ) + # Assert the email smtp address + assert email_mock.call_args[0][0] == notification_targets[0] + # Assert the email inline screenshot + smtp_images = email_mock.call_args[1]["images"] + assert smtp_images[list(smtp_images.keys())[0]] == screenshot + # Assert logs are correct + assert_log(ReportLogState.SUCCESS) + + +@pytest.mark.usefixtures("create_report_email_dashboard") +@patch("superset.reports.notifications.email.send_email_smtp") +@patch("superset.utils.screenshots.DashboardScreenshot.compute_and_cache") +def test_email_dashboard_report_schedule( + screenshot_mock, email_mock, create_report_email_dashboard +): + """ + ExecuteReport Command: Test dashboard email report schedule + """ + # setup screenshot mock + screenshot = read_fixture("sample.png") + screenshot_mock.return_value = screenshot + + with freeze_time("2020-01-01T00:00:00Z"): + AsyncExecuteReportScheduleCommand( + create_report_email_dashboard.id, datetime.utcnow() + ).run() + + notification_targets = get_target_from_report_schedule( + create_report_email_dashboard + ) + # Assert the email smtp address + assert email_mock.call_args[0][0] == notification_targets[0] + # Assert the email inline screenshot + smtp_images = email_mock.call_args[1]["images"] + assert smtp_images[list(smtp_images.keys())[0]] == screenshot + # Assert logs are correct + assert_log(ReportLogState.SUCCESS) + + +@pytest.mark.usefixtures("create_report_slack_chart") +@patch("superset.reports.notifications.slack.WebClient.files_upload") +@patch("superset.utils.screenshots.ChartScreenshot.compute_and_cache") +def test_slack_chart_report_schedule( + screenshot_mock, file_upload_mock, create_report_slack_chart +): + """ + ExecuteReport Command: Test chart slack report schedule + """ + # setup screenshot mock + screenshot = read_fixture("sample.png") + screenshot_mock.return_value = screenshot + + with freeze_time("2020-01-01T00:00:00Z"): + AsyncExecuteReportScheduleCommand( + create_report_slack_chart.id, datetime.utcnow() + ).run() + + notification_targets = get_target_from_report_schedule( + create_report_slack_chart + ) + assert file_upload_mock.call_args[1]["channels"] == notification_targets[0] + assert file_upload_mock.call_args[1]["file"] == screenshot + + # Assert logs are correct + assert_log(ReportLogState.SUCCESS) + + +@pytest.mark.usefixtures("create_report_slack_chart") +def test_report_schedule_not_found(create_report_slack_chart): + """ + ExecuteReport Command: Test report schedule not found + """ + max_id = db.session.query(func.max(ReportSchedule.id)).scalar() + with pytest.raises(ReportScheduleNotFoundError): + AsyncExecuteReportScheduleCommand(max_id + 1, datetime.utcnow()).run() + + +@pytest.mark.usefixtures("create_report_slack_chart_working") +def test_report_schedule_working(create_report_slack_chart_working): + """ + ExecuteReport Command: Test report schedule still working + """ + # setup screenshot mock + with pytest.raises(ReportSchedulePreviousWorkingError): + AsyncExecuteReportScheduleCommand( + create_report_slack_chart_working.id, datetime.utcnow() + ).run() + + assert_log( + ReportLogState.ERROR, error_message=ReportSchedulePreviousWorkingError.message + ) + assert create_report_slack_chart_working.last_state == ReportLogState.WORKING + + +@pytest.mark.usefixtures("create_report_email_dashboard") +@patch("superset.reports.notifications.email.send_email_smtp") +@patch("superset.utils.screenshots.DashboardScreenshot.compute_and_cache") +def test_email_dashboard_report_fails( + screenshot_mock, email_mock, create_report_email_dashboard +): + """ + ExecuteReport Command: Test dashboard email report schedule notification fails + """ + # setup screenshot mock + from smtplib import SMTPException + + screenshot = read_fixture("sample.png") + screenshot_mock.return_value = screenshot + email_mock.side_effect = SMTPException("Could not connect to SMTP XPTO") + + with pytest.raises(ReportScheduleNotificationError): + AsyncExecuteReportScheduleCommand( + create_report_email_dashboard.id, datetime.utcnow() + ).run() + + assert_log(ReportLogState.ERROR, error_message="Could not connect to SMTP XPTO") + + +@pytest.mark.usefixtures("create_alert_email_chart") +@patch("superset.reports.notifications.email.send_email_smtp") +@patch("superset.utils.screenshots.ChartScreenshot.compute_and_cache") +def test_slack_chart_alert(screenshot_mock, email_mock, create_alert_email_chart): + """ + ExecuteReport Command: Test chart slack alert + """ + # setup screenshot mock + screenshot = read_fixture("sample.png") + screenshot_mock.return_value = screenshot + + with freeze_time("2020-01-01T00:00:00Z"): + AsyncExecuteReportScheduleCommand( + create_alert_email_chart.id, datetime.utcnow() + ).run() + + notification_targets = get_target_from_report_schedule(create_alert_email_chart) + # Assert the email smtp address + assert email_mock.call_args[0][0] == notification_targets[0] + # Assert the email inline screenshot + smtp_images = email_mock.call_args[1]["images"] + assert smtp_images[list(smtp_images.keys())[0]] == screenshot + # Assert logs are correct + assert_log(ReportLogState.SUCCESS) + + +@pytest.mark.usefixtures("create_no_alert_email_chart") +def test_email_chart_no_alert(create_no_alert_email_chart): + """ + ExecuteReport Command: Test chart email no alert + """ + with freeze_time("2020-01-01T00:00:00Z"): + AsyncExecuteReportScheduleCommand( + create_no_alert_email_chart.id, datetime.utcnow() + ).run() + assert_log(ReportLogState.NOOP) + + +@pytest.mark.usefixtures("create_mul_alert_email_chart") +def test_email_mul_alert(create_mul_alert_email_chart): + """ + ExecuteReport Command: Test chart email multiple rows + """ + with freeze_time("2020-01-01T00:00:00Z"): + with pytest.raises( + (AlertQueryMultipleRowsError, AlertQueryMultipleColumnsError) + ): + AsyncExecuteReportScheduleCommand( + create_mul_alert_email_chart.id, datetime.utcnow() + ).run() diff --git a/tests/reports/utils.py b/tests/reports/utils.py new file mode 100644 index 0000000000000..841ae4d975207 --- /dev/null +++ b/tests/reports/utils.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Optional + +from flask_appbuilder.security.sqla.models import User + +from superset import db +from superset.models.core import Database +from superset.models.dashboard import Dashboard +from superset.models.reports import ReportExecutionLog, ReportRecipients, ReportSchedule +from superset.models.slice import Slice + + +def insert_report_schedule( + type: str, + name: str, + crontab: str, + sql: Optional[str] = None, + description: Optional[str] = None, + chart: Optional[Slice] = None, + dashboard: Optional[Dashboard] = None, + database: Optional[Database] = None, + owners: Optional[List[User]] = None, + validator_type: Optional[str] = None, + validator_config_json: Optional[str] = None, + log_retention: Optional[int] = None, + grace_period: Optional[int] = None, + recipients: Optional[List[ReportRecipients]] = None, + logs: Optional[List[ReportExecutionLog]] = None, +) -> ReportSchedule: + owners = owners or [] + recipients = recipients or [] + logs = logs or [] + report_schedule = ReportSchedule( + type=type, + name=name, + crontab=crontab, + sql=sql, + description=description, + chart=chart, + dashboard=dashboard, + database=database, + owners=owners, + validator_type=validator_type, + validator_config_json=validator_config_json, + log_retention=log_retention, + grace_period=grace_period, + recipients=recipients, + logs=logs, + ) + db.session.add(report_schedule) + db.session.commit() + return report_schedule