Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make schema name for the CTA queries and limit configurable #8867

Merged
merged 13 commits into from
Mar 3, 2020
11 changes: 10 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ jobs:
- redis-server
before_script:
- mysql -u root -e "DROP DATABASE IF EXISTS superset; CREATE DATABASE superset DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci"
- mysql -u root -e "DROP DATABASE IF EXISTS sqllab_test_db; CREATE DATABASE sqllab_test_db DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s not clear from this PR why we need these additional two databases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@john-bodley this is purely for testing, e.g. there is a need to have 2 different schemas in the mysql & postgres to test the CTA behavior

- mysql -u root -e "DROP DATABASE IF EXISTS admin_database; CREATE DATABASE admin_database DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci"
- mysql -u root -e "CREATE USER 'mysqluser'@'localhost' IDENTIFIED BY 'mysqluserpassword';"
- mysql -u root -e "GRANT ALL ON superset.* TO 'mysqluser'@'localhost';"
- mysql -u root -e "GRANT ALL ON *.* TO 'mysqluser'@'localhost';"
- language: python
env: TOXENV=javascript
before_install:
Expand All @@ -91,8 +93,15 @@ jobs:
- postgresql
- redis-server
before_script:
- psql -U postgres -c "DROP DATABASE IF EXISTS superset;"
- psql -U postgres -c "CREATE DATABASE superset;"
- psql -U postgres superset -c "DROP SCHEMA IF EXISTS sqllab_test_db;"
- psql -U postgres superset -c "CREATE SCHEMA sqllab_test_db;"
- psql -U postgres superset -c "DROP SCHEMA IF EXISTS admin_database;"
- psql -U postgres superset -c "CREATE SCHEMA admin_database;"
- psql -U postgres -c "CREATE USER postgresuser WITH PASSWORD 'pguserpassword';"
- psql -U postgres superset -c "GRANT ALL PRIVILEGES ON SCHEMA sqllab_test_db to postgresuser";
- psql -U postgres superset -c "GRANT ALL PRIVILEGES ON SCHEMA admin_database to postgresuser";
- language: python
python: 3.6
env: TOXENV=pylint
Expand Down
Empty file.
31 changes: 30 additions & 1 deletion superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import sys
from collections import OrderedDict
from datetime import date
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING

from celery.schedules import crontab
from dateutil import tz
Expand All @@ -41,6 +41,9 @@

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from flask_appbuilder.security.sqla import models # pylint: disable=unused-import
from superset.models.core import Database # pylint: disable=unused-import

# Realtime stats logger, a StatsD implementation exists
STATS_LOGGER = DummyStatsLogger()
Expand Down Expand Up @@ -523,6 +526,32 @@ class CeleryConfig: # pylint: disable=too-few-public-methods
# timeout.
SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = 10 # seconds

# Flag that controls if limit should be enforced on the CTA (create table as queries).
SQLLAB_CTAS_NO_LIMIT = False

# This allows you to define custom logic around the "CREATE TABLE AS" or CTAS feature
# in SQL Lab that defines where the target schema should be for a given user.
# Database `CTAS Schema` has a precedence over this setting.
# Example below returns a username and CTA queries will write tables into the schema
# name `username`
# SQLLAB_CTAS_SCHEMA_NAME_FUNC = lambda database, user, schema, sql: user.username
# This is move involved example where depending on the database you can leverage data
# available to assign schema for the CTA query:
# def compute_schema_name(database: Database, user: User, schema: str, sql: str) -> str:
# if database.name == 'mysql_payments_slave':
# return 'tmp_superset_schema'
# if database.name == 'presto_gold':
# return user.username
# if database.name == 'analytics':
# if 'analytics' in [r.name for r in user.roles]:
# return 'analytics_cta'
# else:
# return f'tmp_{schema}'
# Function accepts database object, user object, schema name and sql that will be run.
SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[
Callable[["Database", "models.User", str, str], str]
] = None

