From b0d08216d6c36f2a266c2df217e9872a1e539fcd Mon Sep 17 00:00:00 2001 From: Diego Medina Date: Mon, 30 Jan 2023 19:43:39 -0300 Subject: [PATCH 1/3] chore: Migrate /superset/csv/ to API v1 --- docs/static/resources/openapi.json | 271 ++++++++++++++---- .../src/SqlLab/components/ResultSet/index.tsx | 11 +- superset/sqllab/api.py | 73 ++++- superset/sqllab/commands/export.py | 130 +++++++++ superset/sqllab/schemas.py | 8 + superset/views/core.py | 1 + tests/integration_tests/sql_lab/api_tests.py | 41 ++- .../sql_lab/commands_tests.py | 270 +++++++++++++++-- 8 files changed, 729 insertions(+), 76 deletions(-) create mode 100644 superset/sqllab/commands/export.py diff --git a/docs/static/resources/openapi.json b/docs/static/resources/openapi.json index 8077af91c1906..d7aecdc4c8b27 100644 --- a/docs/static/resources/openapi.json +++ b/docs/static/resources/openapi.json @@ -345,7 +345,7 @@ "AnnotationLayerRestApi.get_list": { "properties": { "changed_by": { - "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User" + "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User1" }, "changed_on": { "format": "date-time", @@ -356,7 +356,7 @@ "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User1" + "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User" }, "created_on": { "format": "date-time", @@ -502,13 +502,13 @@ "AnnotationRestApi.get_list": { "properties": { "changed_by": { - "$ref": "#/components/schemas/AnnotationRestApi.get_list.User" + "$ref": "#/components/schemas/AnnotationRestApi.get_list.User1" }, "changed_on_delta_humanized": { "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/AnnotationRestApi.get_list.User1" + "$ref": "#/components/schemas/AnnotationRestApi.get_list.User" }, "end_dttm": { "format": "date-time", @@ -1783,7 +1783,7 @@ "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/ChartDataRestApi.get_list.User3" + "$ref": "#/components/schemas/ChartDataRestApi.get_list.User2" }, "created_on_delta_humanized": { "readOnly": true @@ -1833,7 +1833,7 @@ "$ref": "#/components/schemas/ChartDataRestApi.get_list.User" }, "owners": { - "$ref": "#/components/schemas/ChartDataRestApi.get_list.User2" + "$ref": "#/components/schemas/ChartDataRestApi.get_list.User3" }, "params": { "nullable": true, @@ -1942,16 +1942,11 @@ "last_name": { "maxLength": 64, "type": "string" - }, - "username": { - "maxLength": 64, - "type": "string" } }, "required": [ "first_name", - "last_name", - "username" + "last_name" ], "type": "object" }, @@ -1968,11 +1963,16 @@ "last_name": { "maxLength": 64, "type": "string" + }, + "username": { + "maxLength": 64, + "type": "string" } }, "required": [ "first_name", - "last_name" + "last_name", + "username" ], "type": "object" }, @@ -2575,7 +2575,7 @@ "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/ChartRestApi.get_list.User3" + "$ref": "#/components/schemas/ChartRestApi.get_list.User2" }, "created_on_delta_humanized": { "readOnly": true @@ -2625,7 +2625,7 @@ "$ref": "#/components/schemas/ChartRestApi.get_list.User" }, "owners": { - "$ref": "#/components/schemas/ChartRestApi.get_list.User2" + "$ref": "#/components/schemas/ChartRestApi.get_list.User3" }, "params": { "nullable": true, @@ -2734,16 +2734,11 @@ "last_name": { "maxLength": 64, "type": "string" - }, - "username": { - "maxLength": 64, - "type": "string" } }, "required": [ "first_name", - "last_name", - "username" + "last_name" ], "type": "object" }, @@ -2760,11 +2755,16 @@ "last_name": { "maxLength": 64, "type": "string" + }, + "username": { + "maxLength": 64, + "type": "string" } }, "required": [ "first_name", - "last_name" + "last_name", + "username" ], "type": "object" }, @@ -3027,13 +3027,13 @@ "CssTemplateRestApi.get_list": { "properties": { "changed_by": { - "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User" + "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User1" }, "changed_on_delta_humanized": { "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User1" + "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User" }, "created_on": { "format": "date-time", @@ -3415,7 +3415,7 @@ "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/DashboardRestApi.get_list.User2" + "$ref": "#/components/schemas/DashboardRestApi.get_list.User1" }, "created_on_delta_humanized": { "readOnly": true @@ -3441,7 +3441,7 @@ "type": "string" }, "owners": { - "$ref": "#/components/schemas/DashboardRestApi.get_list.User1" + "$ref": "#/components/schemas/DashboardRestApi.get_list.User2" }, "position_json": { "nullable": true, @@ -3515,10 +3515,6 @@ }, "DashboardRestApi.get_list.User1": { "properties": { - "email": { - "maxLength": 64, - "type": "string" - }, "first_name": { "maxLength": 64, "type": "string" @@ -3530,22 +3526,20 @@ "last_name": { "maxLength": 64, "type": "string" - }, - "username": { - "maxLength": 64, - "type": "string" } }, "required": [ - "email", "first_name", - "last_name", - "username" + "last_name" ], "type": "object" }, "DashboardRestApi.get_list.User2": { "properties": { + "email": { + "maxLength": 64, + "type": "string" + }, "first_name": { "maxLength": 64, "type": "string" @@ -3557,11 +3551,17 @@ "last_name": { "maxLength": 64, "type": "string" + }, + "username": { + "maxLength": 64, + "type": "string" } }, "required": [ + "email", "first_name", - "last_name" + "last_name", + "username" ], "type": "object" }, @@ -4895,7 +4895,7 @@ "$ref": "#/components/schemas/DatasetRestApi.get.TableColumn" }, "created_by": { - "$ref": "#/components/schemas/DatasetRestApi.get.User2" + "$ref": "#/components/schemas/DatasetRestApi.get.User1" }, "created_on": { "format": "date-time", @@ -4959,7 +4959,7 @@ "type": "integer" }, "owners": { - "$ref": "#/components/schemas/DatasetRestApi.get.User1" + "$ref": "#/components/schemas/DatasetRestApi.get.User2" }, "schema": { "maxLength": 255, @@ -5173,23 +5173,14 @@ "maxLength": 64, "type": "string" }, - "id": { - "format": "int32", - "type": "integer" - }, "last_name": { "maxLength": 64, "type": "string" - }, - "username": { - "maxLength": 64, - "type": "string" } }, "required": [ "first_name", - "last_name", - "username" + "last_name" ], "type": "object" }, @@ -5199,14 +5190,23 @@ "maxLength": 64, "type": "string" }, + "id": { + "format": "int32", + "type": "integer" + }, "last_name": { "maxLength": 64, "type": "string" + }, + "username": { + "maxLength": 64, + "type": "string" } }, "required": [ "first_name", - "last_name" + "last_name", + "username" ], "type": "object" }, @@ -6949,7 +6949,7 @@ "type": "integer" }, "created_by": { - "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.User2" + "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.User1" }, "created_on": { "format": "date-time", @@ -6999,7 +6999,7 @@ "type": "string" }, "owners": { - "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.User1" + "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.User2" }, "recipients": { "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.ReportRecipients" @@ -7060,10 +7060,6 @@ "maxLength": 64, "type": "string" }, - "id": { - "format": "int32", - "type": "integer" - }, "last_name": { "maxLength": 64, "type": "string" @@ -7081,6 +7077,10 @@ "maxLength": 64, "type": "string" }, + "id": { + "format": "int32", + "type": "integer" + }, "last_name": { "maxLength": 64, "type": "string" @@ -9507,6 +9507,17 @@ }, "type": "object" }, + "sql_lab_export_csv_schema": { + "properties": { + "client_id": { + "type": "string" + } + }, + "required": [ + "client_id" + ], + "type": "object" + }, "sql_lab_get_results_schema": { "properties": { "key": { @@ -16686,6 +16697,99 @@ ] } }, + "/api/v1/datasource/{datasource_type}/{datasource_id}/column/{column_name}/values/": { + "get": { + "parameters": [ + { + "description": "The type of datasource", + "in": "path", + "name": "datasource_type", + "required": true, + "schema": { + "type": "string" + } + }, + { + "description": "The id of the datasource", + "in": "path", + "name": "datasource_id", + "required": true, + "schema": { + "type": "integer" + } + }, + { + "description": "The name of the column to get values for", + "in": "path", + "name": "column_name", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "properties": { + "result": { + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "object" + } + ] + }, + "type": "array" + } + }, + "type": "object" + } + } + }, + "description": "A List of distinct values for the column" + }, + "400": { + "$ref": "#/components/responses/400" + }, + "401": { + "$ref": "#/components/responses/401" + }, + "403": { + "$ref": "#/components/responses/403" + }, + "404": { + "$ref": "#/components/responses/404" + }, + "500": { + "$ref": "#/components/responses/500" + } + }, + "security": [ + { + "jwt": [] + } + ], + "summary": "Get possible values for a datasource column", + "tags": [ + "Datasources" + ] + } + }, "/api/v1/embedded_dashboard/{uuid}": { "get": { "description": "Get a report schedule log", @@ -19799,6 +19903,59 @@ ] } }, + "/api/v1/sqllab/export/": { + "get": { + "parameters": [ + { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/sql_lab_export_csv_schema" + } + } + }, + "in": "query", + "name": "q" + } + ], + "responses": { + "200": { + "content": { + "text/csv": { + "schema": { + "type": "string" + } + } + }, + "description": "SQL query results" + }, + "400": { + "$ref": "#/components/responses/400" + }, + "401": { + "$ref": "#/components/responses/401" + }, + "403": { + "$ref": "#/components/responses/403" + }, + "404": { + "$ref": "#/components/responses/404" + }, + "500": { + "$ref": "#/components/responses/500" + } + }, + "security": [ + { + "jwt": [] + } + ], + "summary": "Exports the SQL Query results to a CSV", + "tags": [ + "SQL Lab" + ] + } + }, "/api/v1/sqllab/results/": { "get": { "parameters": [ diff --git a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx index 81a4e47a11368..47f8d4acdd56a 100644 --- a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx +++ b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx @@ -17,6 +17,7 @@ * under the License. */ import React, { useCallback, useEffect, useState } from 'react'; +import rison from 'rison'; import { useDispatch } from 'react-redux'; import ButtonGroup from 'src/components/ButtonGroup'; import Alert from 'src/components/Alert'; @@ -219,6 +220,14 @@ const ResultSet = ({ } }; + const getExportCsvUrl = (clientId: string) => { + const params = rison.encode({ + client_id: clientId, + }); + + return `/api/v1/sqllab/export/?q=${params}`; + }; + const renderControls = () => { if (search || visualize || csv) { let { data } = query.results; @@ -257,7 +266,7 @@ const ResultSet = ({ /> )} {csv && ( - )} diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py index 283c3ab638707..df6acbfd267d2 100644 --- a/superset/sqllab/api.py +++ b/superset/sqllab/api.py @@ -16,6 +16,7 @@ # under the License. import logging from typing import Any, cast, Dict, Optional +from urllib import parse import simplejson as json from flask import request @@ -32,6 +33,7 @@ from superset.sql_lab import get_sql_results from superset.sqllab.command_status import SqlJsonExecutionStatus from superset.sqllab.commands.execute import CommandResult, ExecuteSqlCommand +from superset.sqllab.commands.export import SqlResultExportCommand from superset.sqllab.commands.results import SqlExecutionResultsCommand from superset.sqllab.exceptions import ( QueryIsForbiddenToAccessException, @@ -42,6 +44,7 @@ from superset.sqllab.schemas import ( ExecutePayloadSchema, QueryExecutionResponseSchema, + sql_lab_export_csv_schema, sql_lab_get_results_schema, ) from superset.sqllab.sql_json_executer import ( @@ -53,7 +56,7 @@ from superset.sqllab.validators import CanAccessQueryValidatorImpl from superset.superset_typing import FlaskResponse from superset.utils import core as utils -from superset.views.base import json_success +from superset.views.base import CsvResponse, generate_download_headers, json_success from superset.views.base_api import BaseSupersetApi, requires_json, statsd_metrics config = app.config @@ -72,6 +75,7 @@ class SqlLabRestApi(BaseSupersetApi): apispec_parameter_schemas = { "sql_lab_get_results_schema": sql_lab_get_results_schema, + "sql_lab_export_csv_schema": sql_lab_export_csv_schema, } openapi_spec_tag = "SQL Lab" openapi_spec_component_schemas = ( @@ -79,6 +83,73 @@ class SqlLabRestApi(BaseSupersetApi): QueryExecutionResponseSchema, ) + @expose("/export/") + @protect() + @statsd_metrics + @rison(sql_lab_export_csv_schema) + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".export_csv", + log_to_statsd=False, + ) + def export_csv(self, **kwargs: Any) -> CsvResponse: + """Exports the SQL Query results to a CSV + --- + get: + summary: >- + Exports the SQL Query results to a CSV + parameters: + - in: query + name: q + content: + application/json: + schema: + $ref: '#/components/schemas/sql_lab_export_csv_schema' + responses: + 200: + description: SQL query results + content: + text/csv: + schema: + type: string + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + params = kwargs["rison"] + client_id = params.get("client_id") + result = SqlResultExportCommand(client_id=client_id).run() + + query = result.get("query") + data = result.get("data") + row_count = result.get("row_count") + + quoted_csv_name = parse.quote(query.name) + response = CsvResponse( + data, headers=generate_download_headers("csv", quoted_csv_name) + ) + event_info = { + "event_type": "data_export", + "client_id": client_id, + "row_count": row_count, + "database": query.database.name, + "schema": query.schema, + "sql": query.sql, + "exported_format": "csv", + } + event_rep = repr(event_info) + logger.debug( + "CSV exported: %s", event_rep, extra={"superset_event": event_info} + ) + return response + @expose("/results/") @protect() @statsd_metrics diff --git a/superset/sqllab/commands/export.py b/superset/sqllab/commands/export.py new file mode 100644 index 0000000000000..1c189d674e474 --- /dev/null +++ b/superset/sqllab/commands/export.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-few-public-methods, too-many-arguments +from __future__ import annotations + +import logging +from typing import Any, cast, Dict + +import pandas as pd +from flask_babel import gettext as __, lazy_gettext as _ + +from superset import app, db, results_backend, results_backend_use_msgpack +from superset.commands.base import BaseCommand +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetErrorException, SupersetSecurityException +from superset.models.sql_lab import Query +from superset.sql_parse import ParsedQuery +from superset.sqllab.limiting_factor import LimitingFactor +from superset.utils import core as utils, csv +from superset.utils.dates import now_as_float +from superset.views.utils import _deserialize_results_payload + +config = app.config + +logger = logging.getLogger(__name__) + + +class SqlResultExportCommand(BaseCommand): + _client_id: str + _query: Query + + def __init__( + self, + client_id: str, + ) -> None: + self._client_id = client_id + + def validate(self) -> None: + self._query = ( + db.session.query(Query).filter_by(client_id=self._client_id).one_or_none() + ) + if self._query is None: + raise SupersetErrorException( + SupersetError( + message=__( + "The query associated with these results could not be found. " + "You need to re-run the original query." + ), + error_type=SupersetErrorType.RESULTS_BACKEND_ERROR, + level=ErrorLevel.ERROR, + ), + status=404, + ) + + try: + self._query.raise_for_access() + except SupersetSecurityException: + raise SupersetErrorException( + SupersetError( + message=__("Cannot access the query"), + error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR, + level=ErrorLevel.ERROR, + ), + status=403, + ) + + def run( + self, + ) -> Dict[str, Any]: + self.validate() + blob = None + if results_backend and self._query.results_key: + logger.info( + "Fetching CSV from results backend [%s]", self._query.results_key + ) + blob = results_backend.get(self._query.results_key) + if blob: + logger.info("Decompressing") + payload = utils.zlib_decompress( + blob, decode=not results_backend_use_msgpack + ) + obj = _deserialize_results_payload( + payload, self._query, cast(bool, results_backend_use_msgpack) + ) + + df = pd.DataFrame( + data=obj["data"], + dtype=object, + columns=[c["name"] for c in obj["columns"]], + ) + + logger.info("Using pandas to convert to CSV") + else: + logger.info("Running a query to turn into CSV") + if self._query.select_sql: + sql = self._query.select_sql + limit = None + else: + sql = self._query.executed_sql + limit = ParsedQuery(sql).limit + if limit is not None and self._query.limiting_factor in { + LimitingFactor.QUERY, + LimitingFactor.DROPDOWN, + LimitingFactor.QUERY_AND_DROPDOWN, + }: + # remove extra row from `increased_limit` + limit -= 1 + df = self._query.database.get_df(sql, self._query.schema)[:limit] + + csv_data = csv.df_to_escaped_csv(df, index=False, **config["CSV_EXPORT"]) + + return { + "query": self._query, + "count": len(df.index), + "data": csv_data, + } diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py index f238fda5c918f..428cdb89bb3e3 100644 --- a/superset/sqllab/schemas.py +++ b/superset/sqllab/schemas.py @@ -24,6 +24,14 @@ "required": ["key"], } +sql_lab_export_csv_schema = { + "type": "object", + "properties": { + "client_id": {"type": "string"}, + }, + "required": ["client_id"], +} + class ExecutePayloadSchema(Schema): database_id = fields.Integer(required=True) diff --git a/superset/views/core.py b/superset/views/core.py index 8d632dcde21bf..283ea5df0de72 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2392,6 +2392,7 @@ def _create_response_from_execution_context( # pylint: disable=invalid-name, no @has_access @event_logger.log_this @expose("/csv/") + @deprecated() def csv(self, client_id: str) -> FlaskResponse: # pylint: disable=no-self-use """Download the query results as csv.""" logger.info("Exporting CSV file [%s]", client_id) diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py index 4c2080ad4cc2f..52668593213b2 100644 --- a/tests/integration_tests/sql_lab/api_tests.py +++ b/tests/integration_tests/sql_lab/api_tests.py @@ -19,6 +19,9 @@ import datetime import json import random +import csv +import pandas as pd +import io import pytest import prison @@ -26,7 +29,7 @@ from unittest import mock from tests.integration_tests.test_app import app -from superset import sql_lab +from superset import db, sql_lab from superset.common.db_query_status import QueryStatus from superset.models.core import Database from superset.utils.database import get_example_database, get_main_database @@ -176,3 +179,39 @@ def test_get_results_with_display_limit(self): self.assertEqual(result_limited, expected_limited) app.config["RESULTS_BACKEND_USE_MSGPACK"] = use_msgpack + + @mock.patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) + @mock.patch("superset.models.core.Database.get_df") + def test_export_results(self, get_df_mock: mock.Mock) -> None: + self.login() + + database = Database( + database_name="my_export_database", sqlalchemy_uri="sqlite://" + ) + query_obj = Query( + client_id="test", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql=None, + executed_sql="select * from bar limit 2", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="test_abc2", + ) + + db.session.add(database) + db.session.add(query_obj) + + get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) + + arguments = {"client_id": "test"} + resp = self.get_resp(f"/api/v1/sqllab/export/?q={prison.dumps(arguments)}") + data = csv.reader(io.StringIO(resp)) + expected_data = csv.reader(io.StringIO(f"foo\n1\n2")) + + self.assertEqual(list(expected_data), list(data)) + db.session.rollback() diff --git a/tests/integration_tests/sql_lab/commands_tests.py b/tests/integration_tests/sql_lab/commands_tests.py index 74c1fe7082103..4e2e6642a66d7 100644 --- a/tests/integration_tests/sql_lab/commands_tests.py +++ b/tests/integration_tests/sql_lab/commands_tests.py @@ -15,23 +15,259 @@ # specific language governing permissions and limitations # under the License. from unittest import mock, skip -from unittest.mock import patch +from unittest.mock import Mock, patch +import pandas as pd import pytest from superset import db, sql_lab from superset.common.db_query_status import QueryStatus -from superset.errors import SupersetErrorType -from superset.exceptions import SerializationError, SupersetErrorException +from superset.errors import ErrorLevel, SupersetErrorType +from superset.exceptions import ( + SerializationError, + SupersetError, + SupersetErrorException, + SupersetSecurityException, +) from superset.models.core import Database from superset.models.sql_lab import Query -from superset.sqllab.commands import results +from superset.sqllab.commands import export, results +from superset.sqllab.limiting_factor import LimitingFactor from superset.utils import core as utils from tests.integration_tests.base_tests import SupersetTestCase +class TestSqlResultExportCommand(SupersetTestCase): + def test_validation_query_not_found(self) -> None: + command = export.SqlResultExportCommand("asdf") + + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test1", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc1", + ) + + db.session.add(database) + db.session.add(query_obj) + db.session.commit() + + with pytest.raises(SupersetErrorException) as ex_info: + command.run() + assert ex_info.value.error.error_type == SupersetErrorType.RESULTS_BACKEND_ERROR + + def test_validation_invalid_access(self) -> None: + command = export.SqlResultExportCommand("test2") + + database = Database(database_name="my_database2", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test2", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc2", + ) + + db.session.add(database) + db.session.add(query_obj) + db.session.commit() + + with mock.patch( + "superset.security_manager.raise_for_access", + side_effect=SupersetSecurityException( + SupersetError( + "dummy", + SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, + ErrorLevel.ERROR, + ) + ), + ): + with pytest.raises(SupersetErrorException) as ex_info: + command.run() + assert ( + ex_info.value.error.error_type + == SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR + ) + + @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) + @patch("superset.models.core.Database.get_df") + def test_run_no_results_backend_select_sql(self, get_df_mock: Mock) -> None: + command = export.SqlResultExportCommand("test3") + + database = Database(database_name="my_database3", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test3", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc3", + ) + + db.session.add(database) + db.session.add(query_obj) + db.session.commit() + + get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) + + result = command.run() + + data = result.get("data") + count = result.get("count") + query = result.get("query") + + assert data == "foo\n1\n2\n3\n" + assert count == 3 + assert query.client_id == "test3" + + @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) + @patch("superset.models.core.Database.get_df") + def test_run_no_results_backend_executed_sql(self, get_df_mock: Mock) -> None: + command = export.SqlResultExportCommand("test4") + + database = Database(database_name="my_database4", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test4", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql=None, + executed_sql="select * from bar limit 2", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc4", + ) + + db.session.add(database) + db.session.add(query_obj) + db.session.commit() + + get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) + + result = command.run() + + data = result.get("data") + count = result.get("count") + query = result.get("query") + + assert data == "foo\n1\n2\n" + assert count == 2 + assert query.client_id == "test4" + + @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) + @patch("superset.models.core.Database.get_df") + def test_run_no_results_backend_executed_sql_limiting_factor( + self, get_df_mock: Mock + ) -> None: + command = export.SqlResultExportCommand("test5") + + database = Database(database_name="my_database5", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test5", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql=None, + executed_sql="select * from bar limit 2", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc5", + limiting_factor=LimitingFactor.DROPDOWN, + ) + + db.session.add(database) + db.session.add(query_obj) + db.session.commit() + + get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) + + result = command.run() + + data = result.get("data") + count = result.get("count") + query = result.get("query") + + assert data == "foo\n1\n" + assert count == 1 + assert query.client_id == "test5" + + @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) + @patch("superset.sqllab.commands.export.results_backend_use_msgpack", False) + def test_run_with_results_backend(self) -> None: + command = export.SqlResultExportCommand("test6") + + database = Database(database_name="my_database6", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test6", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc6", + ) + + db.session.add(database) + db.session.add(query_obj) + db.session.commit() + + data = [{"foo": i} for i in range(5)] + payload = { + "columns": [{"name": "foo"}], + "data": data, + } + serialized_payload = sql_lab._serialize_payload(payload, False) + compressed = utils.zlib_compress(serialized_payload) + + export.results_backend = mock.Mock() + export.results_backend.get.return_value = compressed + + result = command.run() + + data = result.get("data") + count = result.get("count") + query = result.get("query") + + assert data == "foo\n0\n1\n2\n3\n4\n" + assert count == 5 + assert query.client_id == "test6" + + class TestSqlExecutionResultsCommand(SupersetTestCase): - @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) + @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) def test_validation_no_results_backend(self) -> None: results.results_backend = None @@ -44,7 +280,7 @@ def test_validation_no_results_backend(self) -> None: == SupersetErrorType.RESULTS_BACKEND_NOT_CONFIGURED_ERROR ) - @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) + @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) def test_validation_data_cannot_be_retrieved(self) -> None: results.results_backend = mock.Mock() results.results_backend.get.return_value = None @@ -55,8 +291,8 @@ def test_validation_data_cannot_be_retrieved(self) -> None: command.run() assert ex_info.value.error.error_type == SupersetErrorType.RESULTS_BACKEND_ERROR - @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) - def test_validation_query_not_found(self) -> None: + @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) + def test_validation_data_not_found(self) -> None: data = [{"col_0": i} for i in range(100)] payload = { "status": QueryStatus.SUCCESS, @@ -75,8 +311,8 @@ def test_validation_query_not_found(self) -> None: command.run() assert ex_info.value.error.error_type == SupersetErrorType.RESULTS_BACKEND_ERROR - @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) - def test_validation_query_not_found2(self) -> None: + @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) + def test_validation_query_not_found(self) -> None: data = [{"col_0": i} for i in range(104)] payload = { "status": QueryStatus.SUCCESS, @@ -89,9 +325,9 @@ def test_validation_query_not_found2(self) -> None: results.results_backend = mock.Mock() results.results_backend.get.return_value = compressed - database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database7", sqlalchemy_uri="sqlite://") query_obj = Query( - client_id="foo", + client_id="test8", database=database, tab_name="test_tab", sql_editor_id="test_editor_id", @@ -102,11 +338,12 @@ def test_validation_query_not_found2(self) -> None: select_as_cta=False, rows=104, error_message="none", - results_key="test_abc", + results_key="abc7", ) db.session.add(database) db.session.add(query_obj) + db.session.commit() with mock.patch( "superset.views.utils._deserialize_results_payload", @@ -120,7 +357,7 @@ def test_validation_query_not_found2(self) -> None: == SupersetErrorType.RESULTS_BACKEND_ERROR ) - @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) + @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) def test_run_succeeds(self) -> None: data = [{"col_0": i} for i in range(104)] payload = { @@ -134,9 +371,9 @@ def test_run_succeeds(self) -> None: results.results_backend = mock.Mock() results.results_backend.get.return_value = compressed - database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database8", sqlalchemy_uri="sqlite://") query_obj = Query( - client_id="foo", + client_id="test9", database=database, tab_name="test_tab", sql_editor_id="test_editor_id", @@ -152,6 +389,7 @@ def test_run_succeeds(self) -> None: db.session.add(database) db.session.add(query_obj) + db.session.commit() command = results.SqlExecutionResultsCommand("test_abc", 1000) result = command.run() From 5d5e756dc14e9e18a57f61848d5394354d870d55 Mon Sep 17 00:00:00 2001 From: Diego Medina Date: Tue, 7 Feb 2023 00:10:30 -0300 Subject: [PATCH 2/3] improvements --- docs/static/resources/openapi.json | 168 ++++++------ .../src/SqlLab/components/ResultSet/index.tsx | 10 +- superset/sqllab/api.py | 26 +- superset/sqllab/commands/export.py | 10 +- superset/sqllab/schemas.py | 8 - tests/integration_tests/sql_lab/api_tests.py | 7 +- .../sql_lab/commands_tests.py | 258 ++++++------------ 7 files changed, 189 insertions(+), 298 deletions(-) diff --git a/docs/static/resources/openapi.json b/docs/static/resources/openapi.json index d0865ce966776..18ea7a47f8f19 100644 --- a/docs/static/resources/openapi.json +++ b/docs/static/resources/openapi.json @@ -345,7 +345,7 @@ "AnnotationLayerRestApi.get_list": { "properties": { "changed_by": { - "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User1" + "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User" }, "changed_on": { "format": "date-time", @@ -356,7 +356,7 @@ "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User" + "$ref": "#/components/schemas/AnnotationLayerRestApi.get_list.User1" }, "created_on": { "format": "date-time", @@ -502,13 +502,13 @@ "AnnotationRestApi.get_list": { "properties": { "changed_by": { - "$ref": "#/components/schemas/AnnotationRestApi.get_list.User1" + "$ref": "#/components/schemas/AnnotationRestApi.get_list.User" }, "changed_on_delta_humanized": { "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/AnnotationRestApi.get_list.User" + "$ref": "#/components/schemas/AnnotationRestApi.get_list.User1" }, "end_dttm": { "format": "date-time", @@ -1223,6 +1223,7 @@ "example": false }, "periods": { + "description": "Time periods (in units of `time_grain`) to predict into the future", "example": 7, "format": "int32", "type": "integer" @@ -1578,6 +1579,7 @@ "type": "string" }, "from_dttm": { + "description": "Start timestamp of time range", "format": "int32", "nullable": true, "type": "integer" @@ -1603,6 +1605,7 @@ "type": "integer" }, "stacktrace": { + "description": "Stacktrace if there was an error", "nullable": true, "type": "string" }, @@ -1620,6 +1623,7 @@ "type": "string" }, "to_dttm": { + "description": "End timestamp of time range", "format": "int32", "nullable": true, "type": "integer" @@ -1833,7 +1837,7 @@ "$ref": "#/components/schemas/ChartDataRestApi.get_list.User3" }, "owners": { - "$ref": "#/components/schemas/ChartDataRestApi.get_list.User3" + "$ref": "#/components/schemas/ChartDataRestApi.get_list.User1" }, "params": { "nullable": true, @@ -1921,11 +1925,16 @@ "last_name": { "maxLength": 64, "type": "string" + }, + "username": { + "maxLength": 64, + "type": "string" } }, "required": [ "first_name", - "last_name" + "last_name", + "username" ], "type": "object" }, @@ -1963,16 +1972,11 @@ "last_name": { "maxLength": 64, "type": "string" - }, - "username": { - "maxLength": 64, - "type": "string" } }, "required": [ "first_name", - "last_name", - "username" + "last_name" ], "type": "object" }, @@ -2232,6 +2236,7 @@ "type": "string" }, "rolling_type_options": { + "description": "Optional options to pass to rolling method. Needed for e.g. quantile operation.", "example": {}, "type": "object" }, @@ -2625,7 +2630,7 @@ "$ref": "#/components/schemas/ChartRestApi.get_list.User3" }, "owners": { - "$ref": "#/components/schemas/ChartRestApi.get_list.User3" + "$ref": "#/components/schemas/ChartRestApi.get_list.User1" }, "params": { "nullable": true, @@ -2713,11 +2718,16 @@ "last_name": { "maxLength": 64, "type": "string" + }, + "username": { + "maxLength": 64, + "type": "string" } }, "required": [ "first_name", - "last_name" + "last_name", + "username" ], "type": "object" }, @@ -2755,16 +2765,11 @@ "last_name": { "maxLength": 64, "type": "string" - }, - "username": { - "maxLength": 64, - "type": "string" } }, "required": [ "first_name", - "last_name", - "username" + "last_name" ], "type": "object" }, @@ -3027,13 +3032,13 @@ "CssTemplateRestApi.get_list": { "properties": { "changed_by": { - "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User1" + "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User" }, "changed_on_delta_humanized": { "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User" + "$ref": "#/components/schemas/CssTemplateRestApi.get_list.User1" }, "created_on": { "format": "date-time", @@ -3415,7 +3420,7 @@ "readOnly": true }, "created_by": { - "$ref": "#/components/schemas/DashboardRestApi.get_list.User1" + "$ref": "#/components/schemas/DashboardRestApi.get_list.User2" }, "created_on_delta_humanized": { "readOnly": true @@ -3441,7 +3446,7 @@ "type": "string" }, "owners": { - "$ref": "#/components/schemas/DashboardRestApi.get_list.User2" + "$ref": "#/components/schemas/DashboardRestApi.get_list.User1" }, "position_json": { "nullable": true, @@ -3515,6 +3520,10 @@ }, "DashboardRestApi.get_list.User1": { "properties": { + "email": { + "maxLength": 64, + "type": "string" + }, "first_name": { "maxLength": 64, "type": "string" @@ -3526,20 +3535,22 @@ "last_name": { "maxLength": 64, "type": "string" + }, + "username": { + "maxLength": 64, + "type": "string" } }, "required": [ + "email", "first_name", - "last_name" + "last_name", + "username" ], "type": "object" }, "DashboardRestApi.get_list.User2": { "properties": { - "email": { - "maxLength": 64, - "type": "string" - }, "first_name": { "maxLength": 64, "type": "string" @@ -3551,17 +3562,11 @@ "last_name": { "maxLength": 64, "type": "string" - }, - "username": { - "maxLength": 64, - "type": "string" } }, "required": [ - "email", "first_name", - "last_name", - "username" + "last_name" ], "type": "object" }, @@ -4912,7 +4917,7 @@ "$ref": "#/components/schemas/DatasetRestApi.get.TableColumn" }, "created_by": { - "$ref": "#/components/schemas/DatasetRestApi.get.User1" + "$ref": "#/components/schemas/DatasetRestApi.get.User2" }, "created_on": { "format": "date-time", @@ -4976,7 +4981,7 @@ "type": "integer" }, "owners": { - "$ref": "#/components/schemas/DatasetRestApi.get.User2" + "$ref": "#/components/schemas/DatasetRestApi.get.User1" }, "schema": { "maxLength": 255, @@ -5190,14 +5195,23 @@ "maxLength": 64, "type": "string" }, + "id": { + "format": "int32", + "type": "integer" + }, "last_name": { "maxLength": 64, "type": "string" + }, + "username": { + "maxLength": 64, + "type": "string" } }, "required": [ "first_name", - "last_name" + "last_name", + "username" ], "type": "object" }, @@ -5207,30 +5221,21 @@ "maxLength": 64, "type": "string" }, - "id": { - "format": "int32", - "type": "integer" - }, "last_name": { "maxLength": 64, "type": "string" - }, - "username": { - "maxLength": 64, - "type": "string" } }, "required": [ "first_name", - "last_name", - "username" + "last_name" ], "type": "object" }, "DatasetRestApi.get_list": { "properties": { "changed_by": { - "$ref": "#/components/schemas/DatasetRestApi.get_list.User1" + "$ref": "#/components/schemas/DatasetRestApi.get_list.User" }, "changed_by_name": { "readOnly": true @@ -5273,7 +5278,7 @@ "readOnly": true }, "owners": { - "$ref": "#/components/schemas/DatasetRestApi.get_list.User" + "$ref": "#/components/schemas/DatasetRestApi.get_list.User1" }, "schema": { "maxLength": 255, @@ -5317,14 +5322,6 @@ "maxLength": 64, "type": "string" }, - "id": { - "format": "int32", - "type": "integer" - }, - "last_name": { - "maxLength": 64, - "type": "string" - }, "username": { "maxLength": 64, "type": "string" @@ -5332,7 +5329,6 @@ }, "required": [ "first_name", - "last_name", "username" ], "type": "object" @@ -5343,6 +5339,14 @@ "maxLength": 64, "type": "string" }, + "id": { + "format": "int32", + "type": "integer" + }, + "last_name": { + "maxLength": 64, + "type": "string" + }, "username": { "maxLength": 64, "type": "string" @@ -5350,6 +5354,7 @@ }, "required": [ "first_name", + "last_name", "username" ], "type": "object" @@ -6966,7 +6971,7 @@ "type": "integer" }, "created_by": { - "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.User1" + "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.User2" }, "created_on": { "format": "date-time", @@ -7016,7 +7021,7 @@ "type": "string" }, "owners": { - "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.User2" + "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.User1" }, "recipients": { "$ref": "#/components/schemas/ReportScheduleRestApi.get_list.ReportRecipients" @@ -7077,6 +7082,10 @@ "maxLength": 64, "type": "string" }, + "id": { + "format": "int32", + "type": "integer" + }, "last_name": { "maxLength": 64, "type": "string" @@ -7094,10 +7103,6 @@ "maxLength": 64, "type": "string" }, - "id": { - "format": "int32", - "type": "integer" - }, "last_name": { "maxLength": 64, "type": "string" @@ -9538,17 +9543,6 @@ }, "type": "object" }, - "sql_lab_export_csv_schema": { - "properties": { - "client_id": { - "type": "string" - } - }, - "required": [ - "client_id" - ], - "type": "object" - }, "sql_lab_get_results_schema": { "properties": { "key": { @@ -20008,19 +20002,17 @@ ] } }, - "/api/v1/sqllab/export/": { + "/api/v1/sqllab/export/{client_id}/": { "get": { "parameters": [ { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/sql_lab_export_csv_schema" - } - } - }, - "in": "query", - "name": "q" + "description": "The SQL query result identifier", + "in": "path", + "name": "client_id", + "required": true, + "schema": { + "type": "integer" + } } ], "responses": { @@ -20055,7 +20047,7 @@ "jwt": [] } ], - "summary": "Exports the SQL Query results to a CSV", + "summary": "Exports the SQL query results to a CSV", "tags": [ "SQL Lab" ] diff --git a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx index cb24b5e68b867..fad6c98bc94b8 100644 --- a/superset-frontend/src/SqlLab/components/ResultSet/index.tsx +++ b/superset-frontend/src/SqlLab/components/ResultSet/index.tsx @@ -17,7 +17,6 @@ * under the License. */ import React, { useCallback, useEffect, useState } from 'react'; -import rison from 'rison'; import { useDispatch } from 'react-redux'; import ButtonGroup from 'src/components/ButtonGroup'; import Alert from 'src/components/Alert'; @@ -220,13 +219,8 @@ const ResultSet = ({ } }; - const getExportCsvUrl = (clientId: string) => { - const params = rison.encode({ - client_id: clientId, - }); - - return `/api/v1/sqllab/export/?q=${params}`; - }; + const getExportCsvUrl = (clientId: string) => + `/api/v1/sqllab/export/${clientId}/`; const renderControls = () => { if (search || visualize || csv) { diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py index df6acbfd267d2..f73ef749d4936 100644 --- a/superset/sqllab/api.py +++ b/superset/sqllab/api.py @@ -44,7 +44,6 @@ from superset.sqllab.schemas import ( ExecutePayloadSchema, QueryExecutionResponseSchema, - sql_lab_export_csv_schema, sql_lab_get_results_schema, ) from superset.sqllab.sql_json_executer import ( @@ -75,7 +74,6 @@ class SqlLabRestApi(BaseSupersetApi): apispec_parameter_schemas = { "sql_lab_get_results_schema": sql_lab_get_results_schema, - "sql_lab_export_csv_schema": sql_lab_export_csv_schema, } openapi_spec_tag = "SQL Lab" openapi_spec_component_schemas = ( @@ -83,28 +81,26 @@ class SqlLabRestApi(BaseSupersetApi): QueryExecutionResponseSchema, ) - @expose("/export/") + @expose("/export//") @protect() @statsd_metrics - @rison(sql_lab_export_csv_schema) @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".export_csv", log_to_statsd=False, ) - def export_csv(self, **kwargs: Any) -> CsvResponse: - """Exports the SQL Query results to a CSV + def export_csv(self, client_id: str) -> CsvResponse: + """Exports the SQL query results to a CSV --- get: summary: >- - Exports the SQL Query results to a CSV + Exports the SQL query results to a CSV parameters: - - in: query - name: q - content: - application/json: - schema: - $ref: '#/components/schemas/sql_lab_export_csv_schema' + - in: path + schema: + type: integer + name: client_id + description: The SQL query result identifier responses: 200: description: SQL query results @@ -123,13 +119,11 @@ def export_csv(self, **kwargs: Any) -> CsvResponse: 500: $ref: '#/components/responses/500' """ - params = kwargs["rison"] - client_id = params.get("client_id") result = SqlResultExportCommand(client_id=client_id).run() query = result.get("query") data = result.get("data") - row_count = result.get("row_count") + row_count = result.get("count") quoted_csv_name = parse.quote(query.name) response = CsvResponse( diff --git a/superset/sqllab/commands/export.py b/superset/sqllab/commands/export.py index 1c189d674e474..feca664225c7e 100644 --- a/superset/sqllab/commands/export.py +++ b/superset/sqllab/commands/export.py @@ -18,7 +18,7 @@ from __future__ import annotations import logging -from typing import Any, cast, Dict +from typing import Any, cast, List, TypedDict import pandas as pd from flask_babel import gettext as __, lazy_gettext as _ @@ -39,6 +39,12 @@ logger = logging.getLogger(__name__) +class SqlExportResult(TypedDict): + query: Query + count: int + data: List[Any] + + class SqlResultExportCommand(BaseCommand): _client_id: str _query: Query @@ -80,7 +86,7 @@ def validate(self) -> None: def run( self, - ) -> Dict[str, Any]: + ) -> SqlExportResult: self.validate() blob = None if results_backend and self._query.results_key: diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py index 428cdb89bb3e3..f238fda5c918f 100644 --- a/superset/sqllab/schemas.py +++ b/superset/sqllab/schemas.py @@ -24,14 +24,6 @@ "required": ["key"], } -sql_lab_export_csv_schema = { - "type": "object", - "properties": { - "client_id": {"type": "string"}, - }, - "required": ["client_id"], -} - class ExecutePayloadSchema(Schema): database_id = fields.Integer(required=True) diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py index 52668593213b2..23b6f2c9deec0 100644 --- a/tests/integration_tests/sql_lab/api_tests.py +++ b/tests/integration_tests/sql_lab/api_tests.py @@ -200,7 +200,7 @@ def test_export_results(self, get_df_mock: mock.Mock) -> None: select_as_cta=False, rows=104, error_message="none", - results_key="test_abc2", + results_key="test_abc", ) db.session.add(database) @@ -208,10 +208,9 @@ def test_export_results(self, get_df_mock: mock.Mock) -> None: get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) - arguments = {"client_id": "test"} - resp = self.get_resp(f"/api/v1/sqllab/export/?q={prison.dumps(arguments)}") + resp = self.get_resp("/api/v1/sqllab/export/test/") data = csv.reader(io.StringIO(resp)) - expected_data = csv.reader(io.StringIO(f"foo\n1\n2")) + expected_data = csv.reader(io.StringIO("foo\n1\n2")) self.assertEqual(list(expected_data), list(data)) db.session.rollback() diff --git a/tests/integration_tests/sql_lab/commands_tests.py b/tests/integration_tests/sql_lab/commands_tests.py index 4e2e6642a66d7..2293b20a5fcbb 100644 --- a/tests/integration_tests/sql_lab/commands_tests.py +++ b/tests/integration_tests/sql_lab/commands_tests.py @@ -38,55 +38,46 @@ class TestSqlResultExportCommand(SupersetTestCase): - def test_validation_query_not_found(self) -> None: - command = export.SqlResultExportCommand("asdf") + @pytest.fixture() + def create_database_and_query(self): + with self.create_app().app_context(): + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc_query", + ) - database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") - query_obj = Query( - client_id="test1", - database=database, - tab_name="test_tab", - sql_editor_id="test_editor_id", - sql="select * from bar", - select_sql="select * from bar", - executed_sql="select * from bar", - limit=100, - select_as_cta=False, - rows=104, - error_message="none", - results_key="abc1", - ) + db.session.add(database) + db.session.add(query_obj) + db.session.commit() - db.session.add(database) - db.session.add(query_obj) - db.session.commit() + yield + + db.session.delete(query_obj) + db.session.delete(database) + db.session.commit() + + @pytest.mark.usefixtures("create_database_and_query") + def test_validation_query_not_found(self) -> None: + command = export.SqlResultExportCommand("asdf") with pytest.raises(SupersetErrorException) as ex_info: command.run() assert ex_info.value.error.error_type == SupersetErrorType.RESULTS_BACKEND_ERROR + @pytest.mark.usefixtures("create_database_and_query") def test_validation_invalid_access(self) -> None: - command = export.SqlResultExportCommand("test2") - - database = Database(database_name="my_database2", sqlalchemy_uri="sqlite://") - query_obj = Query( - client_id="test2", - database=database, - tab_name="test_tab", - sql_editor_id="test_editor_id", - sql="select * from bar", - select_sql="select * from bar", - executed_sql="select * from bar", - limit=100, - select_as_cta=False, - rows=104, - error_message="none", - results_key="abc2", - ) - - db.session.add(database) - db.session.add(query_obj) - db.session.commit() + command = export.SqlResultExportCommand("test") with mock.patch( "superset.security_manager.raise_for_access", @@ -105,33 +96,13 @@ def test_validation_invalid_access(self) -> None: == SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR ) + @pytest.mark.usefixtures("create_database_and_query") @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) @patch("superset.models.core.Database.get_df") def test_run_no_results_backend_select_sql(self, get_df_mock: Mock) -> None: - command = export.SqlResultExportCommand("test3") - - database = Database(database_name="my_database3", sqlalchemy_uri="sqlite://") - query_obj = Query( - client_id="test3", - database=database, - tab_name="test_tab", - sql_editor_id="test_editor_id", - sql="select * from bar", - select_sql="select * from bar", - executed_sql="select * from bar", - limit=100, - select_as_cta=False, - rows=104, - error_message="none", - results_key="abc3", - ) - - db.session.add(database) - db.session.add(query_obj) - db.session.commit() + command = export.SqlResultExportCommand("test") get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) - result = command.run() data = result.get("data") @@ -140,35 +111,20 @@ def test_run_no_results_backend_select_sql(self, get_df_mock: Mock) -> None: assert data == "foo\n1\n2\n3\n" assert count == 3 - assert query.client_id == "test3" + assert query.client_id == "test" + @pytest.mark.usefixtures("create_database_and_query") @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) @patch("superset.models.core.Database.get_df") def test_run_no_results_backend_executed_sql(self, get_df_mock: Mock) -> None: - command = export.SqlResultExportCommand("test4") - - database = Database(database_name="my_database4", sqlalchemy_uri="sqlite://") - query_obj = Query( - client_id="test4", - database=database, - tab_name="test_tab", - sql_editor_id="test_editor_id", - sql="select * from bar", - select_sql=None, - executed_sql="select * from bar limit 2", - limit=100, - select_as_cta=False, - rows=104, - error_message="none", - results_key="abc4", - ) - - db.session.add(database) - db.session.add(query_obj) + query_obj = db.session.query(Query).filter_by(client_id="test").one() + query_obj.executed_sql = "select * from bar limit 2" + query_obj.select_sql = None db.session.commit() - get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) + command = export.SqlResultExportCommand("test") + get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) result = command.run() data = result.get("data") @@ -177,36 +133,22 @@ def test_run_no_results_backend_executed_sql(self, get_df_mock: Mock) -> None: assert data == "foo\n1\n2\n" assert count == 2 - assert query.client_id == "test4" + assert query.client_id == "test" + @pytest.mark.usefixtures("create_database_and_query") @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) @patch("superset.models.core.Database.get_df") def test_run_no_results_backend_executed_sql_limiting_factor( self, get_df_mock: Mock ) -> None: - command = export.SqlResultExportCommand("test5") - - database = Database(database_name="my_database5", sqlalchemy_uri="sqlite://") - query_obj = Query( - client_id="test5", - database=database, - tab_name="test_tab", - sql_editor_id="test_editor_id", - sql="select * from bar", - select_sql=None, - executed_sql="select * from bar limit 2", - limit=100, - select_as_cta=False, - rows=104, - error_message="none", - results_key="abc5", - limiting_factor=LimitingFactor.DROPDOWN, - ) - - db.session.add(database) - db.session.add(query_obj) + query_obj = db.session.query(Query).filter_by(results_key="abc_query").one() + query_obj.executed_sql = "select * from bar limit 2" + query_obj.select_sql = None + query_obj.limiting_factor = LimitingFactor.DROPDOWN db.session.commit() + command = export.SqlResultExportCommand("test") + get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) result = command.run() @@ -217,32 +159,13 @@ def test_run_no_results_backend_executed_sql_limiting_factor( assert data == "foo\n1\n" assert count == 1 - assert query.client_id == "test5" + assert query.client_id == "test" + @pytest.mark.usefixtures("create_database_and_query") @patch("superset.models.sql_lab.Query.raise_for_access", lambda _: None) @patch("superset.sqllab.commands.export.results_backend_use_msgpack", False) def test_run_with_results_backend(self) -> None: - command = export.SqlResultExportCommand("test6") - - database = Database(database_name="my_database6", sqlalchemy_uri="sqlite://") - query_obj = Query( - client_id="test6", - database=database, - tab_name="test_tab", - sql_editor_id="test_editor_id", - sql="select * from bar", - select_sql="select * from bar", - executed_sql="select * from bar", - limit=100, - select_as_cta=False, - rows=104, - error_message="none", - results_key="abc6", - ) - - db.session.add(database) - db.session.add(query_obj) - db.session.commit() + command = export.SqlResultExportCommand("test") data = [{"foo": i} for i in range(5)] payload = { @@ -263,10 +186,39 @@ def test_run_with_results_backend(self) -> None: assert data == "foo\n0\n1\n2\n3\n4\n" assert count == 5 - assert query.client_id == "test6" + assert query.client_id == "test" class TestSqlExecutionResultsCommand(SupersetTestCase): + @pytest.fixture() + def create_database_and_query(self): + with self.create_app().app_context(): + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + query_obj = Query( + client_id="test", + database=database, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=104, + error_message="none", + results_key="abc_query", + ) + + db.session.add(database) + db.session.add(query_obj) + db.session.commit() + + yield + + db.session.delete(query_obj) + db.session.delete(database) + db.session.commit() + @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) def test_validation_no_results_backend(self) -> None: results.results_backend = None @@ -311,6 +263,7 @@ def test_validation_data_not_found(self) -> None: command.run() assert ex_info.value.error.error_type == SupersetErrorType.RESULTS_BACKEND_ERROR + @pytest.mark.usefixtures("create_database_and_query") @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) def test_validation_query_not_found(self) -> None: data = [{"col_0": i} for i in range(104)] @@ -325,38 +278,19 @@ def test_validation_query_not_found(self) -> None: results.results_backend = mock.Mock() results.results_backend.get.return_value = compressed - database = Database(database_name="my_database7", sqlalchemy_uri="sqlite://") - query_obj = Query( - client_id="test8", - database=database, - tab_name="test_tab", - sql_editor_id="test_editor_id", - sql="select * from bar", - select_sql="select * from bar", - executed_sql="select * from bar", - limit=100, - select_as_cta=False, - rows=104, - error_message="none", - results_key="abc7", - ) - - db.session.add(database) - db.session.add(query_obj) - db.session.commit() - with mock.patch( "superset.views.utils._deserialize_results_payload", side_effect=SerializationError(), ): with pytest.raises(SupersetErrorException) as ex_info: - command = results.SqlExecutionResultsCommand("test", 1000) + command = results.SqlExecutionResultsCommand("test_other", 1000) command.run() assert ( ex_info.value.error.error_type == SupersetErrorType.RESULTS_BACKEND_ERROR ) + @pytest.mark.usefixtures("create_database_and_query") @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) def test_run_succeeds(self) -> None: data = [{"col_0": i} for i in range(104)] @@ -371,27 +305,7 @@ def test_run_succeeds(self) -> None: results.results_backend = mock.Mock() results.results_backend.get.return_value = compressed - database = Database(database_name="my_database8", sqlalchemy_uri="sqlite://") - query_obj = Query( - client_id="test9", - database=database, - tab_name="test_tab", - sql_editor_id="test_editor_id", - sql="select * from bar", - select_sql="select * from bar", - executed_sql="select * from bar", - limit=100, - select_as_cta=False, - rows=104, - error_message="none", - results_key="test_abc", - ) - - db.session.add(database) - db.session.add(query_obj) - db.session.commit() - - command = results.SqlExecutionResultsCommand("test_abc", 1000) + command = results.SqlExecutionResultsCommand("abc_query", 1000) result = command.run() assert result.get("status") == "success" From d48e9e703cf6e1b31af1c6420cf6dd941648cb2c Mon Sep 17 00:00:00 2001 From: Diego Medina Date: Wed, 8 Feb 2023 23:09:50 -0300 Subject: [PATCH 3/3] improvements --- tests/integration_tests/sql_lab/api_tests.py | 9 ++++----- tests/integration_tests/sql_lab/commands_tests.py | 9 +++------ 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py index 23b6f2c9deec0..93beb380f0db6 100644 --- a/tests/integration_tests/sql_lab/api_tests.py +++ b/tests/integration_tests/sql_lab/api_tests.py @@ -185,9 +185,7 @@ def test_get_results_with_display_limit(self): def test_export_results(self, get_df_mock: mock.Mock) -> None: self.login() - database = Database( - database_name="my_export_database", sqlalchemy_uri="sqlite://" - ) + database = get_example_database() query_obj = Query( client_id="test", database=database, @@ -203,8 +201,8 @@ def test_export_results(self, get_df_mock: mock.Mock) -> None: results_key="test_abc", ) - db.session.add(database) db.session.add(query_obj) + db.session.commit() get_df_mock.return_value = pd.DataFrame({"foo": [1, 2, 3]}) @@ -213,4 +211,5 @@ def test_export_results(self, get_df_mock: mock.Mock) -> None: expected_data = csv.reader(io.StringIO("foo\n1\n2")) self.assertEqual(list(expected_data), list(data)) - db.session.rollback() + db.session.delete(query_obj) + db.session.commit() diff --git a/tests/integration_tests/sql_lab/commands_tests.py b/tests/integration_tests/sql_lab/commands_tests.py index 2293b20a5fcbb..edb71552370b7 100644 --- a/tests/integration_tests/sql_lab/commands_tests.py +++ b/tests/integration_tests/sql_lab/commands_tests.py @@ -34,6 +34,7 @@ from superset.sqllab.commands import export, results from superset.sqllab.limiting_factor import LimitingFactor from superset.utils import core as utils +from superset.utils.database import get_example_database from tests.integration_tests.base_tests import SupersetTestCase @@ -41,7 +42,7 @@ class TestSqlResultExportCommand(SupersetTestCase): @pytest.fixture() def create_database_and_query(self): with self.create_app().app_context(): - database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = get_example_database() query_obj = Query( client_id="test", database=database, @@ -57,14 +58,12 @@ def create_database_and_query(self): results_key="abc_query", ) - db.session.add(database) db.session.add(query_obj) db.session.commit() yield db.session.delete(query_obj) - db.session.delete(database) db.session.commit() @pytest.mark.usefixtures("create_database_and_query") @@ -193,7 +192,7 @@ class TestSqlExecutionResultsCommand(SupersetTestCase): @pytest.fixture() def create_database_and_query(self): with self.create_app().app_context(): - database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = get_example_database() query_obj = Query( client_id="test", database=database, @@ -209,14 +208,12 @@ def create_database_and_query(self): results_key="abc_query", ) - db.session.add(database) db.session.add(query_obj) db.session.commit() yield db.session.delete(query_obj) - db.session.delete(database) db.session.commit() @patch("superset.sqllab.commands.results.results_backend_use_msgpack", False)