diff --git a/1 b/1 new file mode 100644 index 0000000000000..56866025994a3 --- /dev/null +++ b/1 @@ -0,0 +1,688 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file +"""Unit tests for Superset""" +import json +import unittest +from tests.integration_tests.fixtures.birth_names_dashboard import ( + load_birth_names_dashboard_with_slices, + load_birth_names_data, +) + +import pytest +from flask import g +from sqlalchemy.orm.session import make_transient + +from tests.integration_tests.fixtures.energy_dashboard import ( + load_energy_table_with_slice, + load_energy_table_data, +) +from tests.integration_tests.test_app import app +from superset.commands.dashboard.importers.v0 import decode_dashboards +from superset import db, security_manager + +from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.commands.dashboard.importers.v0 import import_chart, import_dashboard +from superset.commands.dataset.importers.v0 import import_dataset +from superset.models.dashboard import Dashboard +from superset.models.slice import Slice +from superset.utils.core import DatasourceType, get_example_default_schema +from superset.utils.database import get_example_database + +from tests.integration_tests.fixtures.world_bank_dashboard import ( + load_world_bank_dashboard_with_slices, + load_world_bank_data, +) +from .base_tests import SupersetTestCase + + +def delete_imports(): + with app.app_context(): + # Imported data clean up + session = db.session + for slc in session.query(Slice): + if "remote_id" in slc.params_dict: + session.delete(slc) + for dash in session.query(Dashboard): + if "remote_id" in dash.params_dict: + session.delete(dash) + for table in session.query(SqlaTable): + if "remote_id" in table.params_dict: + session.delete(table) + session.commit() + + +@pytest.fixture(autouse=True, scope="module") +def clean_imports(): + yield + delete_imports() + + +class TestImportExport(SupersetTestCase): + """Testing export import functionality for dashboards""" + + def create_slice( + self, + name, + ds_id=None, + id=None, + db_name="examples", + table_name="wb_health_population", + schema=None, + ): + params = { + "num_period_compare": "10", + "remote_id": id, + "datasource_name": table_name, + "database_name": db_name, + "schema": schema, + # Test for trailing commas + "metrics": ["sum__signup_attempt_email", "sum__signup_attempt_facebook"], + } + + if table_name and not ds_id: + table = self.get_table(schema=schema, name=table_name) + if table: + ds_id = table.id + + return Slice( + slice_name=name, + datasource_type=DatasourceType.TABLE, + viz_type="bubble", + params=json.dumps(params), + datasource_id=ds_id, + id=id, + ) + + def create_dashboard(self, title, id=0, slcs=[]): + json_metadata = {"remote_id": id} + return Dashboard( + id=id, + dashboard_title=title, + slices=slcs, + position_json='{"size_y": 2, "size_x": 2}', + slug=f"{title.lower()}_imported", + json_metadata=json.dumps(json_metadata), + published=False, + ) + + def create_table(self, name, schema=None, id=0, cols_names=[], metric_names=[]): + params = {"remote_id": id, "database_name": "examples"} + table = SqlaTable( + id=id, + schema=schema, + table_name=name, + params=json.dumps(params), + ) + for col_name in cols_names: + table.columns.append(TableColumn(column_name=col_name)) + for metric_name in metric_names: + table.metrics.append(SqlMetric(metric_name=metric_name, expression="")) + return table + + def get_slice(self, slc_id): + return db.session.query(Slice).filter_by(id=slc_id).first() + + def get_slice_by_name(self, name): + return db.session.query(Slice).filter_by(slice_name=name).first() + + def get_dash(self, dash_id): + return db.session.query(Dashboard).filter_by(id=dash_id).first() + + def assert_dash_equals( + self, expected_dash, actual_dash, check_position=True, check_slugs=True + ): + if check_slugs: + self.assertEqual(expected_dash.slug, actual_dash.slug) + self.assertEqual(expected_dash.dashboard_title, actual_dash.dashboard_title) + self.assertEqual(len(expected_dash.slices), len(actual_dash.slices)) + expected_slices = sorted(expected_dash.slices, key=lambda s: s.slice_name or "") + actual_slices = sorted(actual_dash.slices, key=lambda s: s.slice_name or "") + for e_slc, a_slc in zip(expected_slices, actual_slices): + self.assert_slice_equals(e_slc, a_slc) + if check_position: + self.assertEqual(expected_dash.position_json, actual_dash.position_json) + + def assert_table_equals(self, expected_ds, actual_ds): + self.assertEqual(expected_ds.table_name, actual_ds.table_name) + self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col) + self.assertEqual(expected_ds.schema, actual_ds.schema) + self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) + self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) + self.assertEqual( + {c.column_name for c in expected_ds.columns}, + {c.column_name for c in actual_ds.columns}, + ) + self.assertEqual( + {m.metric_name for m in expected_ds.metrics}, + {m.metric_name for m in actual_ds.metrics}, + ) + + def assert_datasource_equals(self, expected_ds, actual_ds): + self.assertEqual(expected_ds.datasource_name, actual_ds.datasource_name) + self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col) + self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) + self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) + self.assertEqual( + {c.column_name for c in expected_ds.columns}, + {c.column_name for c in actual_ds.columns}, + ) + self.assertEqual( + {m.metric_name for m in expected_ds.metrics}, + {m.metric_name for m in actual_ds.metrics}, + ) + + def assert_slice_equals(self, expected_slc, actual_slc): + # to avoid bad slice data (no slice_name) + expected_slc_name = expected_slc.slice_name or "" + actual_slc_name = actual_slc.slice_name or "" + self.assertEqual(expected_slc_name, actual_slc_name) + self.assertEqual(expected_slc.datasource_type, actual_slc.datasource_type) + self.assertEqual(expected_slc.viz_type, actual_slc.viz_type) + exp_params = json.loads(expected_slc.params) + actual_params = json.loads(actual_slc.params) + diff_params_keys = ( + "schema", + "database_name", + "datasource_name", + "remote_id", + "import_time", + ) + for k in diff_params_keys: + if k in actual_params: + actual_params.pop(k) + if k in exp_params: + exp_params.pop(k) + self.assertEqual(exp_params, actual_params) + + def assert_only_exported_slc_fields(self, expected_dash, actual_dash): + """only exported json has this params + imported/created dashboard has relationships to other models instead + """ + expected_slices = sorted(expected_dash.slices, key=lambda s: s.slice_name or "") + actual_slices = sorted(actual_dash.slices, key=lambda s: s.slice_name or "") + for e_slc, a_slc in zip(expected_slices, actual_slices): + params = a_slc.params_dict + self.assertEqual(e_slc.datasource.name, params["datasource_name"]) + self.assertEqual(e_slc.datasource.schema, params["schema"]) + self.assertEqual(e_slc.datasource.database.name, params["database_name"]) + + @unittest.skip("Schema needs to be updated") + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_export_1_dashboard(self): + self.login("admin") + birth_dash = self.get_dash_by_slug("births") + id_ = birth_dash.id + export_dash_url = f"/dashboard/export_dashboards_form?id={id_}&action=go" + resp = self.client.get(export_dash_url) + exported_dashboards = json.loads( + resp.data.decode("utf-8"), object_hook=decode_dashboards + )["dashboards"] + + birth_dash = self.get_dash_by_slug("births") + self.assert_only_exported_slc_fields(birth_dash, exported_dashboards[0]) + self.assert_dash_equals(birth_dash, exported_dashboards[0]) + self.assertEqual( + id_, + json.loads( + exported_dashboards[0].json_metadata, object_hook=decode_dashboards + )["remote_id"], + ) + + exported_tables = json.loads( + resp.data.decode("utf-8"), object_hook=decode_dashboards + )["datasources"] + self.assertEqual(1, len(exported_tables)) + self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0]) + + @unittest.skip("Schema needs to be updated") + @pytest.mark.usefixtures( + "load_world_bank_dashboard_with_slices", + "load_birth_names_dashboard_with_slices", + ) + def test_export_2_dashboards(self): + self.login("admin") + birth_dash = self.get_dash_by_slug("births") + world_health_dash = self.get_dash_by_slug("world_health") + export_dash_url = ( + "/dashboard/export_dashboards_form?id={}&id={}&action=go".format( + birth_dash.id, world_health_dash.id + ) + ) + resp = self.client.get(export_dash_url) + resp_data = json.loads(resp.data.decode("utf-8"), object_hook=decode_dashboards) + exported_dashboards = sorted( + resp_data.get("dashboards"), key=lambda d: d.dashboard_title + ) + self.assertEqual(2, len(exported_dashboards)) + + birth_dash = self.get_dash_by_slug("births") + self.assert_only_exported_slc_fields(birth_dash, exported_dashboards[0]) + self.assert_dash_equals(birth_dash, exported_dashboards[0]) + self.assertEqual( + birth_dash.id, json.loads(exported_dashboards[0].json_metadata)["remote_id"] + ) + + world_health_dash = self.get_dash_by_slug("world_health") + self.assert_only_exported_slc_fields(world_health_dash, exported_dashboards[1]) + self.assert_dash_equals(world_health_dash, exported_dashboards[1]) + self.assertEqual( + world_health_dash.id, + json.loads(exported_dashboards[1].json_metadata)["remote_id"], + ) + + exported_tables = sorted( + resp_data.get("datasources"), key=lambda t: t.table_name + ) + self.assertEqual(2, len(exported_tables)) + self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0]) + self.assert_table_equals( + self.get_table(name="wb_health_population"), exported_tables[1] + ) + + @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") + def test_import_1_slice(self): + expected_slice = self.create_slice( + "Import Me", id=10001, schema=get_example_default_schema() + ) + slc_id = import_chart(expected_slice, None, import_time=1989) + slc = self.get_slice(slc_id) + self.assertEqual(slc.datasource.perm, slc.perm) + self.assert_slice_equals(expected_slice, slc) + + table_id = self.get_table(name="wb_health_population").id + self.assertEqual(table_id, self.get_slice(slc_id).datasource_id) + + @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") + def test_import_2_slices_for_same_table(self): + schema = get_example_default_schema() + table_id = self.get_table(name="wb_health_population").id + slc_1 = self.create_slice( + "Import Me 1", ds_id=table_id, id=10002, schema=schema + ) + slc_id_1 = import_chart(slc_1, None) + slc_2 = self.create_slice( + "Import Me 2", ds_id=table_id, id=10003, schema=schema + ) + slc_id_2 = import_chart(slc_2, None) + + imported_slc_1 = self.get_slice(slc_id_1) + imported_slc_2 = self.get_slice(slc_id_2) + self.assertEqual(table_id, imported_slc_1.datasource_id) + self.assert_slice_equals(slc_1, imported_slc_1) + self.assertEqual(imported_slc_1.datasource.perm, imported_slc_1.perm) + + self.assertEqual(table_id, imported_slc_2.datasource_id) + self.assert_slice_equals(slc_2, imported_slc_2) + self.assertEqual(imported_slc_2.datasource.perm, imported_slc_2.perm) + + def test_import_slices_override(self): + schema = get_example_default_schema() + slc = self.create_slice("Import Me New", id=10005, schema=schema) + slc_1_id = import_chart(slc, None, import_time=1990) + slc.slice_name = "Import Me New" + imported_slc_1 = self.get_slice(slc_1_id) + slc_2 = self.create_slice("Import Me New", id=10005, schema=schema) + slc_2_id = import_chart(slc_2, imported_slc_1, import_time=1990) + self.assertEqual(slc_1_id, slc_2_id) + imported_slc_2 = self.get_slice(slc_2_id) + self.assert_slice_equals(slc, imported_slc_2) + + def test_import_empty_dashboard(self): + empty_dash = self.create_dashboard("empty_dashboard", id=10001) + imported_dash_id = import_dashboard(empty_dash, import_time=1989) + imported_dash = self.get_dash(imported_dash_id) + self.assert_dash_equals(empty_dash, imported_dash, check_position=False) + + @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") + def test_import_dashboard_1_slice(self): + slc = self.create_slice( + "health_slc", id=10006, schema=get_example_default_schema() + ) + dash_with_1_slice = self.create_dashboard( + "dash_with_1_slice", slcs=[slc], id=10002 + ) + dash_with_1_slice.position_json = """ + {{"DASHBOARD_VERSION_KEY": "v2", + "DASHBOARD_CHART_TYPE-{0}": {{ + "type": "CHART", + "id": {0}, + "children": [], + "meta": {{ + "width": 4, + "height": 50, + "chartId": {0} + }} + }} + }} + """.format( + slc.id + ) + imported_dash_id = import_dashboard(dash_with_1_slice, import_time=1990) + imported_dash = self.get_dash(imported_dash_id) + + expected_dash = self.create_dashboard("dash_with_1_slice", slcs=[slc], id=10002) + make_transient(expected_dash) + self.assert_dash_equals( + expected_dash, imported_dash, check_position=False, check_slugs=False + ) + self.assertEqual( + { + "remote_id": 10002, + "import_time": 1990, + "native_filter_configuration": [], + }, + json.loads(imported_dash.json_metadata), + ) + + expected_position = dash_with_1_slice.position + # new slice id (auto-incremental) assigned on insert + # id from json is used only for updating position with new id + meta = expected_position["DASHBOARD_CHART_TYPE-10006"]["meta"] + meta["chartId"] = imported_dash.slices[0].id + self.assertEqual(expected_position, imported_dash.position) + + @pytest.mark.usefixtures("load_energy_table_with_slice") + def test_import_dashboard_2_slices(self): + schema = get_example_default_schema() + e_slc = self.create_slice( + "e_slc", id=10007, table_name="energy_usage", schema=schema + ) + b_slc = self.create_slice( + "b_slc", id=10008, table_name="birth_names", schema=schema + ) + dash_with_2_slices = self.create_dashboard( + "dash_with_2_slices", slcs=[e_slc, b_slc], id=10003 + ) + dash_with_2_slices.json_metadata = json.dumps( + { + "remote_id": 10003, + "expanded_slices": { + f"{e_slc.id}": True, + f"{b_slc.id}": False, + }, + # mocked legacy filter_scope metadata + "filter_scopes": { + str(e_slc.id): { + "region": {"scope": ["ROOT_ID"], "immune": [b_slc.id]} + } + }, + } + ) + + imported_dash_id = import_dashboard(dash_with_2_slices, import_time=1991) + imported_dash = self.get_dash(imported_dash_id) + + expected_dash = self.create_dashboard( + "dash_with_2_slices", slcs=[e_slc, b_slc], id=10003 + ) + make_transient(expected_dash) + self.assert_dash_equals( + imported_dash, expected_dash, check_position=False, check_slugs=False + ) + i_e_slc = self.get_slice_by_name("e_slc") + i_b_slc = self.get_slice_by_name("b_slc") + expected_json_metadata = { + "remote_id": 10003, + "import_time": 1991, + "expanded_slices": { + f"{i_e_slc.id}": True, + f"{i_b_slc.id}": False, + }, + "native_filter_configuration": [], + } + self.assertEqual( + expected_json_metadata, json.loads(imported_dash.json_metadata) + ) + + @pytest.mark.usefixtures("load_energy_table_with_slice") + def test_import_override_dashboard_2_slices(self): + schema = get_example_default_schema() + e_slc = self.create_slice( + "e_slc", id=10009, table_name="energy_usage", schema=schema + ) + b_slc = self.create_slice( + "b_slc", id=10010, table_name="birth_names", schema=schema + ) + dash_to_import = self.create_dashboard( + "override_dashboard", slcs=[e_slc, b_slc], id=10004 + ) + imported_dash_id_1 = import_dashboard(dash_to_import, import_time=1992) + + # create new instances of the slices + e_slc = self.create_slice( + "e_slc", id=10009, table_name="energy_usage", schema=schema + ) + b_slc = self.create_slice( + "b_slc", id=10010, table_name="birth_names", schema=schema + ) + c_slc = self.create_slice( + "c_slc", id=10011, table_name="birth_names", schema=schema + ) + dash_to_import_override = self.create_dashboard( + "override_dashboard_new", slcs=[e_slc, b_slc, c_slc], id=10004 + ) + imported_dash_id_2 = import_dashboard(dash_to_import_override, import_time=1992) + + # override doesn't change the id + self.assertEqual(imported_dash_id_1, imported_dash_id_2) + expected_dash = self.create_dashboard( + "override_dashboard_new", slcs=[e_slc, b_slc, c_slc], id=10004 + ) + make_transient(expected_dash) + imported_dash = self.get_dash(imported_dash_id_2) + self.assert_dash_equals( + expected_dash, imported_dash, check_position=False, check_slugs=False + ) + self.assertEqual( + { + "remote_id": 10004, + "import_time": 1992, + "native_filter_configuration": [], + }, + json.loads(imported_dash.json_metadata), + ) + + def test_import_new_dashboard_slice_reset_ownership(self): + admin_user = security_manager.find_user(username="admin") + self.assertTrue(admin_user) + gamma_user = security_manager.find_user(username="gamma") + self.assertTrue(gamma_user) + g.user = gamma_user + + dash_with_1_slice = self._create_dashboard_for_import(id_=10200) + # set another user as an owner of importing dashboard + dash_with_1_slice.created_by = admin_user + dash_with_1_slice.changed_by = admin_user + dash_with_1_slice.owners = [admin_user] + + imported_dash_id = import_dashboard(dash_with_1_slice) + imported_dash = self.get_dash(imported_dash_id) + self.assertEqual(imported_dash.created_by, gamma_user) + self.assertEqual(imported_dash.changed_by, gamma_user) + self.assertEqual(imported_dash.owners, [gamma_user]) + + imported_slc = imported_dash.slices[0] + self.assertEqual(imported_slc.created_by, gamma_user) + self.assertEqual(imported_slc.changed_by, gamma_user) + self.assertEqual(imported_slc.owners, [gamma_user]) + + def test_import_override_dashboard_slice_reset_ownership(self): + admin_user = security_manager.find_user(username="admin") + self.assertTrue(admin_user) + gamma_user = security_manager.find_user(username="gamma") + self.assertTrue(gamma_user) + g.user = gamma_user + + dash_with_1_slice = self._create_dashboard_for_import(id_=10300) + + imported_dash_id = import_dashboard(dash_with_1_slice) + imported_dash = self.get_dash(imported_dash_id) + self.assertEqual(imported_dash.created_by, gamma_user) + self.assertEqual(imported_dash.changed_by, gamma_user) + self.assertEqual(imported_dash.owners, [gamma_user]) + + imported_slc = imported_dash.slices[0] + self.assertEqual(imported_slc.created_by, gamma_user) + self.assertEqual(imported_slc.changed_by, gamma_user) + self.assertEqual(imported_slc.owners, [gamma_user]) + + # re-import with another user shouldn't change the permissions + g.user = admin_user + + dash_with_1_slice = self._create_dashboard_for_import(id_=10300) + + imported_dash_id = import_dashboard(dash_with_1_slice) + imported_dash = self.get_dash(imported_dash_id) + self.assertEqual(imported_dash.created_by, gamma_user) + self.assertEqual(imported_dash.changed_by, gamma_user) + self.assertEqual(imported_dash.owners, [gamma_user]) + + imported_slc = imported_dash.slices[0] + self.assertEqual(imported_slc.created_by, gamma_user) + self.assertEqual(imported_slc.changed_by, gamma_user) + self.assertEqual(imported_slc.owners, [gamma_user]) + + def _create_dashboard_for_import(self, id_=10100): + slc = self.create_slice( + "health_slc" + str(id_), id=id_ + 1, schema=get_example_default_schema() + ) + dash_with_1_slice = self.create_dashboard( + "dash_with_1_slice" + str(id_), slcs=[slc], id=id_ + 2 + ) + dash_with_1_slice.position_json = """ + {{"DASHBOARD_VERSION_KEY": "v2", + "DASHBOARD_CHART_TYPE-{0}": {{ + "type": "CHART", + "id": {0}, + "children": [], + "meta": {{ + "width": 4, + "height": 50, + "chartId": {0} + }} + }} + }} + """.format( + slc.id + ) + return dash_with_1_slice + + def test_import_table_no_metadata(self): + schema = get_example_default_schema() + db_id = get_example_database().id + table = self.create_table("pure_table", id=10001, schema=schema) + imported_id = import_dataset(table, db_id, import_time=1989) + imported = self.get_table_by_id(imported_id) + self.assert_table_equals(table, imported) + + def test_import_table_1_col_1_met(self): + schema = get_example_default_schema() + table = self.create_table( + "table_1_col_1_met", + id=10002, + cols_names=["col1"], + metric_names=["metric1"], + schema=schema, + ) + db_id = get_example_database().id + imported_id = import_dataset(table, db_id, import_time=1990) + imported = self.get_table_by_id(imported_id) + self.assert_table_equals(table, imported) + self.assertEqual( + { + "remote_id": 10002, + "import_time": 1990, + "database_name": "examples", + }, + json.loads(imported.params), + ) + + def test_import_table_2_col_2_met(self): + schema = get_example_default_schema() + table = self.create_table( + "table_2_col_2_met", + id=10003, + cols_names=["c1", "c2"], + metric_names=["m1", "m2"], + schema=schema, + ) + db_id = get_example_database().id + imported_id = import_dataset(table, db_id, import_time=1991) + + imported = self.get_table_by_id(imported_id) + self.assert_table_equals(table, imported) + + def test_import_table_override(self): + schema = get_example_default_schema() + table = self.create_table( + "table_override", + id=10003, + cols_names=["col1"], + metric_names=["m1"], + schema=schema, + ) + db_id = get_example_database().id + imported_id = import_dataset(table, db_id, import_time=1991) + + table_over = self.create_table( + "table_override", + id=10003, + cols_names=["new_col1", "col2", "col3"], + metric_names=["new_metric1"], + schema=schema, + ) + imported_over_id = import_dataset(table_over, db_id, import_time=1992) + + imported_over = self.get_table_by_id(imported_over_id) + self.assertEqual(imported_id, imported_over.id) + expected_table = self.create_table( + "table_override", + id=10003, + metric_names=["new_metric1", "m1"], + cols_names=["col1", "new_col1", "col2", "col3"], + schema=schema, + ) + self.assert_table_equals(expected_table, imported_over) + + def test_import_table_override_identical(self): + schema = get_example_default_schema() + table = self.create_table( + "copy_cat", + id=10004, + cols_names=["new_col1", "col2", "col3"], + metric_names=["new_metric1"], + schema=schema, + ) + db_id = get_example_database().id + imported_id = import_dataset(table, db_id, import_time=1993) + + copy_table = self.create_table( + "copy_cat", + id=10004, + cols_names=["new_col1", "col2", "col3"], + metric_names=["new_metric1"], + schema=schema, + ) + imported_id_copy = import_dataset(copy_table, db_id, import_time=1994) + + self.assertEqual(imported_id, imported_id_copy) + self.assert_table_equals(copy_table, self.get_table_by_id(imported_id)) + + +if __name__ == "__main__": + unittest.main() diff --git a/superset/commands/report/execute.py b/superset/commands/report/execute.py index 349ef0c85b9a7..5cacb66134a5d 100644 --- a/superset/commands/report/execute.py +++ b/superset/commands/report/execute.py @@ -22,9 +22,8 @@ import pandas as pd from celery.exceptions import SoftTimeLimitExceeded -from sqlalchemy.orm import Session -from superset import app, security_manager +from superset import app, db, security_manager from superset.commands.base import BaseCommand from superset.commands.dashboard.permalink.create import CreateDashboardPermalinkCommand from superset.commands.exceptions import CommandException @@ -68,7 +67,6 @@ from superset.reports.notifications.base import NotificationContent from superset.reports.notifications.exceptions import NotificationError from superset.tasks.utils import get_executor -from superset.utils.celery import session_scope from superset.utils.core import HeaderDataType, override_user from superset.utils.csv import get_chart_csv_data, get_chart_dataframe from superset.utils.decorators import logs_context @@ -85,12 +83,10 @@ class BaseReportState: @logs_context() def __init__( self, - session: Session, report_schedule: ReportSchedule, scheduled_dttm: datetime, execution_id: UUID, ) -> None: - self._session = session self._report_schedule = report_schedule self._scheduled_dttm = scheduled_dttm self._start_dttm = datetime.utcnow() @@ -123,7 +119,7 @@ def update_report_schedule(self, state: ReportState) -> None: self._report_schedule.last_state = state self._report_schedule.last_eval_dttm = datetime.utcnow() - self._session.commit() + db.session.commit() def create_log(self, error_message: Optional[str] = None) -> None: """ @@ -140,8 +136,8 @@ def create_log(self, error_message: Optional[str] = None) -> None: report_schedule=self._report_schedule, uuid=self._execution_id, ) - self._session.add(log) - self._session.commit() + db.session.add(log) + db.session.commit() def _get_url( self, @@ -485,9 +481,7 @@ def is_in_grace_period(self) -> bool: """ Checks if an alert is in it's grace period """ - last_success = ReportScheduleDAO.find_last_success_log( - self._report_schedule, session=self._session - ) + last_success = ReportScheduleDAO.find_last_success_log(self._report_schedule) return ( last_success is not None and self._report_schedule.grace_period @@ -501,7 +495,7 @@ def is_in_error_grace_period(self) -> bool: Checks if an alert/report on error is in it's notification grace period """ last_success = ReportScheduleDAO.find_last_error_notification( - self._report_schedule, session=self._session + self._report_schedule ) if not last_success: return False @@ -518,7 +512,7 @@ def is_on_working_timeout(self) -> bool: Checks if an alert is in a working timeout """ last_working = ReportScheduleDAO.find_last_entered_working_log( - self._report_schedule, session=self._session + self._report_schedule ) if not last_working: return False @@ -668,12 +662,10 @@ class ReportScheduleStateMachine: # pylint: disable=too-few-public-methods def __init__( self, - session: Session, task_uuid: UUID, report_schedule: ReportSchedule, scheduled_dttm: datetime, ): - self._session = session self._execution_id = task_uuid self._report_schedule = report_schedule self._scheduled_dttm = scheduled_dttm @@ -684,7 +676,6 @@ def run(self) -> None: self._report_schedule.last_state in state_cls.current_states ): state_cls( - self._session, self._report_schedule, self._scheduled_dttm, self._execution_id, @@ -708,31 +699,30 @@ def __init__(self, task_id: str, model_id: int, scheduled_dttm: datetime): self._execution_id = UUID(task_id) def run(self) -> None: - with session_scope(nullpool=True) as session: - try: - self.validate(session=session) - if not self._model: - raise ReportScheduleExecuteUnexpectedError() - _, username = get_executor( - executor_types=app.config["ALERT_REPORTS_EXECUTE_AS"], - model=self._model, + try: + self.validate() + if not self._model: + raise ReportScheduleExecuteUnexpectedError() + _, username = get_executor( + executor_types=app.config["ALERT_REPORTS_EXECUTE_AS"], + model=self._model, + ) + user = security_manager.find_user(username) + with override_user(user): + logger.info( + "Running report schedule %s as user %s", + self._execution_id, + username, ) - user = security_manager.find_user(username) - with override_user(user): - logger.info( - "Running report schedule %s as user %s", - self._execution_id, - username, - ) - ReportScheduleStateMachine( - session, self._execution_id, self._model, self._scheduled_dttm - ).run() - except CommandException as ex: - raise ex - except Exception as ex: - raise ReportScheduleUnexpectedError(str(ex)) from ex + ReportScheduleStateMachine( + self._execution_id, self._model, self._scheduled_dttm + ).run() + except CommandException as ex: + raise ex + except Exception as ex: + raise ReportScheduleUnexpectedError(str(ex)) from ex - def validate(self, session: Session = None) -> None: + def validate(self) -> None: # Validate/populate model exists logger.info( "session is validated: id %s, executionid: %s", @@ -740,7 +730,7 @@ def validate(self, session: Session = None) -> None: self._execution_id, ) self._model = ( - session.query(ReportSchedule).filter_by(id=self._model_id).one_or_none() + db.session.query(ReportSchedule).filter_by(id=self._model_id).one_or_none() ) if not self._model: raise ReportScheduleNotFoundError() diff --git a/superset/commands/report/log_prune.py b/superset/commands/report/log_prune.py index 3a9883c9f1009..610519ba90803 100644 --- a/superset/commands/report/log_prune.py +++ b/superset/commands/report/log_prune.py @@ -17,12 +17,12 @@ import logging from datetime import datetime, timedelta +from superset import db from superset.commands.base import BaseCommand from superset.commands.report.exceptions import ReportSchedulePruneLogError from superset.daos.exceptions import DAODeleteFailedError from superset.daos.report import ReportScheduleDAO from superset.reports.models import ReportSchedule -from superset.utils.celery import session_scope logger = logging.getLogger(__name__) @@ -36,28 +36,27 @@ def __init__(self, worker_context: bool = True): self._worker_context = worker_context def run(self) -> None: - with session_scope(nullpool=True) as session: - self.validate() - prune_errors = [] - - for report_schedule in session.query(ReportSchedule).all(): - if report_schedule.log_retention is not None: - from_date = datetime.utcnow() - timedelta( - days=report_schedule.log_retention + self.validate() + prune_errors = [] + + for report_schedule in db.session.query(ReportSchedule).all(): + if report_schedule.log_retention is not None: + from_date = datetime.utcnow() - timedelta( + days=report_schedule.log_retention + ) + try: + row_count = ReportScheduleDAO.bulk_delete_logs( + report_schedule, from_date, commit=False ) - try: - row_count = ReportScheduleDAO.bulk_delete_logs( - report_schedule, from_date, session=session, commit=False - ) - logger.info( - "Deleted %s logs for report schedule id: %s", - str(row_count), - str(report_schedule.id), - ) - except DAODeleteFailedError as ex: - prune_errors.append(str(ex)) - if prune_errors: - raise ReportSchedulePruneLogError(";".join(prune_errors)) + logger.info( + "Deleted %s logs for report schedule id: %s", + str(row_count), + str(report_schedule.id), + ) + except DAODeleteFailedError as ex: + prune_errors.append(str(ex)) + if prune_errors: + raise ReportSchedulePruneLogError(";".join(prune_errors)) def validate(self) -> None: pass diff --git a/superset/daos/report.py b/superset/daos/report.py index e7470623f00ee..b5db391ec4880 100644 --- a/superset/daos/report.py +++ b/superset/daos/report.py @@ -22,7 +22,6 @@ from typing import Any from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import Session from superset.daos.base import BaseDAO from superset.daos.exceptions import DAODeleteFailedError @@ -204,27 +203,25 @@ def update( return super().update(item, attributes, commit) @staticmethod - def find_active(session: Session | None = None) -> list[ReportSchedule]: + def find_active() -> list[ReportSchedule]: """ - Find all active reports. If session is passed it will be used instead of the - default `db.session`, this is useful when on a celery worker session context + Find all active reports. """ - session = session or db.session return ( - session.query(ReportSchedule).filter(ReportSchedule.active.is_(True)).all() + db.session.query(ReportSchedule) + .filter(ReportSchedule.active.is_(True)) + .all() ) @staticmethod def find_last_success_log( report_schedule: ReportSchedule, - session: Session | None = None, ) -> ReportExecutionLog | None: """ Finds last success execution log for a given report """ - session = session or db.session return ( - session.query(ReportExecutionLog) + db.session.query(ReportExecutionLog) .filter( ReportExecutionLog.state == ReportState.SUCCESS, ReportExecutionLog.report_schedule == report_schedule, @@ -236,14 +233,12 @@ def find_last_success_log( @staticmethod def find_last_entered_working_log( report_schedule: ReportSchedule, - session: Session | None = None, ) -> ReportExecutionLog | None: """ Finds last success execution log for a given report """ - session = session or db.session return ( - session.query(ReportExecutionLog) + db.session.query(ReportExecutionLog) .filter( ReportExecutionLog.state == ReportState.WORKING, ReportExecutionLog.report_schedule == report_schedule, @@ -256,14 +251,12 @@ def find_last_entered_working_log( @staticmethod def find_last_error_notification( report_schedule: ReportSchedule, - session: Session | None = None, ) -> ReportExecutionLog | None: """ Finds last error email sent """ - session = session or db.session last_error_email_log = ( - session.query(ReportExecutionLog) + db.session.query(ReportExecutionLog) .filter( ReportExecutionLog.error_message == REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER, @@ -276,7 +269,7 @@ def find_last_error_notification( return None # Checks that only errors have occurred since the last email report_from_last_email = ( - session.query(ReportExecutionLog) + db.session.query(ReportExecutionLog) .filter( ReportExecutionLog.state.notin_( [ReportState.ERROR, ReportState.WORKING] @@ -293,13 +286,11 @@ def find_last_error_notification( def bulk_delete_logs( model: ReportSchedule, from_date: datetime, - session: Session | None = None, commit: bool = True, ) -> int | None: - session = session or db.session try: row_count = ( - session.query(ReportExecutionLog) + db.session.query(ReportExecutionLog) .filter( ReportExecutionLog.report_schedule == model, ReportExecutionLog.end_dttm < from_date, @@ -307,8 +298,8 @@ def bulk_delete_logs( .delete(synchronize_session="fetch") ) if commit: - session.commit() + db.session.commit() return row_count except SQLAlchemyError as ex: - session.rollback() + db.session.rollback() raise DAODeleteFailedError(str(ex)) from ex diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 6bc8c444d69bb..48e44064acfdf 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -50,7 +50,6 @@ from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.ext.compiler import compiles -from sqlalchemy.orm import Session from sqlalchemy.sql import literal_column, quoted_name, text from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause from sqlalchemy.types import TypeEngine @@ -1071,7 +1070,7 @@ def convert_dttm( # pylint: disable=unused-argument return None @classmethod - def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: + def handle_cursor(cls, cursor: Any, query: Query) -> None: """Handle a live cursor between the execute and fetchall calls The flow works without this method doing anything, but it allows @@ -1080,9 +1079,7 @@ def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: # TODO: Fix circular import error caused by importing sql_lab.Query @classmethod - def execute_with_cursor( - cls, cursor: Any, sql: str, query: Query, session: Session - ) -> None: + def execute_with_cursor(cls, cursor: Any, sql: str, query: Query) -> None: """ Trigger execution of a query and handle the resulting cursor. @@ -1095,7 +1092,7 @@ def execute_with_cursor( logger.debug("Query %d: Running query: %s", query.id, sql) cls.execute(cursor, sql, async_=True) logger.debug("Query %d: Handling cursor", query.id) - cls.handle_cursor(cursor, query, session) + cls.handle_cursor(cursor, query) @classmethod def extract_error_message(cls, ex: Exception) -> str: @@ -1841,7 +1838,7 @@ def get_sqla_column_type( # pylint: disable=unused-argument @classmethod - def prepare_cancel_query(cls, query: Query, session: Session) -> None: + def prepare_cancel_query(cls, query: Query) -> None: """ Some databases may acquire the query cancelation id after the query cancelation request has been received. For those cases, the db engine spec diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index bd303f928d625..9222d55db0171 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -34,9 +34,9 @@ from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL -from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnClause, Select +from superset import db from superset.common.db_query_status import QueryStatus from superset.constants import TimeGrain from superset.databases.utils import make_url_safe @@ -334,7 +334,7 @@ def get_tracking_url_from_logs(cls, log_lines: list[str]) -> str | None: @classmethod def handle_cursor( # pylint: disable=too-many-locals - cls, cursor: Any, query: Query, session: Session + cls, cursor: Any, query: Query ) -> None: """Updates progress information""" # pylint: disable=import-outside-toplevel @@ -353,8 +353,8 @@ def handle_cursor( # pylint: disable=too-many-locals # Queries don't terminate when user clicks the STOP button on SQL LAB. # Refresh session so that the `query.status` modified in stop_query in # views/core.py is reflected here. - session.refresh(query) - query = session.query(type(query)).filter_by(id=query_id).one() + db.session.refresh(query) + query = db.session.query(type(query)).filter_by(id=query_id).one() if query.status == QueryStatus.STOPPED: cursor.cancel() break @@ -396,7 +396,7 @@ def handle_cursor( # pylint: disable=too-many-locals logger.info("Query %s: [%s] %s", str(query_id), str(job_id), l) last_log_line = len(log_lines) if needs_commit: - session.commit() + db.session.commit() if sleep_interval := current_app.config.get("HIVE_POLL_INTERVAL"): logger.warning( "HIVE_POLL_INTERVAL is deprecated and will be removed in 3.0. Please use DB_POLL_INTERVAL_SECONDS instead" diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py index c10cf679355ea..9e5f728a6f84a 100644 --- a/superset/db_engine_specs/impala.py +++ b/superset/db_engine_specs/impala.py @@ -23,8 +23,8 @@ from flask import current_app from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.orm import Session +from superset import db from superset.constants import QUERY_EARLY_CANCEL_KEY, TimeGrain from superset.db_engine_specs.base import BaseEngineSpec from superset.models.sql_lab import Query @@ -101,7 +101,7 @@ def execute( raise cls.get_dbapi_mapped_exception(ex) @classmethod - def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: + def handle_cursor(cls, cursor: Any, query: Query) -> None: """Stop query and updates progress information""" query_id = query.id @@ -113,8 +113,8 @@ def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: try: status = cursor.status() while status in unfinished_states: - session.refresh(query) - query = session.query(Query).filter_by(id=query_id).one() + db.session.refresh(query) + query = db.session.query(Query).filter_by(id=query_id).one() # if query cancelation was requested prior to the handle_cursor call, but # the query was still executed # modified in stop_query in views / core.py is reflected here. @@ -145,7 +145,7 @@ def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: needs_commit = True if needs_commit: - session.commit() + db.session.commit() sleep_interval = current_app.config["DB_POLL_INTERVAL_SECONDS"].get( cls.engine, 5 ) diff --git a/superset/db_engine_specs/ocient.py b/superset/db_engine_specs/ocient.py index cb2d1cbcb641b..77c906fe779c8 100644 --- a/superset/db_engine_specs/ocient.py +++ b/superset/db_engine_specs/ocient.py @@ -23,7 +23,6 @@ from flask_babel import gettext as __ from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.orm import Session with contextlib.suppress(ImportError, RuntimeError): # pyocient may not be installed # Ensure pyocient inherits Superset's logging level @@ -372,13 +371,13 @@ def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: return "DUMMY_VALUE" @classmethod - def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: + def handle_cursor(cls, cursor: Any, query: Query) -> None: with OcientEngineSpec.query_id_mapping_lock: OcientEngineSpec.query_id_mapping[query.id] = cursor.query_id # Add the query id to the cursor setattr(cursor, "superset_query_id", query.id) - return super().handle_cursor(cursor, query, session) + return super().handle_cursor(cursor, query) @classmethod def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 27e86a7980875..44f8f9668a224 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -39,10 +39,9 @@ from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.result import Row as ResultRow from sqlalchemy.engine.url import URL -from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnClause, Select -from superset import cache_manager, is_feature_enabled +from superset import cache_manager, db, is_feature_enabled from superset.common.db_query_status import QueryStatus from superset.constants import TimeGrain from superset.databases.utils import make_url_safe @@ -1288,11 +1287,11 @@ def get_tracking_url(cls, cursor: Cursor) -> str | None: return None @classmethod - def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: + def handle_cursor(cls, cursor: Cursor, query: Query) -> None: """Updates progress information""" if tracking_url := cls.get_tracking_url(cursor): query.tracking_url = tracking_url - session.commit() + db.session.commit() query_id = query.id poll_interval = query.database.connect_args.get( @@ -1308,7 +1307,7 @@ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: # Update the object and wait for the kill signal. stats = polled.get("stats", {}) - query = session.query(type(query)).filter_by(id=query_id).one() + query = db.session.query(type(query)).filter_by(id=query_id).one() if query.status in [QueryStatus.STOPPED, QueryStatus.TIMED_OUT]: cursor.cancel() break @@ -1332,7 +1331,7 @@ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: ) if progress > query.progress: query.progress = progress - session.commit() + db.session.commit() time.sleep(poll_interval) logger.info("Query %i: Polling the cursor for progress", query_id) polled = cursor.poll() diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 1dc711880d729..6bdeae9d7fcea 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -27,8 +27,8 @@ from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.exc import NoSuchTableError -from sqlalchemy.orm import Session +from superset import db from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import BaseEngineSpec @@ -155,7 +155,7 @@ def get_tracking_url(cls, cursor: Cursor) -> str | None: return None @classmethod - def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: + def handle_cursor(cls, cursor: Cursor, query: Query) -> None: """ Handle a trino client cursor. @@ -172,7 +172,7 @@ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: if tracking_url := cls.get_tracking_url(cursor): query.tracking_url = tracking_url - session.commit() + db.session.commit() # if query cancelation was requested prior to the handle_cursor call, but # the query was still executed, trigger the actual query cancelation now @@ -183,12 +183,10 @@ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: cancel_query_id=cancel_query_id, ) - super().handle_cursor(cursor=cursor, query=query, session=session) + super().handle_cursor(cursor=cursor, query=query) @classmethod - def execute_with_cursor( - cls, cursor: Cursor, sql: str, query: Query, session: Session - ) -> None: + def execute_with_cursor(cls, cursor: Cursor, sql: str, query: Query) -> None: """ Trigger execution of a query and handle the resulting cursor. @@ -225,7 +223,7 @@ def _execute(results: dict[str, Any], event: threading.Event) -> None: time.sleep(0.1) logger.debug("Query %d: Handling cursor", query_id) - cls.handle_cursor(cursor, query, session) + cls.handle_cursor(cursor, query) # Block until the query completes; same behaviour as the client itself logger.debug("Query %d: Waiting for query to complete", query_id) @@ -237,10 +235,10 @@ def _execute(results: dict[str, Any], event: threading.Event) -> None: raise err @classmethod - def prepare_cancel_query(cls, query: Query, session: Session) -> None: + def prepare_cancel_query(cls, query: Query) -> None: if QUERY_CANCEL_KEY not in query.extra: query.set_extra_json_key(QUERY_EARLY_CANCEL_KEY, True) - session.commit() + db.session.commit() @classmethod def cancel_query(cls, cursor: Cursor, query: Query, cancel_query_id: str) -> bool: diff --git a/superset/sql_lab.py b/superset/sql_lab.py index efbef6560a366..1029ff402ca3c 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -25,10 +25,8 @@ import backoff import msgpack import simplejson as json -from celery import Task from celery.exceptions import SoftTimeLimitExceeded from flask_babel import gettext as __ -from sqlalchemy.orm import Session from superset import ( app, @@ -56,7 +54,6 @@ ) from superset.sqllab.limiting_factor import LimitingFactor from superset.sqllab.utils import write_ipc_buffer -from superset.utils.celery import session_scope from superset.utils.core import ( json_iso_dttm_ser, override_user, @@ -92,7 +89,6 @@ class SqlLabQueryStoppedException(SqlLabException): def handle_query_error( ex: Exception, query: Query, - session: Session, payload: Optional[dict[str, Any]] = None, prefix_message: str = "", ) -> dict[str, Any]: @@ -120,7 +116,7 @@ def handle_query_error( if errors: query.set_extra_json_key("errors", errors_payload) - session.commit() + db.session.commit() payload.update({"status": query.status, "error": msg, "errors": errors_payload}) if troubleshooting_link := config["TROUBLESHOOTING_LINK"]: payload["link"] = troubleshooting_link @@ -150,22 +146,20 @@ def get_query_giveup_handler(_: Any) -> None: on_giveup=get_query_giveup_handler, max_tries=5, ) -def get_query(query_id: int, session: Session) -> Query: +def get_query(query_id: int) -> Query: """attempts to get the query and retry if it cannot""" try: - return session.query(Query).filter_by(id=query_id).one() + return db.session.query(Query).filter_by(id=query_id).one() except Exception as ex: raise SqlLabException("Failed at getting query") from ex @celery_app.task( name="sql_lab.get_sql_results", - bind=True, time_limit=SQLLAB_HARD_TIMEOUT, soft_time_limit=SQLLAB_TIMEOUT, ) def get_sql_results( # pylint: disable=too-many-arguments - ctask: Task, query_id: int, rendered_query: str, return_results: bool = True, @@ -176,30 +170,27 @@ def get_sql_results( # pylint: disable=too-many-arguments log_params: Optional[dict[str, Any]] = None, ) -> Optional[dict[str, Any]]: """Executes the sql query returns the results.""" - with session_scope(not ctask.request.called_directly) as session: - with override_user(security_manager.find_user(username)): - try: - return execute_sql_statements( - query_id, - rendered_query, - return_results, - store_results, - session=session, - start_time=start_time, - expand_data=expand_data, - log_params=log_params, - ) - except Exception as ex: # pylint: disable=broad-except - logger.debug("Query %d: %s", query_id, ex) - stats_logger.incr("error_sqllab_unhandled") - query = get_query(query_id, session) - return handle_query_error(ex, query, session) + with override_user(security_manager.find_user(username)): + try: + return execute_sql_statements( + query_id, + rendered_query, + return_results, + store_results, + start_time=start_time, + expand_data=expand_data, + log_params=log_params, + ) + except Exception as ex: # pylint: disable=broad-except + logger.debug("Query %d: %s", query_id, ex) + stats_logger.incr("error_sqllab_unhandled") + query = get_query(query_id) + return handle_query_error(ex, query) -def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-locals +def execute_sql_statement( sql_statement: str, query: Query, - session: Session, cursor: Any, log_params: Optional[dict[str, Any]], apply_ctas: bool = False, @@ -284,9 +275,9 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local security_manager, log_params, ) - session.commit() + db.session.commit() with stats_timing("sqllab.query.time_executing_query", stats_logger): - db_engine_spec.execute_with_cursor(cursor, sql, query, session) + db_engine_spec.execute_with_cursor(cursor, sql, query) with stats_timing("sqllab.query.time_fetching_results", stats_logger): logger.debug( @@ -319,7 +310,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local except Exception as ex: # query is stopped in another thread/worker # stopping raises expected exceptions which we should skip - session.refresh(query) + db.session.refresh(query) if query.status == QueryStatus.STOPPED: raise SqlLabQueryStoppedException() from ex @@ -393,7 +384,6 @@ def execute_sql_statements( rendered_query: str, return_results: bool, store_results: bool, - session: Session, start_time: Optional[float], expand_data: bool, log_params: Optional[dict[str, Any]], @@ -403,7 +393,7 @@ def execute_sql_statements( # only asynchronous queries stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time) - query = get_query(query_id, session) + query = get_query(query_id) payload: dict[str, Any] = {"query_id": query_id} database = query.database db_engine_spec = database.db_engine_spec @@ -432,7 +422,7 @@ def execute_sql_statements( logger.info("Query %s: Set query to 'running'", str(query_id)) query.status = QueryStatus.RUNNING query.start_running_time = now_as_float() - session.commit() + db.session.commit() # Should we create a table or view from the select? if ( @@ -476,11 +466,11 @@ def execute_sql_statements( cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query) if cancel_query_id is not None: query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id) - session.commit() + db.session.commit() statement_count = len(statements) for i, statement in enumerate(statements): # Check if stopped - session.refresh(query) + db.session.refresh(query) if query.status == QueryStatus.STOPPED: payload.update({"status": query.status}) return payload @@ -497,12 +487,11 @@ def execute_sql_statements( ) logger.info("Query %s: %s", str(query_id), msg) query.set_extra_json_key("progress", msg) - session.commit() + db.session.commit() try: result_set = execute_sql_statement( statement, query, - session, cursor, log_params, apply_ctas, @@ -521,9 +510,7 @@ def execute_sql_statements( if statement_count > 1 else "" ) - payload = handle_query_error( - ex, query, session, payload, prefix_message - ) + payload = handle_query_error(ex, query, payload, prefix_message) return payload # Commit the connection so CTA queries will create the table and any DML. @@ -593,7 +580,7 @@ def execute_sql_statements( query.results_key = key query.status = QueryStatus.SUCCESS - session.commit() + db.session.commit() if return_results: # since we're returning results we need to create non-arrow data @@ -634,7 +621,7 @@ def cancel_query(query: Query) -> bool: return True # Some databases may need to make preparations for query cancellation - query.database.db_engine_spec.prepare_cancel_query(query, db.session) + query.database.db_engine_spec.prepare_cancel_query(query) if query.extra.get(QUERY_EARLY_CANCEL_KEY): # Query has been cancelled prior to being able to set the cancel key. diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index 569797ba27ab0..aeeb86aab3bbe 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -94,14 +94,7 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods name = "dummy" def get_payloads(self) -> list[dict[str, int]]: - session = db.create_scoped_session() - - try: - charts = session.query(Slice).all() - finally: - session.close() - - return [get_payload(chart) for chart in charts] + return [get_payload(chart) for chart in db.session.query(Slice).all()] class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-methods @@ -130,28 +123,24 @@ def __init__(self, top_n: int = 5, since: str = "7 days ago") -> None: self.since = parse_human_datetime(since) if since else None def get_payloads(self) -> list[dict[str, int]]: - payloads = [] - session = db.create_scoped_session() + records = ( + db.session.query(Log.dashboard_id, func.count(Log.dashboard_id)) + .filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since)) + .group_by(Log.dashboard_id) + .order_by(func.count(Log.dashboard_id).desc()) + .limit(self.top_n) + .all() + ) + dash_ids = [record.dashboard_id for record in records] + dashboards = ( + db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all() + ) - try: - records = ( - session.query(Log.dashboard_id, func.count(Log.dashboard_id)) - .filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since)) - .group_by(Log.dashboard_id) - .order_by(func.count(Log.dashboard_id).desc()) - .limit(self.top_n) - .all() - ) - dash_ids = [record.dashboard_id for record in records] - dashboards = ( - session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all() - ) - for dashboard in dashboards: - for chart in dashboard.slices: - payloads.append(get_payload(chart, dashboard)) - finally: - session.close() - return payloads + return [ + get_payload(chart, dashboard) + for dashboard in dashboards + for chart in dashboard.slices + ] class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods @@ -178,48 +167,44 @@ def __init__(self, tags: Optional[list[str]] = None) -> None: def get_payloads(self) -> list[dict[str, int]]: payloads = [] - session = db.create_scoped_session() - - try: - tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all() - tag_ids = [tag.id for tag in tags] - - # add dashboards that are tagged - tagged_objects = ( - session.query(TaggedObject) - .filter( - and_( - TaggedObject.object_type == "dashboard", - TaggedObject.tag_id.in_(tag_ids), - ) + tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all() + tag_ids = [tag.id for tag in tags] + + # add dashboards that are tagged + tagged_objects = ( + db.session.query(TaggedObject) + .filter( + and_( + TaggedObject.object_type == "dashboard", + TaggedObject.tag_id.in_(tag_ids), ) - .all() - ) - dash_ids = [tagged_object.object_id for tagged_object in tagged_objects] - tagged_dashboards = session.query(Dashboard).filter( - Dashboard.id.in_(dash_ids) ) - for dashboard in tagged_dashboards: - for chart in dashboard.slices: - payloads.append(get_payload(chart)) - - # add charts that are tagged - tagged_objects = ( - session.query(TaggedObject) - .filter( - and_( - TaggedObject.object_type == "chart", - TaggedObject.tag_id.in_(tag_ids), - ) + .all() + ) + dash_ids = [tagged_object.object_id for tagged_object in tagged_objects] + tagged_dashboards = db.session.query(Dashboard).filter( + Dashboard.id.in_(dash_ids) + ) + for dashboard in tagged_dashboards: + for chart in dashboard.slices: + payloads.append(get_payload(chart)) + + # add charts that are tagged + tagged_objects = ( + db.session.query(TaggedObject) + .filter( + and_( + TaggedObject.object_type == "chart", + TaggedObject.tag_id.in_(tag_ids), ) - .all() ) - chart_ids = [tagged_object.object_id for tagged_object in tagged_objects] - tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids)) - for chart in tagged_charts: - payloads.append(get_payload(chart)) - finally: - session.close() + .all() + ) + chart_ids = [tagged_object.object_id for tagged_object in tagged_objects] + tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids)) + for chart in tagged_charts: + payloads.append(get_payload(chart)) + return payloads diff --git a/superset/tasks/celery_app.py b/superset/tasks/celery_app.py index 850709bfb4866..e9ab10c0500b7 100644 --- a/superset/tasks/celery_app.py +++ b/superset/tasks/celery_app.py @@ -21,7 +21,7 @@ """ from typing import Any -from celery.signals import worker_process_init +from celery.signals import task_postrun, worker_process_init # Superset framework imports from superset import create_app @@ -43,3 +43,27 @@ def reset_db_connection_pool(**kwargs: Any) -> None: # pylint: disable=unused-a with flask_app.app_context(): # https://docs.sqlalchemy.org/en/14/core/connections.html#engine-disposal db.engine.dispose() + + +@task_postrun.connect +def teardown( # pylint: disable=unused-argument + retval: Any, + *args: Any, + **kwargs: Any, +) -> None: + """ + After each Celery task teardown the Flask-SQLAlchemy session. + + Note for non eagar requests Flask-SQLAlchemy will perform the teardown. + + :param retval: The return value of the task + :see: https://docs.celeryq.dev/en/stable/userguide/signals.html#task-postrun + :see: https://gist.github.com/twolfson/a1b329e9353f9b575131 + """ + + if flask_app.config.get("SQLALCHEMY_COMMIT_ON_TEARDOWN"): + if not isinstance(retval, Exception): + db.session.commit() + + if not flask_app.config.get("CELERY_ALWAYS_EAGER"): + db.session.remove() diff --git a/superset/tasks/scheduler.py b/superset/tasks/scheduler.py index 7b1350a07d3af..cb55dc9f69ad8 100644 --- a/superset/tasks/scheduler.py +++ b/superset/tasks/scheduler.py @@ -29,7 +29,6 @@ from superset.extensions import celery_app from superset.stats_logger import BaseStatsLogger from superset.tasks.cron_util import cron_schedule_window -from superset.utils.celery import session_scope from superset.utils.core import LoggerLevel from superset.utils.log import get_logger_from_status @@ -46,35 +45,32 @@ def scheduler() -> None: if not is_feature_enabled("ALERT_REPORTS"): return - with session_scope(nullpool=True) as session: - active_schedules = ReportScheduleDAO.find_active(session) - triggered_at = ( - datetime.fromisoformat(scheduler.request.expires) - - app.config["CELERY_BEAT_SCHEDULER_EXPIRES"] - if scheduler.request.expires - else datetime.utcnow() - ) - for active_schedule in active_schedules: - for schedule in cron_schedule_window( - triggered_at, active_schedule.crontab, active_schedule.timezone + active_schedules = ReportScheduleDAO.find_active() + triggered_at = ( + datetime.fromisoformat(scheduler.request.expires) + - app.config["CELERY_BEAT_SCHEDULER_EXPIRES"] + if scheduler.request.expires + else datetime.utcnow() + ) + for active_schedule in active_schedules: + for schedule in cron_schedule_window( + triggered_at, active_schedule.crontab, active_schedule.timezone + ): + logger.info("Scheduling alert %s eta: %s", active_schedule.name, schedule) + async_options = {"eta": schedule} + if ( + active_schedule.working_timeout is not None + and app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] ): - logger.info( - "Scheduling alert %s eta: %s", active_schedule.name, schedule + async_options["time_limit"] = ( + active_schedule.working_timeout + + app.config["ALERT_REPORTS_WORKING_TIME_OUT_LAG"] + ) + async_options["soft_time_limit"] = ( + active_schedule.working_timeout + + app.config["ALERT_REPORTS_WORKING_SOFT_TIME_OUT_LAG"] ) - async_options = {"eta": schedule} - if ( - active_schedule.working_timeout is not None - and app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] - ): - async_options["time_limit"] = ( - active_schedule.working_timeout - + app.config["ALERT_REPORTS_WORKING_TIME_OUT_LAG"] - ) - async_options["soft_time_limit"] = ( - active_schedule.working_timeout - + app.config["ALERT_REPORTS_WORKING_SOFT_TIME_OUT_LAG"] - ) - execute.apply_async((active_schedule.id,), **async_options) + execute.apply_async((active_schedule.id,), **async_options) @celery_app.task(name="reports.execute", bind=True) diff --git a/superset/utils/celery.py b/superset/utils/celery.py deleted file mode 100644 index 35771791456ce..0000000000000 --- a/superset/utils/celery.py +++ /dev/null @@ -1,59 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from collections.abc import Iterator -from contextlib import contextmanager - -from sqlalchemy import create_engine -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import Session, sessionmaker -from sqlalchemy.pool import NullPool - -from superset import app, db - -logger = logging.getLogger(__name__) - - -# Null pool is used for the celery workers due process forking side effects. -# For more info see: https://github.com/apache/superset/issues/10530 -@contextmanager -def session_scope(nullpool: bool) -> Iterator[Session]: - """Provide a transactional scope around a series of operations.""" - database_uri = app.config["SQLALCHEMY_DATABASE_URI"] - if "sqlite" in database_uri: - logger.warning( - "SQLite Database support for metadata databases will be removed \ - in a future version of Superset." - ) - if nullpool: - engine = create_engine(database_uri, poolclass=NullPool) - session_class = sessionmaker() - session_class.configure(bind=engine) - session = session_class() - else: - session = db.session() - session.commit() # HACK - - try: - yield session - session.commit() - except SQLAlchemyError as ex: - session.rollback() - logger.exception(ex) - raise - finally: - session.close() diff --git a/tests/integration_tests/sql_lab/test_execute_sql_statements.py b/tests/integration_tests/sql_lab/test_execute_sql_statements.py index 48fcfe31f03cb..7a08f35d3ddcb 100644 --- a/tests/integration_tests/sql_lab/test_execute_sql_statements.py +++ b/tests/integration_tests/sql_lab/test_execute_sql_statements.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from superset import app, db +from superset import app from superset.common.db_query_status import QueryStatus from superset.models.core import Database from superset.models.sql_lab import Query @@ -29,7 +29,6 @@ def test_non_async_execute(non_async_example_db: Database, example_query: Query) "select 1 as foo;", store_results=False, return_results=True, - session=db.session, start_time=now_as_float(), expand_data=True, log_params=dict(), diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 03c3bdfc235e9..4410f1978260d 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -575,16 +575,22 @@ def test_sql_json_parameter_forbidden(self): ) assert data["errors"][0]["error_type"] == "GENERIC_BACKEND_ERROR" + @mock.patch("superset.sql_lab.db") @mock.patch("superset.sql_lab.get_query") @mock.patch("superset.sql_lab.execute_sql_statement") - def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query): + def test_execute_sql_statements( + self, + mock_execute_sql_statement, + mock_get_query, + mock_db, + ): sql = """ -- comment SET @value = 42; SELECT @value AS foo; -- comment """ - mock_session = mock.MagicMock() + mock_db = mock.MagicMock() mock_query = mock.MagicMock() mock_query.database.allow_run_async = False mock_cursor = mock.MagicMock() @@ -599,7 +605,6 @@ def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query rendered_query=sql, return_results=True, store_results=False, - session=mock_session, start_time=None, expand_data=False, log_params=None, @@ -609,7 +614,6 @@ def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query mock.call( "SET @value = 42", mock_query, - mock_session, mock_cursor, None, False, @@ -617,7 +621,6 @@ def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query mock.call( "SELECT @value AS foo", mock_query, - mock_session, mock_cursor, None, False, @@ -637,7 +640,6 @@ def test_execute_sql_statements_no_results_backend( SELECT @value AS foo; -- comment """ - mock_session = mock.MagicMock() mock_query = mock.MagicMock() mock_query.database.allow_run_async = True mock_cursor = mock.MagicMock() @@ -653,7 +655,6 @@ def test_execute_sql_statements_no_results_backend( rendered_query=sql, return_results=True, store_results=False, - session=mock_session, start_time=None, expand_data=False, log_params=None, @@ -676,10 +677,14 @@ def test_execute_sql_statements_no_results_backend( }, ) + @mock.patch("superset.sql_lab.db") @mock.patch("superset.sql_lab.get_query") @mock.patch("superset.sql_lab.execute_sql_statement") def test_execute_sql_statements_ctas( - self, mock_execute_sql_statement, mock_get_query + self, + mock_execute_sql_statement, + mock_get_query, + mock_db, ): sql = """ -- comment @@ -687,7 +692,7 @@ def test_execute_sql_statements_ctas( SELECT @value AS foo; -- comment """ - mock_session = mock.MagicMock() + mock_db = mock.MagicMock() mock_query = mock.MagicMock() mock_query.database.allow_run_async = False mock_cursor = mock.MagicMock() @@ -706,7 +711,6 @@ def test_execute_sql_statements_ctas( rendered_query=sql, return_results=True, store_results=False, - session=mock_session, start_time=None, expand_data=False, log_params=None, @@ -716,7 +720,6 @@ def test_execute_sql_statements_ctas( mock.call( "SET @value = 42", mock_query, - mock_session, mock_cursor, None, False, @@ -724,7 +727,6 @@ def test_execute_sql_statements_ctas( mock.call( "SELECT @value AS foo", mock_query, - mock_session, mock_cursor, None, True, # apply_ctas @@ -740,7 +742,6 @@ def test_execute_sql_statements_ctas( rendered_query=sql, return_results=True, store_results=False, - session=mock_session, start_time=None, expand_data=False, log_params=None, @@ -773,7 +774,6 @@ def test_execute_sql_statements_ctas( rendered_query=sql, return_results=True, store_results=False, - session=mock_session, start_time=None, expand_data=False, log_params=None, diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index fd553b241c8f1..df1457dcad7d4 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -352,9 +352,8 @@ def test_prepare_cancel_query( from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query - session_mock = mocker.MagicMock() query = Query(extra_json=json.dumps(initial_extra)) - TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock) + TrinoEngineSpec.prepare_cancel_query(query=query) assert query.extra == final_extra @@ -374,14 +373,13 @@ def test_handle_cursor_early_cancel( cursor_mock = engine_mock.return_value.__enter__.return_value cursor_mock.query_id = query_id - session_mock = mocker.MagicMock() query = Query() if cancel_early: - TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock) + TrinoEngineSpec.prepare_cancel_query(query=query) - TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query, session=session_mock) + TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) if cancel_early: assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id @@ -399,7 +397,6 @@ def test_execute_with_cursor_in_parallel(mocker: MockerFixture): mock_cursor.query_id = None mock_query = mocker.MagicMock() - mock_session = mocker.MagicMock() def _mock_execute(*args, **kwargs): mock_cursor.query_id = query_id @@ -410,7 +407,6 @@ def _mock_execute(*args, **kwargs): cursor=mock_cursor, sql="SELECT 1 FROM foo", query=mock_query, - session=mock_session, ) mock_query.set_extra_json_key.assert_called_once_with( diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 200ee091ec558..82652773727da 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -20,6 +20,7 @@ from pytest_mock import MockerFixture from sqlalchemy.orm.session import Session +from superset import db from superset.utils.core import override_user @@ -41,14 +42,12 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None: db_engine_spec.is_select_query.return_value = True db_engine_spec.fetch_data.return_value = [(42,)] - session = mocker.MagicMock() cursor = mocker.MagicMock() SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet") execute_sql_statement( sql_statement, query, - session=session, cursor=cursor, log_params={}, apply_ctas=False, @@ -56,7 +55,7 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None: database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True) db_engine_spec.execute_with_cursor.assert_called_with( - cursor, "SELECT 42 AS answer LIMIT 2", query, session + cursor, "SELECT 42 AS answer LIMIT 2", query ) SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec) @@ -83,7 +82,6 @@ def test_execute_sql_statement_with_rls( db_engine_spec.is_select_query.return_value = True db_engine_spec.fetch_data.return_value = [(42,)] - session = mocker.MagicMock() cursor = mocker.MagicMock() SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet") mocker.patch( @@ -95,7 +93,6 @@ def test_execute_sql_statement_with_rls( execute_sql_statement( sql_statement, query, - session=session, cursor=cursor, log_params={}, apply_ctas=False, @@ -107,7 +104,7 @@ def test_execute_sql_statement_with_rls( force=True, ) db_engine_spec.execute_with_cursor.assert_called_with( - cursor, "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", query, session + cursor, "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", query ) SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec) @@ -162,7 +159,6 @@ def test_sql_lab_insert_rls_as_subquery( superset_result_set = execute_sql_statement( sql_statement=query.sql, query=query, - session=session, cursor=cursor, log_params=None, apply_ctas=False, @@ -198,7 +194,6 @@ def test_sql_lab_insert_rls_as_subquery( superset_result_set = execute_sql_statement( sql_statement=query.sql, query=query, - session=session, cursor=cursor, log_params=None, apply_ctas=False,