From c716b6325f20d9f55937a17dd90ea07e54a35357 Mon Sep 17 00:00:00 2001 From: Elizabeth Thompson Date: Mon, 13 Jun 2022 17:30:13 -0700 Subject: [PATCH] fix: catch some potential errors on dual write (#20351) * catch some potential errors on dual write * fix test for sqlite (cherry picked from commit 5a137820d0fd192fe8466e9448a59e327d13eeb5) --- superset/connectors/sqla/models.py | 42 +++++++---- superset/connectors/sqla/utils.py | 11 ++- tests/integration_tests/datasets/api_tests.py | 6 ++ .../integration_tests/datasets/model_tests.py | 69 +++++++++++++++++++ .../integration_tests/fixtures/datasource.py | 52 +++++++++++++- 5 files changed, 163 insertions(+), 17 deletions(-) create mode 100644 tests/integration_tests/datasets/model_tests.py diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 60eff5e6304ab..3b404743317a0 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -66,6 +66,7 @@ ) from sqlalchemy.engine.base import Connection from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session +from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.mapper import Mapper from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, ColumnElement, literal_column, table @@ -933,7 +934,8 @@ def mutate_query_from_config(self, sql: str) -> str: if sql_query_mutator: sql = sql_query_mutator( sql, - user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0. + # TODO(john-bodley): Deprecate in 3.0. + user_name=get_username(), security_manager=security_manager, database=self.database, ) @@ -2115,7 +2117,7 @@ def get_sl_columns(self) -> List[NewColumn]: ] @staticmethod - def update_table( # pylint: disable=unused-argument + def update_column( # pylint: disable=unused-argument mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn] ) -> None: """ @@ -2130,7 +2132,7 @@ def update_table( # pylint: disable=unused-argument # table is updated. This busts the cache key for all charts that use the table. session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id)) - # if table itself has changed, shadow-writing will happen in `after_udpate` anyway + # if table itself has changed, shadow-writing will happen in `after_update` anyway if target.table not in session.dirty: dataset: NewDataset = ( session.query(NewDataset) @@ -2146,17 +2148,27 @@ def update_table( # pylint: disable=unused-argument # update changed_on timestamp session.execute(update(NewDataset).where(NewDataset.id == dataset.id)) - - # update `Column` model as well - session.add( - target.to_sl_column( - { - target.uuid: session.query(NewColumn) - .filter_by(uuid=target.uuid) - .one_or_none() - } + try: + column = session.query(NewColumn).filter_by(uuid=target.uuid).one() + # update `Column` model as well + session.merge(target.to_sl_column({target.uuid: column})) + except NoResultFound: + logger.warning("No column was found for %s", target) + # see if the column is in cache + column = next( + find_cached_objects_in_session( + session, NewColumn, uuids=[target.uuid] + ), + None, ) - ) + + if not column: + # to be safe, use a different uuid and create a new column + uuid = uuid4() + target.uuid = uuid + column = NewColumn(uuid=uuid) + + session.add(target.to_sl_column({column.uuid: column})) @staticmethod def after_insert( @@ -2441,9 +2453,9 @@ def write_shadow_dataset( sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert) sa.event.listen(SqlaTable, "after_delete", SqlaTable.after_delete) sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update) -sa.event.listen(SqlMetric, "after_update", SqlaTable.update_table) +sa.event.listen(SqlMetric, "after_update", SqlaTable.update_column) sa.event.listen(SqlMetric, "after_delete", SqlMetric.after_delete) -sa.event.listen(TableColumn, "after_update", SqlaTable.update_table) +sa.event.listen(TableColumn, "after_update", SqlaTable.update_column) sa.event.listen(TableColumn, "after_delete", TableColumn.after_delete) RLSFilterRoles = Table( diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 1786c5bf17169..69a983156eafb 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import logging from contextlib import closing from typing import ( Any, @@ -35,6 +36,7 @@ from sqlalchemy.exc import NoSuchTableError from sqlalchemy.ext.declarative import DeclarativeMeta from sqlalchemy.orm import Session +from sqlalchemy.orm.exc import ObjectDeletedError from sqlalchemy.sql.type_api import TypeEngine from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -191,6 +193,7 @@ def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]: DeclarativeModel = TypeVar("DeclarativeModel", bound=DeclarativeMeta) +logger = logging.getLogger(__name__) def find_cached_objects_in_session( @@ -209,9 +212,15 @@ def find_cached_objects_in_session( if not ids and not uuids: return iter([]) uuids = uuids or [] + try: + items = set(session) + except ObjectDeletedError: + logger.warning("ObjectDeletedError", exc_info=True) + return iter(()) + return ( item # `session` is an iterator of all known items - for item in set(session) + for item in items if isinstance(item, cls) and (item.id in ids if ids else item.uuid in uuids) ) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 87d0da3cad827..b46b69184f6c1 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -33,6 +33,7 @@ DAODeleteFailedError, DAOUpdateFailedError, ) +from superset.datasets.models import Dataset from superset.extensions import db, security_manager from superset.models.core import Database from superset.utils.core import backend, get_example_default_schema @@ -1611,16 +1612,21 @@ def test_import_dataset(self): database = ( db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() ) + shadow_dataset = ( + db.session.query(Dataset).filter_by(uuid=dataset_config["uuid"]).one() + ) assert database.database_name == "imported_database" assert len(database.tables) == 1 dataset = database.tables[0] assert dataset.table_name == "imported_dataset" assert str(dataset.uuid) == dataset_config["uuid"] + assert str(shadow_dataset.uuid) == dataset_config["uuid"] dataset.owners = [] database.owners = [] db.session.delete(dataset) + db.session.delete(shadow_dataset) db.session.delete(database) db.session.commit() diff --git a/tests/integration_tests/datasets/model_tests.py b/tests/integration_tests/datasets/model_tests.py new file mode 100644 index 0000000000000..31abce5494370 --- /dev/null +++ b/tests/integration_tests/datasets/model_tests.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest import mock + +import pytest +from sqlalchemy.orm.exc import NoResultFound + +from superset.connectors.sqla.models import SqlaTable, TableColumn +from superset.extensions import db +from tests.integration_tests.base_tests import SupersetTestCase +from tests.integration_tests.fixtures.datasource import load_dataset_with_columns + + +class SqlaTableModelTest(SupersetTestCase): + @pytest.mark.usefixtures("load_dataset_with_columns") + def test_dual_update_column(self) -> None: + """ + Test that when updating a sqla ``TableColumn`` + That the shadow ``Column`` is also updated + """ + dataset = db.session.query(SqlaTable).filter_by(table_name="students").first() + column = dataset.columns[0] + column_name = column.column_name + column.column_name = "new_column_name" + SqlaTable.update_column(None, None, target=column) + + # refetch + dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one() + assert dataset.columns[0].column_name == "new_column_name" + + # reset + column.column_name = column_name + SqlaTable.update_column(None, None, target=column) + + @pytest.mark.usefixtures("load_dataset_with_columns") + @mock.patch("superset.columns.models.Column") + def test_dual_update_column_not_found(self, column_mock) -> None: + """ + Test that when updating a sqla ``TableColumn`` + That the shadow ``Column`` is also updated + """ + dataset = db.session.query(SqlaTable).filter_by(table_name="students").first() + column = dataset.columns[0] + column_uuid = column.uuid + with mock.patch("sqlalchemy.orm.query.Query.one", side_effect=NoResultFound): + SqlaTable.update_column(None, None, target=column) + + # refetch + dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one() + # it should create a new uuid + assert dataset.columns[0].uuid != column_uuid + + # reset + column.uuid = column_uuid + SqlaTable.update_column(None, None, target=column) diff --git a/tests/integration_tests/fixtures/datasource.py b/tests/integration_tests/fixtures/datasource.py index b6f2476f662c1..574f43d52bbcb 100644 --- a/tests/integration_tests/fixtures/datasource.py +++ b/tests/integration_tests/fixtures/datasource.py @@ -15,10 +15,20 @@ # specific language governing permissions and limitations # under the License. """Fixtures for test_datasource.py""" -from typing import Any, Dict +from typing import Any, Dict, Generator +import pytest +from sqlalchemy import Column, create_engine, Date, Integer, MetaData, String, Table +from sqlalchemy.ext.declarative.api import declarative_base + +from superset.columns.models import Column as Sl_Column +from superset.connectors.sqla.models import SqlaTable, TableColumn +from superset.extensions import db +from superset.models.core import Database +from superset.tables.models import Table as Sl_Table from superset.utils.core import get_example_default_schema from superset.utils.database import get_example_database +from tests.integration_tests.test_app import app def get_datasource_post() -> Dict[str, Any]: @@ -159,3 +169,43 @@ def get_datasource_post() -> Dict[str, Any]: }, ], } + + +@pytest.fixture() +def load_dataset_with_columns() -> Generator[SqlaTable, None, None]: + with app.app_context(): + engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True) + meta = MetaData() + session = db.session + + students = Table( + "students", + meta, + Column("id", Integer, primary_key=True), + Column("name", String(255)), + Column("lastname", String(255)), + Column("ds", Date), + ) + meta.create_all(engine) + + students.insert().values(name="George", ds="2021-01-01") + + dataset = SqlaTable( + database_id=db.session.query(Database).first().id, table_name="students" + ) + column = TableColumn(table_id=dataset.id, column_name="name") + dataset.columns = [column] + session.add(dataset) + session.commit() + yield dataset + + # cleanup + students_table = meta.tables.get("students") + if students_table is not None: + base = declarative_base() + # needed for sqlite + session.commit() + base.metadata.drop_all(engine, [students_table], checkfirst=True) + session.delete(dataset) + session.delete(column) + session.commit()