diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index c8ed840c7c5c7..885b6691b3653 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -42,7 +42,7 @@ from superset.charts.post_processing import apply_post_process from superset.charts.schemas import ChartDataQueryContextSchema from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType -from superset.connectors.base.models import BaseDatasource +from superset.connectors.sqla.models import BaseDatasource from superset.daos.exceptions import DatasourceNotFound from superset.exceptions import QueryObjectValidationError from superset.extensions import event_logger diff --git a/superset/charts/post_processing.py b/superset/charts/post_processing.py index 939714642fc5e..ebcae32f8f486 100644 --- a/superset/charts/post_processing.py +++ b/superset/charts/post_processing.py @@ -40,7 +40,7 @@ ) if TYPE_CHECKING: - from superset.connectors.base.models import BaseDatasource + from superset.connectors.sqla.models import BaseDatasource from superset.models.sql_lab import Query diff --git a/superset/commands/utils.py b/superset/commands/utils.py index 02b6b5f383516..8cfeab3c1148d 100644 --- a/superset/commands/utils.py +++ b/superset/commands/utils.py @@ -33,7 +33,7 @@ from superset.utils.core import DatasourceType, get_user_id if TYPE_CHECKING: - from superset.connectors.base.models import BaseDatasource + from superset.connectors.sqla.models import BaseDatasource def populate_owners( diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index 22c778b77be67..d73a99d0271c1 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -24,7 +24,7 @@ from superset import app from superset.common.chart_data import ChartDataResultType from superset.common.db_query_status import QueryStatus -from superset.connectors.base.models import BaseDatasource +from superset.connectors.sqla.models import BaseDatasource from superset.exceptions import QueryObjectValidationError from superset.utils.core import ( extract_column_dtype, diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 1a8d3c518b07a..4f517cd90557a 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -30,7 +30,7 @@ from superset.models.slice import Slice if TYPE_CHECKING: - from superset.connectors.base.models import BaseDatasource + from superset.connectors.sqla.models import BaseDatasource from superset.models.helpers import QueryResult diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index d6510ccd9a434..708907d4a91ab 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -29,7 +29,7 @@ from superset.utils.core import DatasourceDict, DatasourceType, is_adhoc_column if TYPE_CHECKING: - from superset.connectors.base.models import BaseDatasource + from superset.connectors.sqla.models import BaseDatasource config = app.config diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 5a0468b671b39..7967313cd7678 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -36,7 +36,7 @@ get_since_until_from_query_object, get_since_until_from_time_range, ) -from superset.connectors.base.models import BaseDatasource +from superset.connectors.sqla.models import BaseDatasource from superset.constants import CacheRegion, TimeGrain from superset.daos.annotation import AnnotationLayerDAO from superset.daos.chart import ChartDAO diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 1e826761ecba4..989df5775b2e7 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -49,7 +49,7 @@ from superset.utils.hashing import md5_sha_from_dict if TYPE_CHECKING: - from superset.connectors.base.models import BaseDatasource + from superset.connectors.sqla.models import BaseDatasource logger = logging.getLogger(__name__) diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py index d993eca279093..d2aa140dfe933 100644 --- a/superset/common/query_object_factory.py +++ b/superset/common/query_object_factory.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: from sqlalchemy.orm import sessionmaker - from superset.connectors.base.models import BaseDatasource + from superset.connectors.sqla.models import BaseDatasource from superset.daos.datasource import DatasourceDAO diff --git a/superset/connectors/base/__init__.py b/superset/connectors/base/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/superset/connectors/base/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py deleted file mode 100644 index 1fc0fde5751c4..0000000000000 --- a/superset/connectors/base/models.py +++ /dev/null @@ -1,769 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import builtins -import json -import logging -from collections.abc import Hashable -from datetime import datetime -from json.decoder import JSONDecodeError -from typing import Any, TYPE_CHECKING - -from flask_appbuilder.security.sqla.models import User -from flask_babel import gettext as __ -from sqlalchemy import and_, Boolean, Column, Integer, String, Text -from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.orm import foreign, Query, relationship, RelationshipProperty, Session -from sqlalchemy.sql import literal_column - -from superset import security_manager -from superset.constants import EMPTY_STRING, NULL_STRING -from superset.datasets.commands.exceptions import DatasetNotFoundError -from superset.models.helpers import AuditMixinNullable, ImportExportMixin, QueryResult -from superset.models.slice import Slice -from superset.superset_typing import ( - FilterValue, - FilterValues, - QueryObjectDict, - ResultSetColumnType, -) -from superset.utils import core as utils -from superset.utils.backports import StrEnum -from superset.utils.core import GenericDataType, MediumText - -if TYPE_CHECKING: - from superset.db_engine_specs.base import BaseEngineSpec - -logger = logging.getLogger(__name__) - -METRIC_FORM_DATA_PARAMS = [ - "metric", - "metric_2", - "metrics", - "metrics_b", - "percent_metrics", - "secondary_metric", - "size", - "timeseries_limit_metric", - "x", - "y", -] - -COLUMN_FORM_DATA_PARAMS = [ - "all_columns", - "all_columns_x", - "columns", - "entity", - "groupby", - "order_by_cols", - "series", -] - - -class DatasourceKind(StrEnum): - VIRTUAL = "virtual" - PHYSICAL = "physical" - - -class BaseDatasource( - AuditMixinNullable, ImportExportMixin -): # pylint: disable=too-many-public-methods - """A common interface to objects that are queryable - (tables and datasources)""" - - # --------------------------------------------------------------- - # class attributes to define when deriving BaseDatasource - # --------------------------------------------------------------- - __tablename__: str | None = None # {connector_name}_datasource - baselink: str | None = None # url portion pointing to ModelView endpoint - - @property - def column_class(self) -> type[BaseColumn]: - # link to derivative of BaseColumn - raise NotImplementedError() - - @property - def metric_class(self) -> type[BaseMetric]: - # link to derivative of BaseMetric - raise NotImplementedError() - - owner_class: User | None = None - - # Used to do code highlighting when displaying the query in the UI - query_language: str | None = None - - # Only some datasources support Row Level Security - is_rls_supported: bool = False - - @property - def name(self) -> str: - # can be a Column or a property pointing to one - raise NotImplementedError() - - # --------------------------------------------------------------- - - # Columns - id = Column(Integer, primary_key=True) - description = Column(Text) - default_endpoint = Column(Text) - is_featured = Column(Boolean, default=False) # TODO deprecating - filter_select_enabled = Column(Boolean, default=True) - offset = Column(Integer, default=0) - cache_timeout = Column(Integer) - params = Column(String(1000)) - perm = Column(String(1000)) - schema_perm = Column(String(1000)) - is_managed_externally = Column(Boolean, nullable=False, default=False) - external_url = Column(Text, nullable=True) - - sql: str | None = None - owners: list[User] - update_from_object_fields: list[str] - - extra_import_fields = ["is_managed_externally", "external_url"] - - @property - def kind(self) -> DatasourceKind: - return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL - - @property - def owners_data(self) -> list[dict[str, Any]]: - return [ - { - "first_name": o.first_name, - "last_name": o.last_name, - "username": o.username, - "id": o.id, - } - for o in self.owners - ] - - @property - def is_virtual(self) -> bool: - return self.kind == DatasourceKind.VIRTUAL - - @declared_attr - def slices(self) -> RelationshipProperty: - return relationship( - "Slice", - overlaps="table", - primaryjoin=lambda: and_( - foreign(Slice.datasource_id) == self.id, - foreign(Slice.datasource_type) == self.type, - ), - ) - - columns: list[BaseColumn] = [] - metrics: list[BaseMetric] = [] - - @property - def type(self) -> str: - raise NotImplementedError() - - @property - def uid(self) -> str: - """Unique id across datasource types""" - return f"{self.id}__{self.type}" - - @property - def column_names(self) -> list[str]: - return sorted([c.column_name for c in self.columns], key=lambda x: x or "") - - @property - def columns_types(self) -> dict[str, str]: - return {c.column_name: c.type for c in self.columns} - - @property - def main_dttm_col(self) -> str: - return "timestamp" - - @property - def datasource_name(self) -> str: - raise NotImplementedError() - - @property - def connection(self) -> str | None: - """String representing the context of the Datasource""" - return None - - @property - def schema(self) -> str | None: - """String representing the schema of the Datasource (if it applies)""" - return None - - @property - def filterable_column_names(self) -> list[str]: - return sorted([c.column_name for c in self.columns if c.filterable]) - - @property - def dttm_cols(self) -> list[str]: - return [] - - @property - def url(self) -> str: - return f"/{self.baselink}/edit/{self.id}" - - @property - def explore_url(self) -> str: - if self.default_endpoint: - return self.default_endpoint - return f"/explore/?datasource_type={self.type}&datasource_id={self.id}" - - @property - def column_formats(self) -> dict[str, str | None]: - return {m.metric_name: m.d3format for m in self.metrics if m.d3format} - - @property - def currency_formats(self) -> dict[str, dict[str, str | None] | None]: - return {m.metric_name: m.currency_json for m in self.metrics if m.currency_json} - - def add_missing_metrics(self, metrics: list[BaseMetric]) -> None: - existing_metrics = {m.metric_name for m in self.metrics} - for metric in metrics: - if metric.metric_name not in existing_metrics: - metric.table_id = self.id - self.metrics.append(metric) - - @property - def short_data(self) -> dict[str, Any]: - """Data representation of the datasource sent to the frontend""" - return { - "edit_url": self.url, - "id": self.id, - "uid": self.uid, - "schema": self.schema, - "name": self.name, - "type": self.type, - "connection": self.connection, - "creator": str(self.created_by), - } - - @property - def select_star(self) -> str | None: - pass - - @property - def order_by_choices(self) -> list[tuple[str, str]]: - choices = [] - # self.column_names return sorted column_names - for column_name in self.column_names: - column_name = str(column_name or "") - choices.append( - (json.dumps([column_name, True]), f"{column_name} " + __("[asc]")) - ) - choices.append( - (json.dumps([column_name, False]), f"{column_name} " + __("[desc]")) - ) - return choices - - @property - def verbose_map(self) -> dict[str, str]: - verb_map = {"__timestamp": "Time"} - verb_map.update( - {o.metric_name: o.verbose_name or o.metric_name for o in self.metrics} - ) - verb_map.update( - {o.column_name: o.verbose_name or o.column_name for o in self.columns} - ) - return verb_map - - @property - def data(self) -> dict[str, Any]: - """Data representation of the datasource sent to the frontend""" - return { - # simple fields - "id": self.id, - "uid": self.uid, - "column_formats": self.column_formats, - "currency_formats": self.currency_formats, - "description": self.description, - "database": self.database.data, # pylint: disable=no-member - "default_endpoint": self.default_endpoint, - "filter_select": self.filter_select_enabled, # TODO deprecate - "filter_select_enabled": self.filter_select_enabled, - "name": self.name, - "datasource_name": self.datasource_name, - "table_name": self.datasource_name, - "type": self.type, - "schema": self.schema, - "offset": self.offset, - "cache_timeout": self.cache_timeout, - "params": self.params, - "perm": self.perm, - "edit_url": self.url, - # sqla-specific - "sql": self.sql, - # one to many - "columns": [o.data for o in self.columns], - "metrics": [o.data for o in self.metrics], - # TODO deprecate, move logic to JS - "order_by_choices": self.order_by_choices, - "owners": [owner.id for owner in self.owners], - "verbose_map": self.verbose_map, - "select_star": self.select_star, - } - - def data_for_slices( # pylint: disable=too-many-locals - self, slices: list[Slice] - ) -> dict[str, Any]: - """ - The representation of the datasource containing only the required data - to render the provided slices. - - Used to reduce the payload when loading a dashboard. - """ - data = self.data - metric_names = set() - column_names = set() - for slc in slices: - form_data = slc.form_data - # pull out all required metrics from the form_data - for metric_param in METRIC_FORM_DATA_PARAMS: - for metric in utils.as_list(form_data.get(metric_param) or []): - metric_names.add(utils.get_metric_name(metric)) - if utils.is_adhoc_metric(metric): - column = metric.get("column") or {} - if column_name := column.get("column_name"): - column_names.add(column_name) - - # Columns used in query filters - column_names.update( - filter_["subject"] - for filter_ in form_data.get("adhoc_filters") or [] - if filter_.get("clause") == "WHERE" and filter_.get("subject") - ) - - # columns used by Filter Box - column_names.update( - filter_config["column"] - for filter_config in form_data.get("filter_configs") or [] - if "column" in filter_config - ) - - # for legacy dashboard imports which have the wrong query_context in them - try: - query_context = slc.get_query_context() - except DatasetNotFoundError: - query_context = None - - # legacy charts don't have query_context charts - if query_context: - column_names.update( - [ - utils.get_column_name(column) - for query in query_context.queries - for column in query.columns - ] - or [] - ) - else: - _columns = [ - utils.get_column_name(column) - if utils.is_adhoc_column(column) - else column - for column_param in COLUMN_FORM_DATA_PARAMS - for column in utils.as_list(form_data.get(column_param) or []) - ] - column_names.update(_columns) - - filtered_metrics = [ - metric - for metric in data["metrics"] - if metric["metric_name"] in metric_names - ] - - filtered_columns: list[Column] = [] - column_types: set[GenericDataType] = set() - for column in data["columns"]: - generic_type = column.get("type_generic") - if generic_type is not None: - column_types.add(generic_type) - if column["column_name"] in column_names: - filtered_columns.append(column) - - data["column_types"] = list(column_types) - del data["description"] - data.update({"metrics": filtered_metrics}) - data.update({"columns": filtered_columns}) - verbose_map = {"__timestamp": "Time"} - verbose_map.update( - { - metric["metric_name"]: metric["verbose_name"] or metric["metric_name"] - for metric in filtered_metrics - } - ) - verbose_map.update( - { - column["column_name"]: column["verbose_name"] or column["column_name"] - for column in filtered_columns - } - ) - data["verbose_map"] = verbose_map - - return data - - @staticmethod - def filter_values_handler( # pylint: disable=too-many-arguments - values: FilterValues | None, - operator: str, - target_generic_type: GenericDataType, - target_native_type: str | None = None, - is_list_target: bool = False, - db_engine_spec: builtins.type[BaseEngineSpec] | None = None, - db_extra: dict[str, Any] | None = None, - ) -> FilterValues | None: - if values is None: - return None - - def handle_single_value(value: FilterValue | None) -> FilterValue | None: - if operator == utils.FilterOperator.TEMPORAL_RANGE: - return value - if ( - isinstance(value, (float, int)) - and target_generic_type == utils.GenericDataType.TEMPORAL - and target_native_type is not None - and db_engine_spec is not None - ): - value = db_engine_spec.convert_dttm( - target_type=target_native_type, - dttm=datetime.utcfromtimestamp(value / 1000), - db_extra=db_extra, - ) - value = literal_column(value) - if isinstance(value, str): - value = value.strip("\t\n") - - if ( - target_generic_type == utils.GenericDataType.NUMERIC - and operator - not in { - utils.FilterOperator.ILIKE, - utils.FilterOperator.LIKE, - } - ): - # For backwards compatibility and edge cases - # where a column data type might have changed - return utils.cast_to_num(value) - if value == NULL_STRING: - return None - if value == EMPTY_STRING: - return "" - if target_generic_type == utils.GenericDataType.BOOLEAN: - return utils.cast_to_boolean(value) - return value - - if isinstance(values, (list, tuple)): - values = [handle_single_value(v) for v in values] # type: ignore - else: - values = handle_single_value(values) - if is_list_target and not isinstance(values, (tuple, list)): - values = [values] # type: ignore - elif not is_list_target and isinstance(values, (tuple, list)): - values = values[0] if values else None - return values - - def external_metadata(self) -> list[ResultSetColumnType]: - """Returns column information from the external system""" - raise NotImplementedError() - - def get_query_str(self, query_obj: QueryObjectDict) -> str: - """Returns a query as a string - - This is used to be displayed to the user so that they can - understand what is taking place behind the scene""" - raise NotImplementedError() - - def query(self, query_obj: QueryObjectDict) -> QueryResult: - """Executes the query and returns a dataframe - - query_obj is a dictionary representing Superset's query interface. - Should return a ``superset.models.helpers.QueryResult`` - """ - raise NotImplementedError() - - @staticmethod - def default_query(qry: Query) -> Query: - return qry - - def get_column(self, column_name: str | None) -> BaseColumn | None: - if not column_name: - return None - for col in self.columns: - if col.column_name == column_name: - return col - return None - - @staticmethod - def get_fk_many_from_list( - object_list: list[Any], - fkmany: list[Column], - fkmany_class: builtins.type[BaseColumn | BaseMetric], - key_attr: str, - ) -> list[Column]: - """Update ORM one-to-many list from object list - - Used for syncing metrics and columns using the same code""" - - object_dict = {o.get(key_attr): o for o in object_list} - - # delete fks that have been removed - fkmany = [o for o in fkmany if getattr(o, key_attr) in object_dict] - - # sync existing fks - for fk in fkmany: - obj = object_dict.get(getattr(fk, key_attr)) - if obj: - for attr in fkmany_class.update_from_object_fields: - setattr(fk, attr, obj.get(attr)) - - # create new fks - new_fks = [] - orm_keys = [getattr(o, key_attr) for o in fkmany] - for obj in object_list: - key = obj.get(key_attr) - if key not in orm_keys: - del obj["id"] - orm_kwargs = {} - for k in obj: - if k in fkmany_class.update_from_object_fields and k in obj: - orm_kwargs[k] = obj[k] - new_obj = fkmany_class(**orm_kwargs) - new_fks.append(new_obj) - fkmany += new_fks - return fkmany - - def update_from_object(self, obj: dict[str, Any]) -> None: - """Update datasource from a data structure - - The UI's table editor crafts a complex data structure that - contains most of the datasource's properties as well as - an array of metrics and columns objects. This method - receives the object from the UI and syncs the datasource to - match it. Since the fields are different for the different - connectors, the implementation uses ``update_from_object_fields`` - which can be defined for each connector and - defines which fields should be synced""" - for attr in self.update_from_object_fields: - setattr(self, attr, obj.get(attr)) - - self.owners = obj.get("owners", []) - - # Syncing metrics - metrics = ( - self.get_fk_many_from_list( - obj["metrics"], self.metrics, self.metric_class, "metric_name" - ) - if self.metric_class and "metrics" in obj - else [] - ) - self.metrics = metrics - - # Syncing columns - self.columns = ( - self.get_fk_many_from_list( - obj["columns"], self.columns, self.column_class, "column_name" - ) - if self.column_class and "columns" in obj - else [] - ) - - def get_extra_cache_keys( - self, query_obj: QueryObjectDict # pylint: disable=unused-argument - ) -> list[Hashable]: - """If a datasource needs to provide additional keys for calculation of - cache keys, those can be provided via this method - - :param query_obj: The dict representation of a query object - :return: list of keys - """ - return [] - - def __hash__(self) -> int: - return hash(self.uid) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, BaseDatasource): - return NotImplemented - return self.uid == other.uid - - def raise_for_access(self) -> None: - """ - Raise an exception if the user cannot access the resource. - - :raises SupersetSecurityException: If the user cannot access the resource - """ - - security_manager.raise_for_access(datasource=self) - - @classmethod - def get_datasource_by_name( - cls, session: Session, datasource_name: str, schema: str, database_name: str - ) -> BaseDatasource | None: - raise NotImplementedError() - - -class BaseColumn(AuditMixinNullable, ImportExportMixin): - """Interface for column""" - - __tablename__: str | None = None # {connector_name}_column - - id = Column(Integer, primary_key=True) - column_name = Column(String(255), nullable=False) - verbose_name = Column(String(1024)) - is_active = Column(Boolean, default=True) - type = Column(Text) - advanced_data_type = Column(String(255)) - groupby = Column(Boolean, default=True) - filterable = Column(Boolean, default=True) - description = Column(MediumText()) - is_dttm = None - - # [optional] Set this to support import/export functionality - export_fields: list[Any] = [] - - def __repr__(self) -> str: - return str(self.column_name) - - bool_types = ("BOOL",) - num_types = ( - "DOUBLE", - "FLOAT", - "INT", - "BIGINT", - "NUMBER", - "LONG", - "REAL", - "NUMERIC", - "DECIMAL", - "MONEY", - ) - date_types = ("DATE", "TIME") - str_types = ("VARCHAR", "STRING", "CHAR") - - @property - def is_numeric(self) -> bool: - return self.type and any(map(lambda t: t in self.type.upper(), self.num_types)) - - @property - def is_temporal(self) -> bool: - return self.type and any(map(lambda t: t in self.type.upper(), self.date_types)) - - @property - def is_string(self) -> bool: - return self.type and any(map(lambda t: t in self.type.upper(), self.str_types)) - - @property - def is_boolean(self) -> bool: - return self.type and any(map(lambda t: t in self.type.upper(), self.bool_types)) - - @property - def type_generic(self) -> utils.GenericDataType | None: - if self.is_string: - return utils.GenericDataType.STRING - if self.is_boolean: - return utils.GenericDataType.BOOLEAN - if self.is_numeric: - return utils.GenericDataType.NUMERIC - if self.is_temporal: - return utils.GenericDataType.TEMPORAL - return None - - @property - def expression(self) -> Column: - raise NotImplementedError() - - @property - def python_date_format(self) -> Column: - raise NotImplementedError() - - @property - def data(self) -> dict[str, Any]: - attrs = ( - "id", - "column_name", - "verbose_name", - "description", - "expression", - "filterable", - "groupby", - "is_dttm", - "type", - "advanced_data_type", - ) - return {s: getattr(self, s) for s in attrs if hasattr(self, s)} - - -class BaseMetric(AuditMixinNullable, ImportExportMixin): - """Interface for Metrics""" - - __tablename__: str | None = None # {connector_name}_metric - - id = Column(Integer, primary_key=True) - metric_name = Column(String(255), nullable=False) - verbose_name = Column(String(1024)) - metric_type = Column(String(32)) - description = Column(MediumText()) - d3format = Column(String(128)) - currency = Column(String(128)) - warning_text = Column(Text) - - """ - The interface should also declare a datasource relationship pointing - to a derivative of BaseDatasource, along with a FK - - datasource_name = Column( - String(255), - ForeignKey('datasources.datasource_name')) - datasource = relationship( - # needs to be altered to point to {Connector}Datasource - 'BaseDatasource', - backref=backref('metrics', cascade='all, delete-orphan'), - enable_typechecks=False) - """ - - @property - def currency_json(self) -> dict[str, str | None] | None: - try: - return json.loads(self.currency or "{}") or None - except (TypeError, JSONDecodeError) as exc: - logger.error( - "Unable to load currency json: %r. Leaving empty.", exc, exc_info=True - ) - return None - - @property - def perm(self) -> str | None: - raise NotImplementedError() - - @property - def expression(self) -> Column: - raise NotImplementedError() - - @property - def data(self) -> dict[str, Any]: - attrs = ( - "id", - "metric_name", - "verbose_name", - "description", - "expression", - "warning_text", - "d3format", - "currency", - ) - return {s: getattr(self, s) for s in attrs} diff --git a/superset/connectors/base/views.py b/superset/connectors/base/views.py deleted file mode 100644 index ae5013ebbf4e9..0000000000000 --- a/superset/connectors/base/views.py +++ /dev/null @@ -1,48 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from typing import Any - -from flask import Markup -from flask_appbuilder.fieldwidgets import BS3TextFieldWidget - -from superset.connectors.base.models import BaseDatasource -from superset.exceptions import SupersetException -from superset.views.base import SupersetModelView - - -class BS3TextFieldROWidget( # pylint: disable=too-few-public-methods - BS3TextFieldWidget -): - """ - Custom read only text field widget. - """ - - def __call__(self, field: Any, **kwargs: Any) -> Markup: - kwargs["readonly"] = "true" - return super().__call__(field, **kwargs) - - -class DatasourceModelView(SupersetModelView): - def pre_delete(self, item: BaseDatasource) -> None: - if item.slices: - raise SupersetException( - Markup( - "Cannot delete a datasource that has slices attached to it." - "Here's the list of associated charts: " - + "".join([i.slice_name for i in item.slices]) - ) - ) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 510ca54ae8ffa..3d1435dc7ba82 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -17,6 +17,7 @@ # pylint: disable=too-many-lines from __future__ import annotations +import builtins import dataclasses import json import logging @@ -25,6 +26,7 @@ from collections.abc import Hashable from dataclasses import dataclass, field from datetime import datetime, timedelta +from json.decoder import JSONDecodeError from typing import Any, Callable, cast import dateutil.parser @@ -34,7 +36,8 @@ import sqlparse from flask import escape, Markup from flask_appbuilder import Model -from flask_babel import lazy_gettext as _ +from flask_appbuilder.security.sqla.models import User +from flask_babel import gettext as __, lazy_gettext as _ from jinja2.exceptions import TemplateError from sqlalchemy import ( and_, @@ -52,9 +55,11 @@ update, ) from sqlalchemy.engine.base import Connection +from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import ( backref, + foreign, Mapped, Query, reconstructor, @@ -71,12 +76,13 @@ from superset import app, db, is_feature_enabled, security_manager from superset.common.db_query_status import QueryStatus -from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.connectors.sqla.utils import ( get_columns_description, get_physical_table_metadata, get_virtual_table_metadata, ) +from superset.constants import EMPTY_STRING, NULL_STRING +from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression from superset.exceptions import ( ColumnNotFoundException, @@ -97,19 +103,24 @@ AuditMixinNullable, CertificationMixin, ExploreMixin, + ImportExportMixin, QueryResult, QueryStringExtended, validate_adhoc_subquery, ) +from superset.models.slice import Slice from superset.sql_parse import ParsedQuery, sanitize_clause from superset.superset_typing import ( AdhocColumn, AdhocMetric, + FilterValue, + FilterValues, Metric, QueryObjectDict, ResultSetColumnType, ) from superset.utils import core as utils +from superset.utils.backports import StrEnum from superset.utils.core import GenericDataType, MediumText config = app.config @@ -134,6 +145,565 @@ class MetadataResult: modified: list[str] = field(default_factory=list) +logger = logging.getLogger(__name__) + +METRIC_FORM_DATA_PARAMS = [ + "metric", + "metric_2", + "metrics", + "metrics_b", + "percent_metrics", + "secondary_metric", + "size", + "timeseries_limit_metric", + "x", + "y", +] + +COLUMN_FORM_DATA_PARAMS = [ + "all_columns", + "all_columns_x", + "columns", + "entity", + "groupby", + "order_by_cols", + "series", +] + + +class DatasourceKind(StrEnum): + VIRTUAL = "virtual" + PHYSICAL = "physical" + + +class BaseDatasource( + AuditMixinNullable, ImportExportMixin +): # pylint: disable=too-many-public-methods + """A common interface to objects that are queryable + (tables and datasources)""" + + # --------------------------------------------------------------- + # class attributes to define when deriving BaseDatasource + # --------------------------------------------------------------- + __tablename__: str | None = None # {connector_name}_datasource + baselink: str | None = None # url portion pointing to ModelView endpoint + + owner_class: User | None = None + + # Used to do code highlighting when displaying the query in the UI + query_language: str | None = None + + # Only some datasources support Row Level Security + is_rls_supported: bool = False + + @property + def name(self) -> str: + # can be a Column or a property pointing to one + raise NotImplementedError() + + # --------------------------------------------------------------- + + # Columns + id = Column(Integer, primary_key=True) + description = Column(Text) + default_endpoint = Column(Text) + is_featured = Column(Boolean, default=False) # TODO deprecating + filter_select_enabled = Column(Boolean, default=True) + offset = Column(Integer, default=0) + cache_timeout = Column(Integer) + params = Column(String(1000)) + perm = Column(String(1000)) + schema_perm = Column(String(1000)) + is_managed_externally = Column(Boolean, nullable=False, default=False) + external_url = Column(Text, nullable=True) + + sql: str | None = None + owners: list[User] + update_from_object_fields: list[str] + + extra_import_fields = ["is_managed_externally", "external_url"] + + @property + def kind(self) -> DatasourceKind: + return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL + + @property + def owners_data(self) -> list[dict[str, Any]]: + return [ + { + "first_name": o.first_name, + "last_name": o.last_name, + "username": o.username, + "id": o.id, + } + for o in self.owners + ] + + @property + def is_virtual(self) -> bool: + return self.kind == DatasourceKind.VIRTUAL + + @declared_attr + def slices(self) -> RelationshipProperty: + return relationship( + "Slice", + overlaps="table", + primaryjoin=lambda: and_( + foreign(Slice.datasource_id) == self.id, + foreign(Slice.datasource_type) == self.type, + ), + ) + + columns: list[TableColumn] = [] + metrics: list[SqlMetric] = [] + + @property + def type(self) -> str: + raise NotImplementedError() + + @property + def uid(self) -> str: + """Unique id across datasource types""" + return f"{self.id}__{self.type}" + + @property + def column_names(self) -> list[str]: + return sorted([c.column_name for c in self.columns], key=lambda x: x or "") + + @property + def columns_types(self) -> dict[str, str]: + return {c.column_name: c.type for c in self.columns} + + @property + def main_dttm_col(self) -> str: + return "timestamp" + + @property + def datasource_name(self) -> str: + raise NotImplementedError() + + @property + def connection(self) -> str | None: + """String representing the context of the Datasource""" + return None + + @property + def schema(self) -> str | None: + """String representing the schema of the Datasource (if it applies)""" + return None + + @property + def filterable_column_names(self) -> list[str]: + return sorted([c.column_name for c in self.columns if c.filterable]) + + @property + def dttm_cols(self) -> list[str]: + return [] + + @property + def url(self) -> str: + return f"/{self.baselink}/edit/{self.id}" + + @property + def explore_url(self) -> str: + if self.default_endpoint: + return self.default_endpoint + return f"/explore/?datasource_type={self.type}&datasource_id={self.id}" + + @property + def column_formats(self) -> dict[str, str | None]: + return {m.metric_name: m.d3format for m in self.metrics if m.d3format} + + @property + def currency_formats(self) -> dict[str, dict[str, str | None] | None]: + return {m.metric_name: m.currency_json for m in self.metrics if m.currency_json} + + def add_missing_metrics(self, metrics: list[SqlMetric]) -> None: + existing_metrics = {m.metric_name for m in self.metrics} + for metric in metrics: + if metric.metric_name not in existing_metrics: + metric.table_id = self.id + self.metrics.append(metric) + + @property + def short_data(self) -> dict[str, Any]: + """Data representation of the datasource sent to the frontend""" + return { + "edit_url": self.url, + "id": self.id, + "uid": self.uid, + "schema": self.schema, + "name": self.name, + "type": self.type, + "connection": self.connection, + "creator": str(self.created_by), + } + + @property + def select_star(self) -> str | None: + pass + + @property + def order_by_choices(self) -> list[tuple[str, str]]: + choices = [] + # self.column_names return sorted column_names + for column_name in self.column_names: + column_name = str(column_name or "") + choices.append( + (json.dumps([column_name, True]), f"{column_name} " + __("[asc]")) + ) + choices.append( + (json.dumps([column_name, False]), f"{column_name} " + __("[desc]")) + ) + return choices + + @property + def verbose_map(self) -> dict[str, str]: + verb_map = {"__timestamp": "Time"} + verb_map.update( + {o.metric_name: o.verbose_name or o.metric_name for o in self.metrics} + ) + verb_map.update( + {o.column_name: o.verbose_name or o.column_name for o in self.columns} + ) + return verb_map + + @property + def data(self) -> dict[str, Any]: + """Data representation of the datasource sent to the frontend""" + return { + # simple fields + "id": self.id, + "uid": self.uid, + "column_formats": self.column_formats, + "currency_formats": self.currency_formats, + "description": self.description, + "database": self.database.data, # pylint: disable=no-member + "default_endpoint": self.default_endpoint, + "filter_select": self.filter_select_enabled, # TODO deprecate + "filter_select_enabled": self.filter_select_enabled, + "name": self.name, + "datasource_name": self.datasource_name, + "table_name": self.datasource_name, + "type": self.type, + "schema": self.schema, + "offset": self.offset, + "cache_timeout": self.cache_timeout, + "params": self.params, + "perm": self.perm, + "edit_url": self.url, + # sqla-specific + "sql": self.sql, + # one to many + "columns": [o.data for o in self.columns], + "metrics": [o.data for o in self.metrics], + # TODO deprecate, move logic to JS + "order_by_choices": self.order_by_choices, + "owners": [owner.id for owner in self.owners], + "verbose_map": self.verbose_map, + "select_star": self.select_star, + } + + def data_for_slices( # pylint: disable=too-many-locals + self, slices: list[Slice] + ) -> dict[str, Any]: + """ + The representation of the datasource containing only the required data + to render the provided slices. + + Used to reduce the payload when loading a dashboard. + """ + data = self.data + metric_names = set() + column_names = set() + for slc in slices: + form_data = slc.form_data + # pull out all required metrics from the form_data + for metric_param in METRIC_FORM_DATA_PARAMS: + for metric in utils.as_list(form_data.get(metric_param) or []): + metric_names.add(utils.get_metric_name(metric)) + if utils.is_adhoc_metric(metric): + column_ = metric.get("column") or {} + if column_name := column_.get("column_name"): + column_names.add(column_name) + + # Columns used in query filters + column_names.update( + filter_["subject"] + for filter_ in form_data.get("adhoc_filters") or [] + if filter_.get("clause") == "WHERE" and filter_.get("subject") + ) + + # columns used by Filter Box + column_names.update( + filter_config["column"] + for filter_config in form_data.get("filter_configs") or [] + if "column" in filter_config + ) + + # for legacy dashboard imports which have the wrong query_context in them + try: + query_context = slc.get_query_context() + except DatasetNotFoundError: + query_context = None + + # legacy charts don't have query_context charts + if query_context: + column_names.update( + [ + utils.get_column_name(column_) + for query in query_context.queries + for column_ in query.columns + ] + or [] + ) + else: + _columns = [ + utils.get_column_name(column_) + if utils.is_adhoc_column(column_) + else column_ + for column_param in COLUMN_FORM_DATA_PARAMS + for column_ in utils.as_list(form_data.get(column_param) or []) + ] + column_names.update(_columns) + + filtered_metrics = [ + metric + for metric in data["metrics"] + if metric["metric_name"] in metric_names + ] + + filtered_columns: list[Column] = [] + column_types: set[GenericDataType] = set() + for column_ in data["columns"]: + generic_type = column_.get("type_generic") + if generic_type is not None: + column_types.add(generic_type) + if column_["column_name"] in column_names: + filtered_columns.append(column_) + + data["column_types"] = list(column_types) + del data["description"] + data.update({"metrics": filtered_metrics}) + data.update({"columns": filtered_columns}) + verbose_map = {"__timestamp": "Time"} + verbose_map.update( + { + metric["metric_name"]: metric["verbose_name"] or metric["metric_name"] + for metric in filtered_metrics + } + ) + verbose_map.update( + { + column_["column_name"]: column_["verbose_name"] + or column_["column_name"] + for column_ in filtered_columns + } + ) + data["verbose_map"] = verbose_map + + return data + + @staticmethod + def filter_values_handler( # pylint: disable=too-many-arguments + values: FilterValues | None, + operator: str, + target_generic_type: GenericDataType, + target_native_type: str | None = None, + is_list_target: bool = False, + db_engine_spec: builtins.type[BaseEngineSpec] | None = None, + db_extra: dict[str, Any] | None = None, + ) -> FilterValues | None: + if values is None: + return None + + def handle_single_value(value: FilterValue | None) -> FilterValue | None: + if operator == utils.FilterOperator.TEMPORAL_RANGE: + return value + if ( + isinstance(value, (float, int)) + and target_generic_type == utils.GenericDataType.TEMPORAL + and target_native_type is not None + and db_engine_spec is not None + ): + value = db_engine_spec.convert_dttm( + target_type=target_native_type, + dttm=datetime.utcfromtimestamp(value / 1000), + db_extra=db_extra, + ) + value = literal_column(value) + if isinstance(value, str): + value = value.strip("\t\n") + + if ( + target_generic_type == utils.GenericDataType.NUMERIC + and operator + not in { + utils.FilterOperator.ILIKE, + utils.FilterOperator.LIKE, + } + ): + # For backwards compatibility and edge cases + # where a column data type might have changed + return utils.cast_to_num(value) + if value == NULL_STRING: + return None + if value == EMPTY_STRING: + return "" + if target_generic_type == utils.GenericDataType.BOOLEAN: + return utils.cast_to_boolean(value) + return value + + if isinstance(values, (list, tuple)): + values = [handle_single_value(v) for v in values] # type: ignore + else: + values = handle_single_value(values) + if is_list_target and not isinstance(values, (tuple, list)): + values = [values] # type: ignore + elif not is_list_target and isinstance(values, (tuple, list)): + values = values[0] if values else None + return values + + def external_metadata(self) -> list[ResultSetColumnType]: + """Returns column information from the external system""" + raise NotImplementedError() + + def get_query_str(self, query_obj: QueryObjectDict) -> str: + """Returns a query as a string + + This is used to be displayed to the user so that they can + understand what is taking place behind the scene""" + raise NotImplementedError() + + def query(self, query_obj: QueryObjectDict) -> QueryResult: + """Executes the query and returns a dataframe + + query_obj is a dictionary representing Superset's query interface. + Should return a ``superset.models.helpers.QueryResult`` + """ + raise NotImplementedError() + + @staticmethod + def default_query(qry: Query) -> Query: + return qry + + def get_column(self, column_name: str | None) -> TableColumn | None: + if not column_name: + return None + for col in self.columns: + if col.column_name == column_name: + return col + return None + + @staticmethod + def get_fk_many_from_list( + object_list: list[Any], + fkmany: list[Column], + fkmany_class: builtins.type[TableColumn | SqlMetric], + key_attr: str, + ) -> list[Column]: + """Update ORM one-to-many list from object list + + Used for syncing metrics and columns using the same code""" + + object_dict = {o.get(key_attr): o for o in object_list} + + # delete fks that have been removed + fkmany = [o for o in fkmany if getattr(o, key_attr) in object_dict] + + # sync existing fks + for fk in fkmany: + obj = object_dict.get(getattr(fk, key_attr)) + if obj: + for attr in fkmany_class.update_from_object_fields: + setattr(fk, attr, obj.get(attr)) + + # create new fks + new_fks = [] + orm_keys = [getattr(o, key_attr) for o in fkmany] + for obj in object_list: + key = obj.get(key_attr) + if key not in orm_keys: + del obj["id"] + orm_kwargs = {} + for k in obj: + if k in fkmany_class.update_from_object_fields and k in obj: + orm_kwargs[k] = obj[k] + new_obj = fkmany_class(**orm_kwargs) + new_fks.append(new_obj) + fkmany += new_fks + return fkmany + + def update_from_object(self, obj: dict[str, Any]) -> None: + """Update datasource from a data structure + + The UI's table editor crafts a complex data structure that + contains most of the datasource's properties as well as + an array of metrics and columns objects. This method + receives the object from the UI and syncs the datasource to + match it. Since the fields are different for the different + connectors, the implementation uses ``update_from_object_fields`` + which can be defined for each connector and + defines which fields should be synced""" + for attr in self.update_from_object_fields: + setattr(self, attr, obj.get(attr)) + + self.owners = obj.get("owners", []) + + # Syncing metrics + metrics = ( + self.get_fk_many_from_list( + obj["metrics"], self.metrics, SqlMetric, "metric_name" + ) + if "metrics" in obj + else [] + ) + self.metrics = metrics + + # Syncing columns + self.columns = ( + self.get_fk_many_from_list( + obj["columns"], self.columns, TableColumn, "column_name" + ) + if "columns" in obj + else [] + ) + + def get_extra_cache_keys( + self, query_obj: QueryObjectDict # pylint: disable=unused-argument + ) -> list[Hashable]: + """If a datasource needs to provide additional keys for calculation of + cache keys, those can be provided via this method + + :param query_obj: The dict representation of a query object + :return: list of keys + """ + return [] + + def __hash__(self) -> int: + return hash(self.uid) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, BaseDatasource): + return NotImplemented + return self.uid == other.uid + + def raise_for_access(self) -> None: + """ + Raise an exception if the user cannot access the resource. + + :raises SupersetSecurityException: If the user cannot access the resource + """ + + security_manager.raise_for_access(datasource=self) + + @classmethod + def get_datasource_by_name( + cls, session: Session, datasource_name: str, schema: str, database_name: str + ) -> BaseDatasource | None: + raise NotImplementedError() + + class AnnotationDatasource(BaseDatasource): """Dummy object so we can query annotations using 'Viz' objects just like regular datasources. @@ -187,22 +757,33 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]: raise NotImplementedError() -class TableColumn(Model, BaseColumn, CertificationMixin): +class TableColumn(Model, AuditMixinNullable, ImportExportMixin, CertificationMixin): """ORM object for table columns, each table can have multiple columns""" __tablename__ = "table_columns" __table_args__ = (UniqueConstraint("table_id", "column_name"),) + + id = Column(Integer, primary_key=True) + column_name = Column(String(255), nullable=False) + verbose_name = Column(String(1024)) + is_active = Column(Boolean, default=True) + type = Column(Text) + advanced_data_type = Column(String(255)) + groupby = Column(Boolean, default=True) + filterable = Column(Boolean, default=True) + description = Column(MediumText()) table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE")) - table: Mapped[SqlaTable] = relationship( - "SqlaTable", - back_populates="columns", - ) is_dttm = Column(Boolean, default=False) expression = Column(MediumText()) python_date_format = Column(String(255)) extra = Column(Text) + table: Mapped[SqlaTable] = relationship( + "SqlaTable", + back_populates="columns", + ) + export_fields = [ "table_id", "column_name", @@ -246,6 +827,9 @@ def init_on_load(self) -> None: self._database = None + def __repr__(self) -> str: + return str(self.column_name) + @property def is_boolean(self) -> bool: """ @@ -284,7 +868,7 @@ def database(self) -> Database: return self.table.database if self.table else self._database @property - def db_engine_spec(self) -> type[BaseEngineSpec]: + def db_engine_spec(self) -> builtins.type[BaseEngineSpec]: return self.database.db_engine_spec @property @@ -366,44 +950,50 @@ def get_timestamp_expression( @property def data(self) -> dict[str, Any]: attrs = ( - "id", + "advanced_data_type", + "certification_details", + "certified_by", "column_name", - "verbose_name", "description", "expression", "filterable", "groupby", + "id", + "is_certified", "is_dttm", + "python_date_format", "type", "type_generic", - "advanced_data_type", - "python_date_format", - "is_certified", - "certified_by", - "certification_details", + "verbose_name", "warning_markdown", ) - attr_dict = {s: getattr(self, s) for s in attrs if hasattr(self, s)} - - attr_dict.update(super().data) + return {s: getattr(self, s) for s in attrs if hasattr(self, s)} - return attr_dict - -class SqlMetric(Model, BaseMetric, CertificationMixin): +class SqlMetric(Model, AuditMixinNullable, ImportExportMixin, CertificationMixin): """ORM object for metrics, each table can have multiple metrics""" __tablename__ = "sql_metrics" __table_args__ = (UniqueConstraint("table_id", "metric_name"),) + + id = Column(Integer, primary_key=True) + metric_name = Column(String(255), nullable=False) + verbose_name = Column(String(1024)) + metric_type = Column(String(32)) + description = Column(MediumText()) + d3format = Column(String(128)) + currency = Column(String(128)) + warning_text = Column(Text) table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE")) + expression = Column(MediumText(), nullable=False) + extra = Column(Text) + table: Mapped[SqlaTable] = relationship( "SqlaTable", back_populates="metrics", ) - expression = Column(MediumText(), nullable=False) - extra = Column(Text) export_fields = [ "metric_name", @@ -449,18 +1039,34 @@ def perm(self) -> str | None: def get_perm(self) -> str | None: return self.perm + @property + def currency_json(self) -> dict[str, str | None] | None: + try: + return json.loads(self.currency or "{}") or None + except (TypeError, JSONDecodeError) as exc: + logger.error( + "Unable to load currency json: %r. Leaving empty.", exc, exc_info=True + ) + return None + @property def data(self) -> dict[str, Any]: attrs = ( - "is_certified", - "certified_by", "certification_details", + "certified_by", + "currency", + "d3format", + "description", + "expression", + "id", + "is_certified", + "metric_name", "warning_markdown", + "warning_text", + "verbose_name", ) - attr_dict = {s: getattr(self, s) for s in attrs} - attr_dict.update(super().data) - return attr_dict + return {s: getattr(self, s) for s in attrs} sqlatable_user = Table( diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index 1ba10f18b216a..36eebcb3f7e16 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -28,7 +28,6 @@ from wtforms.validators import DataRequired, Regexp from superset import db -from superset.connectors.base.views import DatasourceModelView from superset.connectors.sqla import models from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.superset_typing import FlaskResponse @@ -282,7 +281,7 @@ def list(self) -> FlaskResponse: class TableModelView( # pylint: disable=too-many-ancestors - DatasourceModelView, DeleteMixin, YamlExportMixin + SupersetModelView, DeleteMixin, YamlExportMixin ): datamodel = SQLAInterface(models.SqlaTable) class_permission_name = "Dataset" diff --git a/superset/daos/chart.py b/superset/daos/chart.py index 7eae38cb0ecad..eb8b3e809e492 100644 --- a/superset/daos/chart.py +++ b/superset/daos/chart.py @@ -28,7 +28,7 @@ from superset.utils.core import get_user_id if TYPE_CHECKING: - from superset.connectors.base.models import BaseDatasource + from superset.connectors.sqla.models import BaseDatasource logger = logging.getLogger(__name__) diff --git a/superset/datasets/commands/importers/v0.py b/superset/datasets/commands/importers/v0.py index a34d9be1acafb..06472630853c5 100644 --- a/superset/datasets/commands/importers/v0.py +++ b/superset/datasets/commands/importers/v0.py @@ -26,8 +26,12 @@ from superset import db from superset.commands.base import BaseCommand from superset.commands.importers.exceptions import IncorrectVersionError -from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric -from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.connectors.sqla.models import ( + BaseDatasource, + SqlaTable, + SqlMetric, + TableColumn, +) from superset.databases.commands.exceptions import DatabaseNotFoundError from superset.datasets.commands.exceptions import DatasetInvalidError from superset.models.core import Database @@ -102,14 +106,8 @@ def lookup_sqla_metric(session: Session, metric: SqlMetric) -> SqlMetric: ) -def import_metric(session: Session, metric: BaseMetric) -> BaseMetric: - if isinstance(metric, SqlMetric): - lookup_metric = lookup_sqla_metric - else: - raise Exception( # pylint: disable=broad-exception-raised - f"Invalid metric type: {metric}" - ) - return import_simple_obj(session, metric, lookup_metric) +def import_metric(session: Session, metric: SqlMetric) -> SqlMetric: + return import_simple_obj(session, metric, lookup_sqla_metric) def lookup_sqla_column(session: Session, column: TableColumn) -> TableColumn: @@ -123,14 +121,8 @@ def lookup_sqla_column(session: Session, column: TableColumn) -> TableColumn: ) -def import_column(session: Session, column: BaseColumn) -> BaseColumn: - if isinstance(column, TableColumn): - lookup_column = lookup_sqla_column - else: - raise Exception( # pylint: disable=broad-exception-raised - f"Invalid column type: {column}" - ) - return import_simple_obj(session, column, lookup_column) +def import_column(session: Session, column: TableColumn) -> TableColumn: + return import_simple_obj(session, column, lookup_sqla_column) def import_datasource( # pylint: disable=too-many-arguments diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index 31d956f5fde82..78dca8e2b2983 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -24,14 +24,8 @@ import superset.utils.database from superset import app, db -from superset.connectors.sqla.models import SqlMetric -from superset.models.dashboard import Dashboard -from superset.models.slice import Slice -from superset.utils import core as utils -from superset.utils.core import DatasourceType - -from ..connectors.base.models import BaseDatasource -from .helpers import ( +from superset.connectors.sqla.models import BaseDatasource, SqlMetric +from superset.examples.helpers import ( get_example_url, get_examples_folder, get_slice_json, @@ -40,6 +34,10 @@ misc_dash_slices, update_slice_ids, ) +from superset.models.dashboard import Dashboard +from superset.models.slice import Slice +from superset.utils import core as utils +from superset.utils.core import DatasourceType def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-statements diff --git a/superset/explore/commands/get.py b/superset/explore/commands/get.py index d348b16251b97..1994e7ad43876 100644 --- a/superset/explore/commands/get.py +++ b/superset/explore/commands/get.py @@ -26,8 +26,7 @@ from superset import db from superset.commands.base import BaseCommand -from superset.connectors.base.models import BaseDatasource -from superset.connectors.sqla.models import SqlaTable +from superset.connectors.sqla.models import BaseDatasource, SqlaTable from superset.daos.datasource import DatasourceDAO from superset.daos.exceptions import DatasourceNotFound from superset.exceptions import SupersetException diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 18c94aa179cee..919c832ab5b4f 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -47,8 +47,12 @@ from sqlalchemy.sql.elements import BinaryExpression from superset import app, db, is_feature_enabled, security_manager -from superset.connectors.base.models import BaseDatasource -from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.connectors.sqla.models import ( + BaseDatasource, + SqlaTable, + SqlMetric, + TableColumn, +) from superset.daos.datasource import DatasourceDAO from superset.extensions import cache_manager from superset.models.filter_set import FilterSet diff --git a/superset/models/slice.py b/superset/models/slice.py index 248f4ee947e7d..b41bb72a85496 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -51,7 +51,7 @@ if TYPE_CHECKING: from superset.common.query_context import QueryContext from superset.common.query_context_factory import QueryContextFactory - from superset.connectors.base.models import BaseDatasource + from superset.connectors.sqla.models import BaseDatasource metadata = Model.metadata # pylint: disable=no-member slice_user = Table( diff --git a/superset/security/manager.py b/superset/security/manager.py index c8d2c236ab95f..d2f144f6fe841 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -78,8 +78,11 @@ if TYPE_CHECKING: from superset.common.query_context import QueryContext - from superset.connectors.base.models import BaseDatasource - from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable + from superset.connectors.sqla.models import ( + BaseDatasource, + RowLevelSecurityFilter, + SqlaTable, + ) from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.sql_lab import Query diff --git a/superset/utils/core.py b/superset/utils/core.py index 67edabe626d6f..b9c24076a4e12 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -105,7 +105,7 @@ from superset.utils.hashing import md5_sha_from_dict, md5_sha_from_str if TYPE_CHECKING: - from superset.connectors.base.models import BaseColumn, BaseDatasource + from superset.connectors.sqla.models import BaseDatasource, TableColumn from superset.models.sql_lab import Query logging.getLogger("MARKDOWN").setLevel(logging.INFO) @@ -1628,7 +1628,7 @@ def extract_dataframe_dtypes( return generic_types -def extract_column_dtype(col: BaseColumn) -> GenericDataType: +def extract_column_dtype(col: TableColumn) -> GenericDataType: if col.is_temporal: return GenericDataType.TEMPORAL if col.is_numeric: diff --git a/superset/views/core.py b/superset/views/core.py index 2f9b99eba0e61..bb273eb53c9a7 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -47,8 +47,7 @@ from superset.charts.commands.exceptions import ChartNotFoundError from superset.charts.commands.warm_up_cache import ChartWarmUpCacheCommand from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType -from superset.connectors.base.models import BaseDatasource -from superset.connectors.sqla.models import SqlaTable +from superset.connectors.sqla.models import BaseDatasource, SqlaTable from superset.daos.chart import ChartDAO from superset.daos.datasource import DatasourceDAO from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand diff --git a/superset/viz.py b/superset/viz.py index 2e697a77becf8..8ba785ddcf39e 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -84,7 +84,7 @@ if TYPE_CHECKING: from superset.common.query_context_factory import QueryContextFactory - from superset.connectors.base.models import BaseDatasource + from superset.connectors.sqla.models import BaseDatasource config = app.config stats_logger = config["STATS_LOGGER"] diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index 7f7c543d8b04a..0040ec60f68b3 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -36,8 +36,7 @@ from tests.integration_tests.test_app import app, login from superset.sql_parse import CtasMethod from superset import db, security_manager -from superset.connectors.base.models import BaseDatasource -from superset.connectors.sqla.models import SqlaTable +from superset.connectors.sqla.models import BaseDatasource, SqlaTable from superset.models import core as models from superset.models.slice import Slice from superset.models.core import Database diff --git a/tests/unit_tests/common/test_get_aggregated_join_column.py b/tests/unit_tests/common/test_get_aggregated_join_column.py index 8effacf2494cb..de0b6b92b2850 100644 --- a/tests/unit_tests/common/test_get_aggregated_join_column.py +++ b/tests/unit_tests/common/test_get_aggregated_join_column.py @@ -24,7 +24,7 @@ AGGREGATED_JOIN_COLUMN, QueryContextProcessor, ) -from superset.connectors.base.models import BaseDatasource +from superset.connectors.sqla.models import BaseDatasource from superset.constants import TimeGrain query_context_processor = QueryContextProcessor(