Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(#13378): Ensure g.user is set for impersonation #13878

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions superset/tasks/async_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@
import logging
from typing import Any, cast, Dict, Optional

from flask import current_app
from flask import current_app, g

from superset import app
from superset.exceptions import SupersetVizException
from superset.extensions import async_query_manager, cache_manager, celery_app
from superset.extensions import (
async_query_manager,
cache_manager,
celery_app,
security_manager,
)
from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.views.utils import get_datasource_info, get_viz

Expand All @@ -32,6 +37,12 @@
] # TODO: new config key


def ensure_user_is_set(user_id: Optional[int]) -> None:
user_is_set = hasattr(g, "user") and g.user is not None
if not user_is_set and user_id is not None:
g.user = security_manager.get_user_by_id(user_id)


@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
def load_chart_data_into_cache(
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
Expand All @@ -42,6 +53,7 @@ def load_chart_data_into_cache(

with app.app_context(): # type: ignore
try:
ensure_user_is_set(job_metadata.get("user_id"))
command = ChartDataCommand()
command.set_query_context(form_data)
result = command.run(cache=True)
Expand Down Expand Up @@ -72,6 +84,7 @@ def load_explore_json_into_cache(
with app.app_context(): # type: ignore
cache_key_prefix = "ejr-" # ejr: explore_json request
try:
ensure_user_is_set(job_metadata.get("user_id"))
datasource_id, datasource_type = get_datasource_info(None, None, form_data)

viz_obj = get_viz(
Expand Down
53 changes: 39 additions & 14 deletions tests/tasks/async_queries_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from superset.charts.commands.exceptions import ChartDataQueryFailedError
from superset.connectors.sqla.models import SqlaTable
from superset.exceptions import SupersetException
from superset.extensions import async_query_manager
from superset.extensions import async_query_manager, security_manager
from superset.tasks import async_queries
from superset.tasks.async_queries import (
load_chart_data_into_cache,
load_explore_json_into_cache,
Expand All @@ -48,17 +49,24 @@ class TestAsyncQueries(SupersetTestCase):
def test_load_chart_data_into_cache(self, mock_update_job):
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
user = security_manager.find_user("gamma")
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": 1,
"user_id": user.id,
"status": "pending",
"errors": [],
}

load_chart_data_into_cache(job_metadata, query_context)
with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_chart_data_into_cache(job_metadata, query_context)

mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
ensure_user_is_set.assert_called_once_with(user.id)
mock_update_job.assert_called_once_with(
job_metadata, "done", result_url=mock.ANY
)

@mock.patch.object(
ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo")
Expand All @@ -67,25 +75,31 @@ def test_load_chart_data_into_cache(self, mock_update_job):
def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
user = security_manager.find_user("gamma")
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": 1,
"user_id": user.id,
"status": "pending",
"errors": [],
}
with pytest.raises(ChartDataQueryFailedError):
load_chart_data_into_cache(job_metadata, query_context)
with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_chart_data_into_cache(job_metadata, query_context)
ensure_user_is_set.assert_called_once_with(user.id)

mock_run_command.assert_called_with(cache=True)
mock_run_command.assert_called_once_with(cache=True)
errors = [{"message": "Error: foo"}]
mock_update_job.assert_called_with(job_metadata, "error", errors=errors)
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache(self, mock_update_job):
async_query_manager.init_app(app)
table = get_table_by_name("birth_names")
user = security_manager.find_user("gamma")
form_data = {
"datasource": f"{table.id}__table",
"viz_type": "dist_bar",
Expand All @@ -100,29 +114,40 @@ def test_load_explore_json_into_cache(self, mock_update_job):
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": 1,
"user_id": user.id,
"status": "pending",
"errors": [],
}

load_explore_json_into_cache(job_metadata, form_data)
with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_explore_json_into_cache(job_metadata, form_data)

mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
ensure_user_is_set.assert_called_once_with(user.id)
mock_update_job.assert_called_once_with(
job_metadata, "done", result_url=mock.ANY
)

@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache_error(self, mock_update_job):
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": 1,
"user_id": user.id,
"status": "pending",
"errors": [],
}

with pytest.raises(SupersetException):
load_explore_json_into_cache(job_metadata, form_data)
with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_explore_json_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id)

errors = ["The dataset associated with this chart no longer exists"]
mock_update_job.assert_called_with(job_metadata, "error", errors=errors)
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)