# An instantiated derivative of werkzeug.contrib.cache.BaseCache
# if enabled, it can be used to store the results of long-running queries
# in SQL Lab by using the "Run Async" button/feature
Expand Down
12 changes: 7 additions & 5 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> str:
return database.compile_sqla_query(qry)
elif LimitMethod.FORCE_LIMIT:
parsed_query = sql_parse.ParsedQuery(sql)
sql = parsed_query.get_query_with_new_limit(limit)
sql = parsed_query.set_or_update_query_limit(limit)
return sql

@classmethod
Expand All @@ -351,7 +351,7 @@ def get_limit_from_sql(cls, sql: str) -> Optional[int]:
return parsed_query.limit

@classmethod
def get_query_with_new_limit(cls, sql: str, limit: int) -> str:
def set_or_update_query_limit(cls, sql: str, limit: int) -> str:
"""
Create a query based on original query but with new limit clause

Expand All @@ -360,7 +360,7 @@ def get_query_with_new_limit(cls, sql: str, limit: int) -> str:
:return: Query with new limit
"""
parsed_query = sql_parse.ParsedQuery(sql)
return parsed_query.get_query_with_new_limit(limit)
return parsed_query.set_or_update_query_limit(limit)

@staticmethod
def csv_to_df(**kwargs: Any) -> pd.DataFrame:
Expand Down Expand Up @@ -632,10 +632,12 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals
"""
Generate a "SELECT * from [schema.]table_name" query with appropriate limit.

WARNING: expects only unquoted table and schema names.

:param database: Database instance
:param table_name: Table name
:param table_name: Table name, unquoted
:param engine: SqlALchemy Engine instance
:param schema: Schema
:param schema: Schema, unquoted
:param limit: limit to impose on query
:param show_cols: Show columns in query; otherwise use "*"
:param indent: Add indentation to query
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Add tmp_schema_name to the query object.

Revision ID: 72428d1ea401
Revises: 0a6f12f60c73
Create Date: 2020-02-20 08:52:22.877902

"""

# revision identifiers, used by Alembic.
revision = "72428d1ea401"
down_revision = "0a6f12f60c73"

import sqlalchemy as sa
from alembic import op


def upgrade():
op.add_column(
"query", sa.Column("tmp_schema_name", sa.String(length=256), nullable=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not to other reviewers) At first I was confused by the choice of column name here, but turns out there was already tmp_table_name for the CTAS target target table. A more appropriate name might be target_table_name or ctas_table_name, but no point in changing the current convention in this PR.

)


def downgrade():
try:
# sqlite doesn't like dropping the columns
op.drop_column("query", "tmp_schema_name")
except Exception:
pass
1 change: 1 addition & 0 deletions superset/models/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Query(Model, ExtraJSONMixin):

# Store the tmp table into the DB only if the user asks for it.
tmp_table_name = Column(String(256))
tmp_schema_name = Column(String(256))
user_id = Column(Integer, ForeignKey("ab_user.id"), nullable=True)
status = Column(String(16), default=QueryStatus.PENDING)
tab_name = Column(String(256))
Expand Down
16 changes: 13 additions & 3 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"]
SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60
SQL_MAX_ROW = config["SQL_MAX_ROW"]
SQLLAB_CTAS_NO_LIMIT = config["SQLLAB_CTAS_NO_LIMIT"]
SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"]
log_query = config["QUERY_LOGGER"]
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -207,9 +208,15 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_
query.tmp_table_name = "tmp_{}_table_{}".format(
query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S")
)
sql = parsed_query.as_create_table(query.tmp_table_name)
sql = parsed_query.as_create_table(
query.tmp_table_name, schema_name=query.tmp_schema_name
)
query.select_as_cta_used = True
if parsed_query.is_select():

# Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
if parsed_query.is_select() and not (
query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
):
if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
query.limit = SQL_MAX_ROW
if query.limit:
Expand Down Expand Up @@ -378,15 +385,18 @@ def execute_sql_statements(
payload = handle_query_error(msg, query, session, payload)
return payload

# Commit the connection so CTA queries will create the table.
conn.commit()

# Success, updating the query entry in database
query.rows = result_set.size
query.progress = 100
query.set_extra_json_key("progress", None)
if query.select_as_cta:
query.select_sql = database.select_star(
query.tmp_table_name,
schema=query.tmp_schema_name,
limit=query.limit,
schema=database.force_ctas_schema,
show_cols=False,
latest_partition=False,
)
Expand Down
31 changes: 22 additions & 9 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,20 +151,28 @@ def __process_tokenlist(self, token_list: TokenList):
self._alias_names.add(token_list.tokens[0].value)
self.__extract_from_token(token_list)

def as_create_table(self, table_name: str, overwrite: bool = False) -> str:
def as_create_table(
self,
table_name: str,
schema_name: Optional[str] = None,
overwrite: bool = False,
) -> str:
"""Reformats the query into the create table as query.

Works only for the single select SQL statements, in all other cases
the sql query is not modified.
:param table_name: Table that will contain the results of the query execution
:param table_name: table that will contain the results of the query execution
:param schema_name: schema name for the target table
:param overwrite: table_name will be dropped if true
:return: Create table as query
"""
exec_sql = ""
sql = self.stripped()
# TODO(bkyryliuk): quote full_table_name
full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn’t we address the TODO? Note the quoter needs to be dialect specific.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@john-bodley this would be an additional feature. I kept the logic as it was before.
It is worth to resolve this todo and it is existing bug in superset, but I think this PR is not a right place for the fix as it is already quite large & hard to review and comprehend.

if overwrite:
exec_sql = f"DROP TABLE IF EXISTS {table_name};\n"
exec_sql += f"CREATE TABLE {table_name} AS \n{sql}"
exec_sql = f"DROP TABLE IF EXISTS {full_table_name};\n"
exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}"
return exec_sql

def __extract_from_token(self, token: Token): # pylint: disable=too-many-branches
Expand Down Expand Up @@ -205,10 +213,12 @@ def __extract_from_token(self, token: Token): # pylint: disable=too-many-branch
if not self.__is_identifier(token2):
self.__extract_from_token(item)

def get_query_with_new_limit(self, new_limit: int) -> str:
"""
returns the query with the specified limit.
Does not change the underlying query
def set_or_update_query_limit(self, new_limit: int) -> str:
"""Returns the query with the specified limit.

Does not change the underlying query if user did not apply the limit,
otherwise replaces the limit with the lower value between existing limit
in the query and new_limit.

:param new_limit: Limit to be incorporated into returned query
:return: The original query with new limit
Expand All @@ -223,7 +233,10 @@ def get_query_with_new_limit(self, new_limit: int) -> str:
limit_pos = pos
break
_, limit = statement.token_next(idx=limit_pos)
if limit.ttype == sqlparse.tokens.Literal.Number.Integer:
# Override the limit only when it exceeds the configured value.
if limit.ttype == sqlparse.tokens.Literal.Number.Integer and new_limit < int(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmh, this here in theory changes what the function is expected to do ("returns the query with the specified limit"), so either we change the name/docstring to reflect that, or either we move the conditional logic towards where the function is called.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the docstring, yeah there is a mismatch. It's hard to move this condition outside of this function as it needs to get the existing limit value from the query and would require to parse query twice. I can't think about the usecase where we would want to override lower user limit with the higher configured value, e.g. I would expect to see 1 row when I query select * from bla limit 1 rather than 100.

limit.value
):
limit.value = new_limit
elif limit.is_group:
limit.value = f"{next(limit.get_identifiers())}, {new_limit}"
Expand Down
32 changes: 26 additions & 6 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import re
from contextlib import closing
from datetime import datetime, timedelta
from typing import Any, cast, Dict, List, Optional, Union
from typing import Any, Callable, cast, Dict, List, Optional, Union
from urllib import parse

