diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index f05bd67ec35ab..c3bdccc7753a8 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -86,9 +86,10 @@ def extra_table_metadata( } if database.has_view_by_name(table_name, schema_name): - metadata["view"] = database.inspector.get_view_definition( - table_name, schema_name - ) + with database.get_inspector_with_context() as inspector: + metadata["view"] = inspector.get_view_definition( + table_name, schema_name + ) return metadata diff --git a/superset/models/core.py b/superset/models/core.py index e76da0dcd5512..f59cd1159b63b 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -563,7 +563,8 @@ def get_df( # pylint: disable=too-many-locals mutator: Callable[[pd.DataFrame], None] | None = None, ) -> pd.DataFrame: sqls = self.db_engine_spec.parse_sql(sql) - engine = self._get_sqla_engine(schema) + with self.get_sqla_engine_with_context(schema) as engine: + engine_url = engine.url mutate_after_split = config["MUTATE_AFTER_SPLIT"] sql_query_mutator = config["SQL_QUERY_MUTATOR"] @@ -577,7 +578,7 @@ def needs_conversion(df_series: pd.Series) -> bool: def _log_query(sql: str) -> None: if log_query: log_query( - engine.url, + engine_url, sql, schema, __name__, @@ -624,13 +625,12 @@ def _log_query(sql: str) -> None: return df def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str: - engine = self._get_sqla_engine(schema=schema) + with self.get_sqla_engine_with_context(schema) as engine: + sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) - sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) - - # pylint: disable=protected-access - if engine.dialect.identifier_preparer._double_percents: # noqa - sql = sql.replace("%%", "%") + # pylint: disable=protected-access + if engine.dialect.identifier_preparer._double_percents: # noqa + sql = sql.replace("%%", "%") return sql @@ -645,18 +645,18 @@ def select_star( # pylint: disable=too-many-arguments cols: list[ResultSetColumnType] | None = None, ) -> str: """Generates a ``select *`` statement in the proper dialect""" - eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB) - return self.db_engine_spec.select_star( - self, - table_name, - schema=schema, - engine=eng, - limit=limit, - show_cols=show_cols, - indent=indent, - latest_partition=latest_partition, - cols=cols, - ) + with self.get_sqla_engine_with_context(schema) as engine: + return self.db_engine_spec.select_star( + self, + table_name, + schema=schema, + engine=engine, + limit=limit, + show_cols=show_cols, + indent=indent, + latest_partition=latest_partition, + cols=cols, + ) def apply_limit_to_sql( self, sql: str, limit: int = 1000, force: bool = False @@ -668,11 +668,6 @@ def apply_limit_to_sql( def safe_sqlalchemy_uri(self) -> str: return self.sqlalchemy_uri - @property - def inspector(self) -> Inspector: - engine = self._get_sqla_engine() - return sqla.inspect(engine) - @cache_util.memoized_func( key="db:{self.id}:schema:{schema}:table_list", cache=cache_manager.cache, @@ -955,8 +950,10 @@ def _has_view( return view_name in view_names def has_view(self, view_name: str, schema: str | None = None) -> bool: - engine = self._get_sqla_engine() - return engine.run_callable(self._has_view, engine.dialect, view_name, schema) + with self.get_sqla_engine_with_context(schema) as engine: + return engine.run_callable( + self._has_view, engine.dialect, view_name, schema + ) def has_view_by_name(self, view_name: str, schema: str | None = None) -> bool: return self.has_view(view_name=view_name, schema=schema) diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index 8693a888879d4..29a1f7a66afa0 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -120,9 +120,8 @@ def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None: def quote_f(value: Optional[str]): if not value: return value - return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier( - value - ) + with get_example_database().get_inspector_with_context() as inspector: + return inspector.engine.dialect.identifier_preparer.quote_identifier(value) def cta_result(ctas_method: CtasMethod): diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index da3a28f1ba81b..dc82026986245 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -113,9 +113,10 @@ def get_expected_row_count(self, client_id: str) -> int: def quote_name(self, name: str): if get_main_database().backend in {"presto", "hive"}: - return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier( - name - ) + with get_example_database().get_inspector_with_context() as inspector: # E: Ne + return inspector.engine.dialect.identifier_preparer.quote_identifier( + name + ) return name diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 3a5f7c0a77a1c..5222c1cb34ef1 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -296,7 +296,8 @@ def test_select_star(self): db = get_example_database() table_name = "energy_usage" sql = db.select_star(table_name, show_cols=False, latest_partition=False) - quote = db.inspector.engine.dialect.identifier_preparer.quote_identifier + with db.get_sqla_engine_with_context() as engine: + quote = engine.dialect.identifier_preparer.quote_identifier expected = ( textwrap.dedent( f"""\