From 4919b58685887a8f2f3fd8a3ff2246bbd2af05f0 Mon Sep 17 00:00:00 2001 From: Rob Moore Date: Fri, 28 Jul 2023 15:23:46 +0100 Subject: [PATCH] fix(sqllab): Force trino client async execution We are currently unable to stop trino queries, because the underlying trino client blocks until the query completes, and doesn't make the query ID or any other info available in the meantime. Unfortunately it doesn't look like they plan to change that any time soon, either. Make the following changes: - Add a new method execute_with_cursor to db_engine_spec which combines execute with handle_cursor, factoring it out of the one place it's used, deep in the execute query logic. - Make handle_cursor poll the cursor for query ID, as it's going to be populated asynchronously. Add warnings that the trino impl will require using execute_with_cursor. Currently nothing is directly calling handle_cursor, with the one original call eliminated. - Override execute_with_cursor for the trino engine and execute the two tasks in parallel to allow us to poll for the query ID while the query is still blocking. --- superset/db_engine_specs/base.py | 18 +++++ superset/db_engine_specs/trino.py | 66 +++++++++++++++++-- superset/sql_lab.py | 7 +- .../unit_tests/db_engine_specs/test_trino.py | 31 ++++++++- tests/unit_tests/sql_lab_test.py | 10 ++- 5 files changed, 114 insertions(+), 18 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 2e1d5598ff35f..76b1ee14ab270 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1026,6 +1026,24 @@ def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: query object""" # TODO: Fix circular import error caused by importing sql_lab.Query + @classmethod + def execute_with_cursor( + cls, cursor: Any, sql: str, query: Query, session: Session + ) -> None: + """ + Trigger execution of a query and handle the resulting cursor. + + For most implementations this just makes calls to `execute` and + `handle_cursor` consecutively, but in some engines (e.g. Trino) we may + need to handle client limitations such as lack of async support and + perform a more complicated operation to get information from the cursor + in a timely manner and facilitate operations such as query stop + """ + logger.debug("Query %d: Running query: %s", query.id, sql) + cls.execute(cursor, sql, async_=True) + logger.debug("Query %d: Handling cursor", query.id) + cls.handle_cursor(cursor, query, session) + @classmethod def extract_error_message(cls, ex: Exception) -> str: return f"{cls.engine} error: {cls._extract_error_message(ex)}" diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index f05bd67ec35ab..b38a1e1f5b108 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -17,6 +17,8 @@ from __future__ import annotations import logging +import threading +import time from typing import Any, TYPE_CHECKING import simplejson as json @@ -149,14 +151,21 @@ def get_tracking_url(cls, cursor: Cursor) -> str | None: @classmethod def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: - if tracking_url := cls.get_tracking_url(cursor): - query.tracking_url = tracking_url + """ + Handle a trino client cursor. + + WARNING: if you execute a query, it will block until complete and you + will not be able to handle the cursor until complete. Use + `execute_with_cursor` instead, to handle this asynchronously. + """ # Adds the executed query id to the extra payload so the query can be cancelled - query.set_extra_json_key( - key=QUERY_CANCEL_KEY, - value=(cancel_query_id := cursor.stats["queryId"]), - ) + cancel_query_id = cursor.query_id + logger.debug("Query %d: queryId %s found in cursor", query.id, cancel_query_id) + query.set_extra_json_key(key=QUERY_CANCEL_KEY, value=cancel_query_id) + + if tracking_url := cls.get_tracking_url(cursor): + query.tracking_url = tracking_url session.commit() @@ -171,6 +180,51 @@ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: super().handle_cursor(cursor=cursor, query=query, session=session) + @classmethod + def execute_with_cursor( + cls, cursor: Any, sql: str, query: Query, session: Session + ) -> None: + """ + Trigger execution of a query and handle the resulting cursor. + + Trino's client blocks until the query is complete, so we need to run it + in another thread and invoke `handle_cursor` to poll for the query ID + to appear on the cursor in parallel. + """ + execute_result: dict[str, Any] = {} + + def _execute(results: dict[str, Any]) -> None: + logger.debug("Query %d: Running query: %s", query.id, sql) + + # Pass result / exception information back to the parent thread + try: + cls.execute(cursor, sql) + results["complete"] = True + except Exception as ex: # pylint: disable=broad-except + results["complete"] = True + results["error"] = ex + + execute_thread = threading.Thread(target=_execute, args=(execute_result,)) + execute_thread.start() + + # Wait for a query ID to be available before handling the cursor, as + # it's required by that method; it may never become available on error. + while not cursor.query_id and not execute_result.get("complete"): + time.sleep(0.1) + + logger.debug("Query %d: Handling cursor", query.id) + cls.handle_cursor(cursor, query, session) + + # Block until the query completes; same behaviour as the client itself + logger.debug("Query %d: Waiting for query to complete", query.id) + while not execute_result.get("complete"): + time.sleep(0.5) + + # Unfortunately we'll mangle the stack trace due to the thread, but + # throwing the original exception allows mapping database errors as normal + if err := execute_result.get("error"): + raise err + @classmethod def prepare_cancel_query(cls, query: Query, session: Session) -> None: if QUERY_CANCEL_KEY not in query.extra: diff --git a/superset/sql_lab.py b/superset/sql_lab.py index cb8da3ef0444c..82e8f941d9d6a 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -191,7 +191,7 @@ def get_sql_results( # pylint: disable=too-many-arguments return handle_query_error(ex, query, session) -def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statements +def execute_sql_statement( # pylint: disable=too-many-arguments sql_statement: str, query: Query, session: Session, @@ -270,10 +270,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statem ) session.commit() with stats_timing("sqllab.query.time_executing_query", stats_logger): - logger.debug("Query %d: Running query: %s", query.id, sql) - db_engine_spec.execute(cursor, sql, async_=True) - logger.debug("Query %d: Handling cursor", query.id) - db_engine_spec.handle_cursor(cursor, query, session) + db_engine_spec.execute_with_cursor(cursor, sql, query, session) with stats_timing("sqllab.query.time_fetching_results", stats_logger): logger.debug( diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 963953d18b48e..1b50a683a0841 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -352,7 +352,7 @@ def test_handle_cursor_early_cancel( query_id = "myQueryId" cursor_mock = engine_mock.return_value.__enter__.return_value - cursor_mock.stats = {"queryId": query_id} + cursor_mock.query_id = query_id session_mock = mocker.MagicMock() query = Query() @@ -366,3 +366,32 @@ def test_handle_cursor_early_cancel( assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id else: assert cancel_query_mock.call_args is None + + +def test_execute_with_cursor_in_parallel(mocker: MockerFixture): + """Test that `execute_with_cursor` fetches query ID from the cursor""" + from superset.db_engine_specs.trino import TrinoEngineSpec + + query_id = "myQueryId" + + mock_cursor = mocker.MagicMock() + mock_cursor.query_id = None + + mock_query = mocker.MagicMock() + mock_session = mocker.MagicMock() + + def _mock_execute(*args, **kwargs): + mock_cursor.query_id = query_id + + mock_cursor.execute.side_effect = _mock_execute + + TrinoEngineSpec.execute_with_cursor( + cursor=mock_cursor, + sql="SELECT 1 FROM foo", + query=mock_query, + session=mock_session, + ) + + mock_query.set_extra_json_key.assert_called_once_with( + key=QUERY_CANCEL_KEY, value=query_id + ) diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 29f45eab682a0..edc1fd2ec4a5d 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -55,8 +55,8 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None: ) database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True) - db_engine_spec.execute.assert_called_with( - cursor, "SELECT 42 AS answer LIMIT 2", async_=True + db_engine_spec.execute_with_cursor.assert_called_with( + cursor, "SELECT 42 AS answer LIMIT 2", query, session ) SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec) @@ -106,10 +106,8 @@ def test_execute_sql_statement_with_rls( 101, force=True, ) - db_engine_spec.execute.assert_called_with( - cursor, - "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", - async_=True, + db_engine_spec.execute_with_cursor.assert_called_with( + cursor, "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", query, session ) SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)