From 592e7b9798d09463c7301da57ecc6a7fcc52ce25 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Thu, 13 Apr 2023 09:23:16 +1200 Subject: [PATCH] chore(db_engine_specs): Refactor get_index (#23656) (cherry picked from commit b35b5a6e0557b53207940935b34f57c76253bb16) --- superset/db_engine_specs/base.py | 22 +++++ superset/db_engine_specs/bigquery.py | 22 +++++ superset/db_engine_specs/presto.py | 16 +++- superset/models/core.py | 3 +- .../db_engine_specs/base_engine_spec_tests.py | 23 +++++ .../db_engine_specs/bigquery_tests.py | 83 +++++++++++++++---- 6 files changed, 147 insertions(+), 22 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index e816ce1e3f660..20d523b902a05 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -43,6 +43,7 @@ import sqlparse from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin +from deprecation import deprecated from flask import current_app from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __, lazy_gettext as _ @@ -724,6 +725,7 @@ def get_datatype(cls, type_code: Any) -> Optional[str]: return None @classmethod + @deprecated(deprecated_in="3.0") def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Normalizes indexes for more consistency across db engines @@ -1089,6 +1091,26 @@ def get_view_names( # pylint: disable=unused-argument views = {re.sub(f"^{schema}\\.", "", view) for view in views} return views + @classmethod + def get_indexes( + cls, + database: Database, # pylint: disable=unused-argument + inspector: Inspector, + table_name: str, + schema: Optional[str], + ) -> List[Dict[str, Any]]: + """ + Get the indexes associated with the specified schema/table. + + :param database: The database to inspect + :param inspector: The SQLAlchemy inspector + :param table_name: The table to inspect + :param schema: The schema to inspect + :returns: The indexes + """ + + return inspector.get_indexes(table_name, schema) + @classmethod def get_table_comment( cls, inspector: Inspector, table_name: str, schema: Optional[str] diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 171dad4732507..5d87a4b8ff21c 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -23,6 +23,7 @@ import pandas as pd from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin +from deprecation import deprecated from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.exceptions import ValidationError @@ -259,6 +260,7 @@ def _truncate_label(cls, label: str) -> str: return "_" + md5_sha_from_str(label) @classmethod + @deprecated(deprecated_in="3.0") def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Normalizes indexes for more consistency across db engines @@ -277,6 +279,26 @@ def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any] normalized_idxs.append(ix) return normalized_idxs + @classmethod + def get_indexes( + cls, + database: "Database", + inspector: Inspector, + table_name: str, + schema: Optional[str], + ) -> List[Dict[str, Any]]: + """ + Get the indexes associated with the specified schema/table. + + :param database: The database to inspect + :param inspector: The SQLAlchemy inspector + :param table_name: The table to inspect + :param schema: The schema to inspect + :returns: The indexes + """ + + return cls.normalize_indexes(inspector.get_indexes(table_name, schema)) + @classmethod def extra_table_metadata( cls, database: "Database", table_name: str, schema_name: Optional[str] diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 72931a85b420c..81c071b386f90 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -534,10 +534,18 @@ def latest_partition( ) column_names = indexes[0]["column_names"] - part_fields = [(column_name, True) for column_name in column_names] - sql = cls._partition_query(table_name, database, 1, part_fields) - df = database.get_df(sql, schema) - return column_names, cls._latest_partition_from_df(df) + + return column_names, cls._latest_partition_from_df( + df=database.get_df( + sql=cls._partition_query( + table_name, + database, + limit=1, + order_by=[(column_name, True) for column_name in column_names], + ), + schema=schema, + ) + ) @classmethod def latest_sub_partition( diff --git a/superset/models/core.py b/superset/models/core.py index 9c67a2efa6d2b..a2c9de32c274e 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -800,8 +800,7 @@ def get_indexes( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: with self.get_inspector_with_context() as inspector: - indexes = inspector.get_indexes(table_name, schema) - return self.db_engine_spec.normalize_indexes(indexes) + return self.db_engine_spec.get_indexes(self, inspector, table_name, schema) def get_pk_constraint( self, table_name: str, schema: Optional[str] = None diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index 87de98db1c1d2..71ddc36ca8b4a 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -520,3 +520,26 @@ def test_validate_parameters_port_closed(is_port_open, is_hostname_valid): }, ) ] + + +def test_get_indexes(): + indexes = [ + { + "name": "partition", + "column_names": ["a", "b"], + "unique": False, + }, + ] + + inspector = mock.Mock() + inspector.get_indexes = mock.Mock(return_value=indexes) + + assert ( + BaseEngineSpec.get_indexes( + database=mock.Mock(), + inspector=inspector, + table_name="bar", + schema="foo", + ) + == indexes + ) diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index 574a2b75e32ca..6bac4649e35f5 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -143,27 +143,78 @@ def test_extra_table_metadata(self): ) self.assertEqual(result, expected_result) - def test_normalize_indexes(self): - """ - DB Eng Specs (bigquery): Test extra table metadata - """ - indexes = [{"name": "partition", "column_names": [None], "unique": False}] - normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes) - self.assertEqual(normalized_idx, []) + def test_get_indexes(self): + database = mock.Mock() + inspector = mock.Mock() + schema = "foo" + table_name = "bar" - indexes = [{"name": "partition", "column_names": ["dttm"], "unique": False}] - normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes) - self.assertEqual(normalized_idx, indexes) + inspector.get_indexes = mock.Mock( + return_value=[ + { + "name": "partition", + "column_names": [None], + "unique": False, + } + ] + ) - indexes = [ - {"name": "partition", "column_names": ["dttm", None], "unique": False} + assert ( + BigQueryEngineSpec.get_indexes( + database, + inspector, + table_name, + schema, + ) + == [] + ) + + inspector.get_indexes = mock.Mock( + return_value=[ + { + "name": "partition", + "column_names": ["dttm"], + "unique": False, + } + ] + ) + + assert BigQueryEngineSpec.get_indexes( + database, + inspector, + table_name, + schema, + ) == [ + { + "name": "partition", + "column_names": ["dttm"], + "unique": False, + } ] - normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes) - self.assertEqual( - normalized_idx, - [{"name": "partition", "column_names": ["dttm"], "unique": False}], + + inspector.get_indexes = mock.Mock( + return_value=[ + { + "name": "partition", + "column_names": ["dttm", None], + "unique": False, + } + ] ) + assert BigQueryEngineSpec.get_indexes( + database, + inspector, + table_name, + schema, + ) == [ + { + "name": "partition", + "column_names": ["dttm"], + "unique": False, + } + ] + @mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine") def test_df_to_sql(self, mock_get_engine): """