Skip to content

Commit

Permalink
big reafctor
Browse files Browse the repository at this point in the history
  • Loading branch information
hughhhh authored Jun 5, 2022
1 parent 5dff58f commit 95c5eb2
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 113 deletions.
6 changes: 2 additions & 4 deletions superset/cachekeys/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
from sqlalchemy.exc import SQLAlchemyError

from superset.cachekeys.schemas import CacheInvalidationRequestSchema
from superset.datasource.dao import DatasourceDAO
from superset.connectors.sqla.models import SqlaTable
from superset.extensions import cache_manager, db, event_logger
from superset.models.cache import CacheKey
from superset.utils.core import DatasourceType
from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -84,9 +83,8 @@ def invalidate(self) -> Response:
return self.response_400(message=str(error))
datasource_uids = set(datasources.get("datasource_uids", []))
for ds in datasources.get("datasources", []):
ds_obj = DatasourceDAO.get_datasource_by_name(
ds_obj = SqlaTable.get_datasource_by_name(
session=db.session,
datasource_type=DatasourceType(ds.get("datasource_type")),
datasource_name=ds.get("datasource_name"),
schema=ds.get("schema"),
database_name=ds.get("database_name"),
Expand Down
43 changes: 43 additions & 0 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
List,
NamedTuple,
Optional,
Set,
Tuple,
Type,
Union,
Expand Down Expand Up @@ -1988,6 +1989,48 @@ def query_datasources_by_name(
query = query.filter_by(schema=schema)
return query.all()

@classmethod
def query_datasources_by_permissions( # pylint: disable=invalid-name
cls,
session: Session,
database: Database,
permissions: Set[str],
schema_perms: Set[str],
) -> List["SqlaTable"]:
# TODO(hughhhh): add unit test
return (
session.query(cls)
.filter_by(database_id=database.id)
.filter(
or_(
SqlaTable.perm.in_(permissions),
SqlaTable.schema_perm.in_(schema_perms),
)
)
.all()
)

@classmethod
def get_eager_sqlatable_datasource(
cls, session: Session, datasource_id: int
) -> "SqlaTable":
"""Returns SqlaTable with columns and metrics."""
return (
session.query(cls)
.options(
sa.subqueryload(cls.columns),
sa.subqueryload(cls.metrics),
)
.filter_by(id=datasource_id)
.one()
)

@classmethod
def get_all_sqlatables_datasources(cls, session: Session) -> List["SqlaTable"]:
qry = session.query(cls)
qry = cls.default_query(qry)
return qry.all()

@staticmethod
def default_query(qry: Query) -> Query:
return qry.filter_by(is_sqllab_view=False)
Expand Down
11 changes: 5 additions & 6 deletions superset/dashboards/commands/importers/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,11 @@ def import_chart(
slc_to_import = slc_to_import.copy()
slc_to_import.reset_ownership()
params = slc_to_import.params_dict
datasource = DatasourceDAO.get_datasource_by_name(
session,
DatasourceType(slc_to_import.datasource_type),
params["datasource_name"],
params["database_name"],
params["schema"],
datasource = SqlaTable.get_datasource_by_name(
session=session,
datasource_name=params["datasource_name"],
database_name=params["database_name"],
schema=params["schema"],
)
slc_to_import.datasource_id = datasource.id # type: ignore
if slc_to_override:
Expand Down
77 changes: 0 additions & 77 deletions superset/datasource/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,80 +59,3 @@ def get_datasource(
raise DatasourceNotFound()

return datasource

@classmethod
def get_all_sqlatables_datasources(cls, session: Session) -> List[SqlaTable]:
source_class = DatasourceDAO.sources[DatasourceType.TABLE]
qry = session.query(source_class)
qry = source_class.default_query(qry)
return qry.all()

@classmethod
def get_datasource_by_name( # pylint: disable=too-many-arguments
cls,
session: Session,
datasource_type: DatasourceType,
datasource_name: str,
database_name: str,
schema: str,
) -> Optional[Datasource]:
datasource_class = DatasourceDAO.sources[datasource_type]
return datasource_class.get_datasource_by_name(
session, datasource_name, schema, database_name
)

@classmethod
def query_datasources_by_permissions( # pylint: disable=invalid-name
cls,
session: Session,
database: Database,
permissions: Set[str],
schema_perms: Set[str],
) -> List[Datasource]:
# TODO(hughhhh): add unit test
datasource_class = DatasourceDAO.sources[DatasourceType(database.type)]
if not isinstance(datasource_class, SqlaTable):
return []

return (
session.query(datasource_class)
.filter_by(database_id=database.id)
.filter(
or_(
datasource_class.perm.in_(permissions),
datasource_class.schema_perm.in_(schema_perms),
)
)
.all()
)

@classmethod
def get_eager_sqlatable_datasource(
cls, session: Session, datasource_id: int
) -> SqlaTable:
"""Returns SqlaTable with columns and metrics."""
return (
session.query(SqlaTable)
.options(
subqueryload(SqlaTable.columns),
subqueryload(SqlaTable.metrics),
)
.filter_by(id=datasource_id)
.one()
)

@classmethod
def query_datasources_by_name(
cls,
session: Session,
database: Database,
datasource_name: str,
schema: Optional[str] = None,
) -> List[Datasource]:
datasource_class = DatasourceDAO.sources[DatasourceType(database.type)]
if not isinstance(datasource_class, SqlaTable):
return []

return datasource_class.query_datasources_by_name(
session, database, datasource_name, schema=schema
)
2 changes: 1 addition & 1 deletion superset/explore/form_data/commands/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing_extensions import TypedDict

from superset.core.utils import DatasourceType
from superset.utils.core import DatasourceType


class TemporaryExploreState(TypedDict):
Expand Down
4 changes: 2 additions & 2 deletions superset/models/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from superset import app, db, is_feature_enabled, security_manager
from superset.common.request_contexed_based import is_user_admin
from superset.connectors.base.models import BaseDatasource
from superset.connectors.sqla.models import SqlMetric, TableColumn
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.datasource.dao import DatasourceDAO
from superset.extensions import cache_manager
from superset.models.filter_set import FilterSet
Expand Down Expand Up @@ -418,7 +418,7 @@ def export_dashboards( # pylint: disable=too-many-locals

eager_datasources = []
for datasource_id, _ in datasource_ids:
eager_datasource = DatasourceDAO.get_eager_sqlatable_datasource(
eager_datasource = SqlaTable.get_eager_sqlatable_datasource(
db.session, datasource_id
)
copied_datasource = eager_datasource.copy()
Expand Down
23 changes: 8 additions & 15 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,9 @@ def get_user_datasources(self) -> List["BaseDatasource"]:
)

# group all datasources by database
# pylint: disable=import-outside-toplevel
from superset.datasource.dao import DatasourceDAO

all_datasources = DatasourceDAO.get_all_sqlatables_datasources(self.get_session)
datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict(
set
)
session = self.get_session
all_datasources = SqlaTable.get_all_sqlatables_datasources(session)
datasources_by_database: Dict["Database", Set["SqlaTable"]] = defaultdict(set)
for datasource in all_datasources:
datasources_by_database[datasource.database].add(datasource)

Expand Down Expand Up @@ -604,6 +600,8 @@ def get_datasources_accessible_by_user( # pylint: disable=invalid-name
:param schema: The fallback SQL schema if not present in the table name
:returns: The list of accessible SQL tables w/ schema
"""
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import SqlaTable

if self.can_access_database(database):
return datasource_names
Expand All @@ -615,10 +613,7 @@ def get_datasources_accessible_by_user( # pylint: disable=invalid-name

user_perms = self.user_view_menu_names("datasource_access")
schema_perms = self.user_view_menu_names("schema_access")
# pylint: disable=import-outside-toplevel
from superset.datasource.dao import DatasourceDAO

user_datasources = DatasourceDAO.query_datasources_by_permissions(
user_datasources = SqlaTable.query_datasources_by_permissions(
self.get_session, database, user_perms, schema_perms
)
if schema:
Expand Down Expand Up @@ -668,6 +663,7 @@ def create_missing_perms(self) -> None:
"""

# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models

logger.info("Fetching a set of all perms to lookup which ones are missing")
Expand All @@ -682,10 +678,7 @@ def merge_pv(view_menu: str, perm: Optional[str]) -> None:
self.add_permission_view_menu(view_menu, perm)

logger.info("Creating missing datasource permissions.")
# pylint: disable=import-outside-toplevel
from superset.datasource.dao import DatasourceDAO

datasources = DatasourceDAO.get_all_sqlatables_datasources(self.get_session)
datasources = SqlaTable.get_all_sqlatables_datasources(self.get_session)
for datasource in datasources:
merge_pv("datasource_access", datasource.get_perm())
merge_pv("schema_access", datasource.get_schema_perm())
Expand Down
2 changes: 1 addition & 1 deletion superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def override_role_permissions(self) -> FlaskResponse:
)
db_ds_names.add(fullname)

existing_datasources = DatasourceDAO.get_all_sqlatables_datasources(db.session)
existing_datasources = SqlaTable.get_all_sqlatables_datasources(db.session)
datasources = [d for d in existing_datasources if d.full_name in db_ds_names]
role = security_manager.find_role(role_name)
# remove all permissions
Expand Down
4 changes: 2 additions & 2 deletions superset/views/datasource/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from superset import db, event_logger
from superset.commands.utils import populate_owners
from superset.connectors.sqla.models import SqlaTable
from superset.connectors.sqla.utils import get_physical_table_metadata
from superset.datasets.commands.exceptions import (
DatasetForbiddenError,
Expand Down Expand Up @@ -156,9 +157,8 @@ def external_metadata_by_name(self, **kwargs: Any) -> FlaskResponse:
except ValidationError as err:
return json_error_response(str(err), status=400)

datasource = DatasourceDAO.get_datasource_by_name(
datasource = SqlaTable.get_datasource_by_name(
session=db.session,
datasource_type=DatasourceType(params["datasource_type"]),
database_name=params["database_name"],
schema=params["schema_name"],
datasource_name=params["table_name"],
Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests/security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ def test_get_user_datasources_admin(
mock_get_session.query.return_value.filter.return_value.all.return_value = []

with mock.patch.object(
DatasourceDAO, "get_all_sqlatables_datasources"
SqlaTable, "get_all_sqlatables_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Expand Down Expand Up @@ -1019,7 +1019,7 @@ def test_get_user_datasources_gamma(
mock_get_session.query.return_value.filter.return_value.all.return_value = []

with mock.patch.object(
DatasourceDAO, "get_all_sqlatables_datasources"
SqlaTable, "get_all_sqlatables_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Expand Down Expand Up @@ -1047,7 +1047,7 @@ def test_get_user_datasources_gamma_with_schema(
]

with mock.patch.object(
DatasourceDAO, "get_all_sqlatables_datasources"
SqlaTable, "get_all_sqlatables_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/datasource/dao_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_get_datasource_sl_dataset(
def test_get_all_sqlatables_datasources(
app_context: None, session_with_data: Session
) -> None:
from superset.datasource.dao import DatasourceDAO
from superset.connectors.sqla.models import SqlaTable

result = DatasourceDAO.get_all_sqlatables_datasources(session=session_with_data)
result = SqlaTable.get_all_sqlatables_datasources(session=session_with_data)
assert len(result) == 1

0 comments on commit 95c5eb2

Please sign in to comment.