import backoff
Expand Down Expand Up @@ -73,6 +73,7 @@
SupersetTimeoutException,
)
from superset.jinja_context import get_template_processor
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.datasource_access_request import DatasourceAccessRequest
from superset.models.slice import Slice
Expand Down Expand Up @@ -243,6 +244,17 @@ def _deserialize_results_payload(
return json.loads(payload) # type: ignore


def get_cta_schema_name(
database: Database, user: ab_models.User, schema: str, sql: str
) -> Optional[str]:
func: Optional[Callable[[Database, ab_models.User, str, str], str]] = config[
"SQLLAB_CTAS_SCHEMA_NAME_FUNC"
]
if not func:
return None
return func(database, user, schema, sql)


class AccessRequestsModelView(SupersetModelView, DeleteMixin):
datamodel = SQLAInterface(DAR)
include_route_methods = RouteMethod.CRUD_SET
Expand Down Expand Up @@ -2334,9 +2346,14 @@ def sql_json_exec(
if not mydb:
return json_error_response(f"Database with id {database_id} is missing.")

# Set tmp_table_name for CTA
# Set tmp_schema_name for CTA
# TODO(bkyryliuk): consider parsing, splitting tmp_schema_name from tmp_table_name if user enters
# <schema_name>.<table_name>
tmp_schema_name: Optional[str] = schema
if select_as_cta and mydb.force_ctas_schema:
tmp_table_name = f"{mydb.force_ctas_schema}.{tmp_table_name}"
tmp_schema_name = mydb.force_ctas_schema
elif select_as_cta:
tmp_schema_name = get_cta_schema_name(mydb, g.user, schema, sql)

# Save current query
query = Query(
Expand All @@ -2349,6 +2366,7 @@ def sql_json_exec(
status=status,
sql_editor_id=sql_editor_id,
tmp_table_name=tmp_table_name,
tmp_schema_name=tmp_schema_name,
user_id=g.user.get_id() if g.user else None,
client_id=client_id,
)
Expand Down Expand Up @@ -2389,9 +2407,11 @@ def sql_json_exec(
f"Query {query_id}: Template rendering failed: {error_msg}"
)

# set LIMIT after template processing
limits = [mydb.db_engine_spec.get_limit_from_sql(rendered_query), limit]
query.limit = min(lim for lim in limits if lim is not None)
# Limit is not applied to the CTA queries if SQLLAB_CTAS_NO_LIMIT flag is set to True.
if not (config.get("SQLLAB_CTAS_NO_LIMIT") and select_as_cta):
# set LIMIT after template processing
limits = [mydb.db_engine_spec.get_limit_from_sql(rendered_query), limit]
query.limit = min(lim for lim in limits if lim is not None)

# Flag for whether or not to expand data
# (feature that will expand Presto row objects and arrays)
Expand Down
25 changes: 15 additions & 10 deletions tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,27 @@ def run_sql(
query_limit=None,
database_name="examples",
sql_editor_id=None,
select_as_cta=False,
tmp_table_name=None,
):
if user_name:
self.logout()
self.login(username=(user_name or "admin"))
dbid = self._get_database_by_name(database_name).id
json_payload = {
"database_id": dbid,
"sql": sql,
"client_id": client_id,
"queryLimit": query_limit,
"sql_editor_id": sql_editor_id,
}
if tmp_table_name:
json_payload["tmp_table_name"] = tmp_table_name
if select_as_cta:
json_payload["select_as_cta"] = select_as_cta

resp = self.get_json_resp(
"/superset/sql_json/",
raise_on_error=False,
json_=dict(
database_id=dbid,
sql=sql,
select_as_create_as=False,
client_id=client_id,
queryLimit=query_limit,
sql_editor_id=sql_editor_id,
),
"/superset/sql_json/", raise_on_error=False, json_=json_payload
)
if raise_on_error and "error" in resp:
raise Exception("run_sql failed")
Expand Down
Loading