Skip to content

Commit

Permalink
feat(trino): add query cancellation (#21035)
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro authored Aug 12, 2022
1 parent 2d1ba46 commit 5113b01
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 6 deletions.
42 changes: 36 additions & 6 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional, TYPE_CHECKING

Expand Down Expand Up @@ -90,7 +92,7 @@ def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
@classmethod
def get_table_names(
cls,
database: "Database",
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
Expand All @@ -103,7 +105,7 @@ def get_table_names(
@classmethod
def get_view_names(
cls,
database: "Database",
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
Expand All @@ -114,7 +116,7 @@ def get_view_names(
)

@classmethod
def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
def get_tracking_url(cls, cursor: Cursor) -> Optional[str]:
try:
return cursor.info_uri
except AttributeError:
Expand All @@ -127,14 +129,42 @@ def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
return None

@classmethod
def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None:
"""Updates progress information"""
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
tracking_url = cls.get_tracking_url(cursor)
if tracking_url:
query.tracking_url = tracking_url
session.commit()

# Adds the executed query id to the extra payload so the query can be cancelled
query.set_extra_json_key("cancel_query", cursor.stats["queryId"])

session.commit()
BaseEngineSpec.handle_cursor(cursor=cursor, query=query, session=session)

@classmethod
def has_implicit_cancel(cls) -> bool:
return False

@classmethod
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool:
"""
Cancel query in the underlying database.
:param cursor: New cursor instance to the db of the query
:param query: Query instance
:param cancel_query_id: Trino `queryId`
:return: True if query cancelled successfully, False otherwise
"""
try:
cursor.execute(
f"CALL system.runtime.kill_query(query_id => '{cancel_query_id}',"
"message => 'Query cancelled by Superset')"
)
cursor.fetchall() # needed to trigger the call
except Exception: # pylint: disable=broad-except
return False

return True

@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
"""
Expand Down
38 changes: 38 additions & 0 deletions tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from unittest import mock


@mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_success(engine_mock: mock.Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query

query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True


@mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_failed(engine_mock: mock.Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query

query = Query()
cursor_mock = engine_mock.raiseError.side_effect = Exception()
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False

0 comments on commit 5113b01

Please sign in to comment.