diff --git a/notebooks/toolkit.py b/notebooks/toolkit.py index b6926b3b6b..8b44f28d99 100644 --- a/notebooks/toolkit.py +++ b/notebooks/toolkit.py @@ -17,8 +17,8 @@ MigrationConfig, TaclConfig, ) -from databricks.labs.ucx.toolkits.group_migration import GroupMigrationToolkit -from databricks.labs.ucx.toolkits.table_acls import TaclToolkit +from databricks.labs.ucx.workspace_access import GroupMigrationToolkit +from databricks.labs.ucx.hive_metastore import TaclToolkit # COMMAND ---------- diff --git a/pyproject.toml b/pyproject.toml index 2958a88ba1..0536b51f3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,14 +163,14 @@ ban-relative-imports = "all" "ARG001" # tests may not use the provided fixtures ] -"src/databricks/labs/ucx/providers/mixins/redash.py" = ["A002", "A003", "N815"] +"src/databricks/labs/ucx/mixins/redash.py" = ["A002", "A003", "N815"] [tool.coverage.run] branch = true parallel = true [tool.coverage.report] -omit = ["src/databricks/labs/ucx/providers/mixins/*"] +omit = ["src/databricks/labs/ucx/mixins/*"] exclude_lines = [ "no cov", "if __name__ == .__main__.:", diff --git a/src/databricks/labs/ucx/__init__.py b/src/databricks/labs/ucx/__init__.py index 98b7d21354..613ee9d172 100644 --- a/src/databricks/labs/ucx/__init__.py +++ b/src/databricks/labs/ucx/__init__.py @@ -1,3 +1,3 @@ -from databricks.labs.ucx.logger import _install +from databricks.labs.ucx.framework.logger import _install _install() diff --git a/src/databricks/labs/ucx/toolkits/assessment.py b/src/databricks/labs/ucx/assessment/assessment.py similarity index 97% rename from src/databricks/labs/ucx/toolkits/assessment.py rename to src/databricks/labs/ucx/assessment/assessment.py index de1b5c99ca..2a14793a52 100644 --- a/src/databricks/labs/ucx/toolkits/assessment.py +++ b/src/databricks/labs/ucx/assessment/assessment.py @@ -6,7 +6,7 @@ from databricks.sdk.service.compute import Language from databricks.labs.ucx.assessment import commands -from databricks.labs.ucx.providers.mixins.compute import CommandExecutor +from databricks.labs.ucx.mixins.compute import CommandExecutor logger = logging.getLogger(__name__) diff --git a/src/databricks/labs/ucx/managers/__init__.py b/src/databricks/labs/ucx/framework/__init__.py similarity index 100% rename from src/databricks/labs/ucx/managers/__init__.py rename to src/databricks/labs/ucx/framework/__init__.py diff --git a/src/databricks/labs/ucx/tacl/_internal.py b/src/databricks/labs/ucx/framework/crawlers.py similarity index 99% rename from src/databricks/labs/ucx/tacl/_internal.py rename to src/databricks/labs/ucx/framework/crawlers.py index 6fb95ae799..3ea69c2286 100644 --- a/src/databricks/labs/ucx/tacl/_internal.py +++ b/src/databricks/labs/ucx/framework/crawlers.py @@ -6,7 +6,7 @@ from databricks.sdk import WorkspaceClient -from databricks.labs.ucx.providers.mixins.sql import StatementExecutionExt +from databricks.labs.ucx.mixins.sql import StatementExecutionExt logger = logging.getLogger(__name__) diff --git a/src/databricks/labs/ucx/logger.py b/src/databricks/labs/ucx/framework/logger.py similarity index 100% rename from src/databricks/labs/ucx/logger.py rename to src/databricks/labs/ucx/framework/logger.py diff --git a/src/databricks/labs/ucx/utils.py b/src/databricks/labs/ucx/framework/parallel.py similarity index 86% rename from src/databricks/labs/ucx/utils.py rename to src/databricks/labs/ucx/framework/parallel.py index 030411d353..8b3c99bcee 100644 --- a/src/databricks/labs/ucx/utils.py +++ b/src/databricks/labs/ucx/framework/parallel.py @@ -5,8 +5,6 @@ from concurrent.futures import ALL_COMPLETED, ThreadPoolExecutor from typing import Generic, TypeVar -from databricks.labs.ucx.generic import StrEnum - ExecutableResult = TypeVar("ExecutableResult") ExecutableFunction = Callable[..., ExecutableResult] logger = logging.getLogger(__name__) @@ -66,20 +64,5 @@ def run(self) -> list[ExecutableResult]: return collected -class Request: - def __init__(self, req: dict): - self.request = req - - def as_dict(self) -> dict: - return self.request - - -class WorkspaceLevelEntitlement(StrEnum): - WORKSPACE_ACCESS = "workspace-access" - DATABRICKS_SQL_ACCESS = "databricks-sql-access" - ALLOW_CLUSTER_CREATE = "allow-cluster-create" - ALLOW_INSTANCE_POOL_CREATE = "allow-instance-pool-create" - - def noop(): pass diff --git a/src/databricks/labs/ucx/tasks.py b/src/databricks/labs/ucx/framework/tasks.py similarity index 97% rename from src/databricks/labs/ucx/tasks.py rename to src/databricks/labs/ucx/framework/tasks.py index 56a0d2f45e..d3adbbad57 100644 --- a/src/databricks/labs/ucx/tasks.py +++ b/src/databricks/labs/ucx/framework/tasks.py @@ -5,7 +5,7 @@ from pathlib import Path from databricks.labs.ucx.config import MigrationConfig -from databricks.labs.ucx.logger import _install +from databricks.labs.ucx.framework.logger import _install _TASKS: dict[str, "Task"] = {} diff --git a/src/databricks/labs/ucx/generic.py b/src/databricks/labs/ucx/generic.py deleted file mode 100644 index c2b3f3a095..0000000000 --- a/src/databricks/labs/ucx/generic.py +++ /dev/null @@ -1,15 +0,0 @@ -import enum - - -class StrEnum(str, enum.Enum): # re-exported for compatability with older python versions - def __new__(cls, value, *args, **kwargs): - if not isinstance(value, str | enum.auto): - msg = f"Values of StrEnums must be strings: {value!r} is a {type(value)}" - raise TypeError(msg) - return super().__new__(cls, value, *args, **kwargs) - - def __str__(self): - return str(self.value) - - def _generate_next_value_(name, *_): # noqa: N805 - return name diff --git a/src/databricks/labs/ucx/hive_metastore/__init__.py b/src/databricks/labs/ucx/hive_metastore/__init__.py new file mode 100644 index 0000000000..7f2a4dcdd2 --- /dev/null +++ b/src/databricks/labs/ucx/hive_metastore/__init__.py @@ -0,0 +1,3 @@ +from databricks.labs.ucx.hive_metastore.table_acls import TaclToolkit + +__all__ = ["TaclToolkit"] diff --git a/src/databricks/labs/ucx/tacl/grants.py b/src/databricks/labs/ucx/hive_metastore/grants.py similarity index 98% rename from src/databricks/labs/ucx/tacl/grants.py rename to src/databricks/labs/ucx/hive_metastore/grants.py index 8d8639bcf3..dcacb62b3d 100644 --- a/src/databricks/labs/ucx/tacl/grants.py +++ b/src/databricks/labs/ucx/hive_metastore/grants.py @@ -2,9 +2,9 @@ from dataclasses import dataclass from functools import partial -from databricks.labs.ucx.tacl._internal import CrawlerBase -from databricks.labs.ucx.tacl.tables import TablesCrawler -from databricks.labs.ucx.utils import ThreadedExecution +from databricks.labs.ucx.framework.crawlers import CrawlerBase +from databricks.labs.ucx.framework.parallel import ThreadedExecution +from databricks.labs.ucx.hive_metastore.tables import TablesCrawler @dataclass(frozen=True) diff --git a/src/databricks/labs/ucx/toolkits/table_acls.py b/src/databricks/labs/ucx/hive_metastore/table_acls.py similarity index 87% rename from src/databricks/labs/ucx/toolkits/table_acls.py rename to src/databricks/labs/ucx/hive_metastore/table_acls.py index 097c17313a..37cd88eb81 100644 --- a/src/databricks/labs/ucx/toolkits/table_acls.py +++ b/src/databricks/labs/ucx/hive_metastore/table_acls.py @@ -2,13 +2,13 @@ from databricks.sdk import WorkspaceClient -from databricks.labs.ucx.tacl._internal import ( +from databricks.labs.ucx.framework.crawlers import ( RuntimeBackend, SqlBackend, StatementExecutionBackend, ) -from databricks.labs.ucx.tacl.grants import GrantsCrawler -from databricks.labs.ucx.tacl.tables import TablesCrawler +from databricks.labs.ucx.hive_metastore.grants import GrantsCrawler +from databricks.labs.ucx.hive_metastore.tables import TablesCrawler logger = logging.getLogger(__name__) diff --git a/src/databricks/labs/ucx/tacl/tables.py b/src/databricks/labs/ucx/hive_metastore/tables.py similarity index 94% rename from src/databricks/labs/ucx/tacl/tables.py rename to src/databricks/labs/ucx/hive_metastore/tables.py index d7ac40929c..76c98240dd 100644 --- a/src/databricks/labs/ucx/tacl/tables.py +++ b/src/databricks/labs/ucx/hive_metastore/tables.py @@ -3,9 +3,9 @@ from dataclasses import dataclass from functools import partial -from databricks.labs.ucx.providers.mixins.sql import Row -from databricks.labs.ucx.tacl._internal import CrawlerBase, SqlBackend -from databricks.labs.ucx.utils import ThreadedExecution +from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend +from databricks.labs.ucx.framework.parallel import ThreadedExecution +from databricks.labs.ucx.mixins.sql import Row logger = logging.getLogger(__name__) @@ -75,8 +75,7 @@ def __init__(self, backend: SqlBackend, catalog, schema): Initializes a TablesCrawler instance. Args: - ws (WorkspaceClient): The WorkspaceClient instance. - warehouse_id: The warehouse ID. + backend (SqlBackend): The SQL Execution Backend abstraction (either REST API or Spark) catalog (str): The catalog name for the inventory persistence. schema: The schema name for the inventory persistence. """ diff --git a/src/databricks/labs/ucx/install.py b/src/databricks/labs/ucx/install.py index 8d6ac54f91..70f3f32db1 100644 --- a/src/databricks/labs/ucx/install.py +++ b/src/databricks/labs/ucx/install.py @@ -16,8 +16,8 @@ from databricks.labs.ucx.__about__ import __version__ from databricks.labs.ucx.config import GroupsConfig, MigrationConfig, TaclConfig +from databricks.labs.ucx.framework.tasks import _TASKS from databricks.labs.ucx.runtime import main -from databricks.labs.ucx.tasks import _TASKS TAG_STEP = "step" TAG_APP = "App" diff --git a/src/databricks/labs/ucx/inventory/__init__.py b/src/databricks/labs/ucx/inventory/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/src/databricks/labs/ucx/inventory/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/databricks/labs/ucx/inventory/permissions.py b/src/databricks/labs/ucx/inventory/permissions.py deleted file mode 100644 index 257edd24e5..0000000000 --- a/src/databricks/labs/ucx/inventory/permissions.py +++ /dev/null @@ -1,64 +0,0 @@ -import logging -from itertools import groupby -from typing import Literal - -from databricks.sdk import WorkspaceClient - -from databricks.labs.ucx.inventory.permissions_inventory import ( - PermissionsInventoryTable, -) -from databricks.labs.ucx.providers.groups_info import GroupMigrationState -from databricks.labs.ucx.support.impl import SupportsProvider -from databricks.labs.ucx.utils import ThreadedExecution - -logger = logging.getLogger(__name__) - - -class PermissionManager: - def __init__( - self, ws: WorkspaceClient, permissions_inventory: PermissionsInventoryTable, supports_provider: SupportsProvider - ): - self._ws = ws - self._permissions_inventory = permissions_inventory - self._supports_provider = supports_provider - - def inventorize_permissions(self): - logger.info("Inventorizing the permissions") - crawler_tasks = list(self._supports_provider.get_crawler_tasks()) - logger.info(f"Total crawler tasks: {len(crawler_tasks)}") - logger.info("Starting the permissions inventorization") - results = ThreadedExecution.gather("crawl permissions", crawler_tasks) - items = [item for item in results if item is not None] - logger.info(f"Total inventorized items: {len(items)}") - self._permissions_inventory.save(items) - logger.info("Permissions were inventorized and saved") - - def apply_group_permissions(self, migration_state: GroupMigrationState, destination: Literal["backup", "account"]): - logger.info(f"Applying the permissions to {destination} groups") - logger.info(f"Total groups to apply permissions: {len(migration_state.groups)}") - # list shall be sorted prior to using group by - items = sorted(self._permissions_inventory.load_all(), key=lambda i: i.support) - logger.info(f"Total inventorized items: {len(items)}") - applier_tasks = [] - supports_to_items = { - support: list(items_subset) for support, items_subset in groupby(items, key=lambda i: i.support) - } - - # we first check that all supports are valid. - for support in supports_to_items: - if support not in self._supports_provider.supports: - msg = f"Could not find support for {support}. Please check the inventory table." - raise ValueError(msg) - - for support, items_subset in supports_to_items.items(): - relevant_support = self._supports_provider.supports[support] - tasks_for_support = [ - relevant_support.get_apply_task(item, migration_state, destination) for item in items_subset - ] - logger.info(f"Total tasks for {support}: {len(tasks_for_support)}") - applier_tasks.extend(tasks_for_support) - - logger.info(f"Total applier tasks: {len(applier_tasks)}") - logger.info("Starting the permissions application") - ThreadedExecution.gather("apply permissions", applier_tasks) - logger.info("Permissions were applied") diff --git a/src/databricks/labs/ucx/inventory/permissions_inventory.py b/src/databricks/labs/ucx/inventory/permissions_inventory.py deleted file mode 100644 index 52ff447595..0000000000 --- a/src/databricks/labs/ucx/inventory/permissions_inventory.py +++ /dev/null @@ -1,31 +0,0 @@ -import logging - -from databricks.labs.ucx.inventory.types import PermissionsInventoryItem -from databricks.labs.ucx.tacl._internal import CrawlerBase, SqlBackend - -logger = logging.getLogger(__name__) - - -class PermissionsInventoryTable(CrawlerBase): - def __init__(self, backend: SqlBackend, inventory_database: str): - super().__init__(backend, "hive_metastore", inventory_database, "permissions") - - def cleanup(self): - logger.info(f"Cleaning up inventory table {self._full_name}") - self._exec(f"DROP TABLE IF EXISTS {self._full_name}") - logger.info("Inventory table cleanup complete") - - def save(self, items: list[PermissionsInventoryItem]): - # TODO: update instead of append - logger.info(f"Saving {len(items)} items to inventory table {self._full_name}") - self._append_records(PermissionsInventoryItem, items) - logger.info("Successfully saved the items to inventory table") - - def load_all(self) -> list[PermissionsInventoryItem]: - logger.info(f"Loading inventory table {self._full_name}") - return [ - PermissionsInventoryItem(object_id, support, raw_object_permissions) - for object_id, support, raw_object_permissions in self._fetch( - f"SELECT object_id, support, raw_object_permissions FROM {self._full_name}" - ) - ] diff --git a/src/databricks/labs/ucx/inventory/types.py b/src/databricks/labs/ucx/inventory/types.py deleted file mode 100644 index 276db209c3..0000000000 --- a/src/databricks/labs/ucx/inventory/types.py +++ /dev/null @@ -1,32 +0,0 @@ -from dataclasses import dataclass -from typing import Literal - -from databricks.labs.ucx.generic import StrEnum - -Destination = Literal["backup", "account"] - - -class RequestObjectType(StrEnum): - AUTHORIZATION = "authorization" # tokens and passwords are here too! - CLUSTERS = "clusters" - CLUSTER_POLICIES = "cluster-policies" - DIRECTORIES = "directories" - EXPERIMENTS = "experiments" - FILES = "files" - INSTANCE_POOLS = "instance-pools" - JOBS = "jobs" - NOTEBOOKS = "notebooks" - PIPELINES = "pipelines" - REGISTERED_MODELS = "registered-models" - REPOS = "repos" - SQL_WAREHOUSES = "sql/warehouses" # / is not a typo, it's the real object type - - def __repr__(self): - return self.value - - -@dataclass -class PermissionsInventoryItem: - object_id: str - support: str # shall be taken from CRAWLERS dict - raw_object_permissions: str diff --git a/src/databricks/labs/ucx/providers/mixins/README.md b/src/databricks/labs/ucx/mixins/README.md similarity index 100% rename from src/databricks/labs/ucx/providers/mixins/README.md rename to src/databricks/labs/ucx/mixins/README.md diff --git a/src/databricks/labs/ucx/providers/__init__.py b/src/databricks/labs/ucx/mixins/__init__.py similarity index 100% rename from src/databricks/labs/ucx/providers/__init__.py rename to src/databricks/labs/ucx/mixins/__init__.py diff --git a/src/databricks/labs/ucx/providers/mixins/compute.py b/src/databricks/labs/ucx/mixins/compute.py similarity index 100% rename from src/databricks/labs/ucx/providers/mixins/compute.py rename to src/databricks/labs/ucx/mixins/compute.py diff --git a/src/databricks/labs/ucx/providers/mixins/fixtures.py b/src/databricks/labs/ucx/mixins/fixtures.py similarity index 100% rename from src/databricks/labs/ucx/providers/mixins/fixtures.py rename to src/databricks/labs/ucx/mixins/fixtures.py diff --git a/src/databricks/labs/ucx/providers/mixins/hardening.py b/src/databricks/labs/ucx/mixins/hardening.py similarity index 100% rename from src/databricks/labs/ucx/providers/mixins/hardening.py rename to src/databricks/labs/ucx/mixins/hardening.py diff --git a/src/databricks/labs/ucx/providers/mixins/redash.py b/src/databricks/labs/ucx/mixins/redash.py similarity index 100% rename from src/databricks/labs/ucx/providers/mixins/redash.py rename to src/databricks/labs/ucx/mixins/redash.py diff --git a/src/databricks/labs/ucx/providers/mixins/sql.py b/src/databricks/labs/ucx/mixins/sql.py similarity index 100% rename from src/databricks/labs/ucx/providers/mixins/sql.py rename to src/databricks/labs/ucx/mixins/sql.py diff --git a/src/databricks/labs/ucx/providers/groups_info.py b/src/databricks/labs/ucx/providers/groups_info.py deleted file mode 100644 index 5f3dbb048f..0000000000 --- a/src/databricks/labs/ucx/providers/groups_info.py +++ /dev/null @@ -1,33 +0,0 @@ -from dataclasses import dataclass - -from databricks.sdk.service.iam import Group - - -@dataclass -class MigrationGroupInfo: - workspace: Group - backup: Group - account: Group - - -class GroupMigrationState: - """Holds migration state of workspace-to-account groups""" - - def __init__(self): - self.groups: list[MigrationGroupInfo] = [] - - def add(self, group: MigrationGroupInfo): - self.groups.append(group) - - def is_in_scope(self, attr: str, group: Group) -> bool: - for info in self.groups: - if getattr(info, attr).id == group.id: - return True - return False - - def get_by_workspace_group_name(self, workspace_group_name: str) -> MigrationGroupInfo | None: - found = [g for g in self.groups if g.workspace.display_name == workspace_group_name] - if len(found) == 0: - return None - else: - return found[0] diff --git a/src/databricks/labs/ucx/runtime.py b/src/databricks/labs/ucx/runtime.py index af5f5c502e..022c06fd36 100644 --- a/src/databricks/labs/ucx/runtime.py +++ b/src/databricks/labs/ucx/runtime.py @@ -5,9 +5,9 @@ from databricks.sdk import WorkspaceClient from databricks.labs.ucx.config import MigrationConfig -from databricks.labs.ucx.tasks import task, trigger -from databricks.labs.ucx.toolkits.group_migration import GroupMigrationToolkit -from databricks.labs.ucx.toolkits.table_acls import TaclToolkit +from databricks.labs.ucx.framework.tasks import task, trigger +from databricks.labs.ucx.hive_metastore import TaclToolkit +from databricks.labs.ucx.workspace_access import GroupMigrationToolkit logger = logging.getLogger(__name__) diff --git a/src/databricks/labs/ucx/support/impl.py b/src/databricks/labs/ucx/support/impl.py deleted file mode 100644 index d82302f543..0000000000 --- a/src/databricks/labs/ucx/support/impl.py +++ /dev/null @@ -1,87 +0,0 @@ -from collections.abc import Callable, Iterator - -from databricks.sdk import WorkspaceClient -from databricks.sdk.service import sql - -from databricks.labs.ucx.inventory.types import ( - PermissionsInventoryItem, - RequestObjectType, -) -from databricks.labs.ucx.support.base import BaseSupport -from databricks.labs.ucx.support.group_level import ScimSupport -from databricks.labs.ucx.support.listing import ( - authorization_listing, - experiments_listing, - models_listing, - workspace_listing, -) -from databricks.labs.ucx.support.permissions import ( - GenericPermissionsSupport, - listing_wrapper, -) -from databricks.labs.ucx.support.secrets import SecretScopesSupport -from databricks.labs.ucx.support.sql import SqlPermissionsSupport -from databricks.labs.ucx.support.sql import listing_wrapper as sql_listing_wrapper - - -class SupportsProvider: - def __init__(self, ws: WorkspaceClient, num_threads: int, workspace_start_path: str): - self._generic_support = GenericPermissionsSupport( - ws=ws, - listings=[ - listing_wrapper(ws.clusters.list, "cluster_id", RequestObjectType.CLUSTERS), - listing_wrapper(ws.cluster_policies.list, "policy_id", RequestObjectType.CLUSTER_POLICIES), - listing_wrapper(ws.instance_pools.list, "instance_pool_id", RequestObjectType.INSTANCE_POOLS), - listing_wrapper(ws.warehouses.list, "id", RequestObjectType.SQL_WAREHOUSES), - listing_wrapper(ws.jobs.list, "job_id", RequestObjectType.JOBS), - listing_wrapper(ws.pipelines.list_pipelines, "pipeline_id", RequestObjectType.PIPELINES), - listing_wrapper(experiments_listing(ws), "experiment_id", RequestObjectType.EXPERIMENTS), - listing_wrapper(models_listing(ws), "id", RequestObjectType.REGISTERED_MODELS), - workspace_listing(ws, num_threads=num_threads, start_path=workspace_start_path), - authorization_listing(), - ], - ) - self._secrets_support = SecretScopesSupport(ws=ws) - self._scim_support = ScimSupport(ws) - self._sql_support = SqlPermissionsSupport( - ws, - listings=[ - sql_listing_wrapper(ws.alerts.list, sql.ObjectTypePlural.ALERTS), - sql_listing_wrapper(ws.dashboards.list, sql.ObjectTypePlural.DASHBOARDS), - sql_listing_wrapper(ws.queries.list, sql.ObjectTypePlural.QUERIES), - ], - ) - - def get_crawler_tasks(self) -> Iterator[Callable[..., PermissionsInventoryItem | None]]: - for support in [self._generic_support, self._secrets_support, self._scim_support, self._sql_support]: - yield from support.get_crawler_tasks() - - @property - def supports(self) -> dict[str, BaseSupport]: - return { - # SCIM-based API - "entitlements": self._scim_support, - "roles": self._scim_support, - # generic API - "clusters": self._generic_support, - "cluster-policies": self._generic_support, - "instance-pools": self._generic_support, - "sql/warehouses": self._generic_support, - "jobs": self._generic_support, - "pipelines": self._generic_support, - "experiments": self._generic_support, - "registered-models": self._generic_support, - "tokens": self._generic_support, - "passwords": self._generic_support, - # workspace objects - "notebooks": self._generic_support, - "files": self._generic_support, - "directories": self._generic_support, - "repos": self._generic_support, - # SQL API - "alerts": self._sql_support, - "queries": self._sql_support, - "dashboards": self._sql_support, - # secrets API - "secrets": self._secrets_support, - } diff --git a/src/databricks/labs/ucx/support/permissions.py b/src/databricks/labs/ucx/support/permissions.py deleted file mode 100644 index f645b5250e..0000000000 --- a/src/databricks/labs/ucx/support/permissions.py +++ /dev/null @@ -1,159 +0,0 @@ -import json -from collections.abc import Callable, Iterator -from dataclasses import dataclass -from functools import partial - -from databricks.sdk import WorkspaceClient -from databricks.sdk.core import DatabricksError -from databricks.sdk.service import iam - -from databricks.labs.ucx.inventory.types import ( - Destination, - PermissionsInventoryItem, - RequestObjectType, -) -from databricks.labs.ucx.providers.groups_info import GroupMigrationState -from databricks.labs.ucx.providers.mixins.hardening import rate_limited -from databricks.labs.ucx.support.base import BaseSupport, logger - - -@dataclass -class GenericPermissionsInfo: - object_id: str - request_type: RequestObjectType - - -class GenericPermissionsSupport(BaseSupport): - def __init__( - self, - listings: list[Callable[..., Iterator[GenericPermissionsInfo]]], - ws: WorkspaceClient, - ): - super().__init__(ws) - self._listings: list[Callable[..., Iterator[GenericPermissionsInfo]]] = listings - - def _safe_get_permissions( - self, ws: WorkspaceClient, request_object_type: RequestObjectType, object_id: str - ) -> iam.ObjectPermissions | None: - try: - permissions = ws.permissions.get(request_object_type, object_id) - return permissions - except DatabricksError as e: - if e.error_code in [ - "RESOURCE_DOES_NOT_EXIST", - "RESOURCE_NOT_FOUND", - "PERMISSION_DENIED", - "FEATURE_DISABLED", - ]: - logger.warning(f"Could not get permissions for {request_object_type} {object_id} due to {e.error_code}") - return None - else: - raise e - - def _prepare_new_acl( - self, permissions: iam.ObjectPermissions, migration_state: GroupMigrationState, destination: Destination - ) -> list[iam.AccessControlRequest]: - _acl = permissions.access_control_list - acl_requests = [] - - for _item in _acl: - # TODO: we have a double iteration over migration_state.groups - # (also by migration_state.get_by_workspace_group_name). - # Has to be be fixed by iterating just on .groups - if _item.group_name in [g.workspace.display_name for g in migration_state.groups]: - migration_info = migration_state.get_by_workspace_group_name(_item.group_name) - assert migration_info is not None, f"Group {_item.group_name} is not in the migration groups provider" - destination_group: iam.Group = getattr(migration_info, destination) - _item.group_name = destination_group.display_name - _reqs = [ - iam.AccessControlRequest( - group_name=_item.group_name, - service_principal_name=_item.service_principal_name, - user_name=_item.user_name, - permission_level=p.permission_level, - ) - for p in _item.all_permissions - if not p.inherited - ] - acl_requests.extend(_reqs) - - return acl_requests - - @rate_limited(max_requests=30) - def _applier_task( - self, ws: WorkspaceClient, object_id: str, acl: list[iam.AccessControlRequest], request_type: RequestObjectType - ): - ws.permissions.update(request_object_type=request_type, request_object_id=object_id, access_control_list=acl) - - @rate_limited(max_requests=100) - def _crawler_task( - self, - ws: WorkspaceClient, - object_id: str, - request_type: RequestObjectType, - ) -> PermissionsInventoryItem | None: - permissions = self._safe_get_permissions(ws, request_type, object_id) - - support = object_id if request_type == RequestObjectType.AUTHORIZATION else request_type.value - - if permissions: - return PermissionsInventoryItem( - object_id=object_id, - support=support, - raw_object_permissions=json.dumps(permissions.as_dict()), - ) - - def _get_apply_task( - self, item: PermissionsInventoryItem, migration_state: GroupMigrationState, destination: Destination - ) -> partial: - new_acl = self._prepare_new_acl( - iam.ObjectPermissions.from_dict(json.loads(item.raw_object_permissions)), migration_state, destination - ) - - request_type = ( - RequestObjectType.AUTHORIZATION - if item.support in ("passwords", "tokens") - else RequestObjectType(item.support) - ) - - return partial( - self._applier_task, - ws=self._ws, - request_type=request_type, - acl=new_acl, - object_id=item.object_id, - ) - - def is_item_relevant(self, item: PermissionsInventoryItem, migration_state: GroupMigrationState) -> bool: - # passwords and tokens are represented on the workspace-level - if item.object_id in ("tokens", "passwords"): - return True - else: - mentioned_groups = [ - acl.group_name - for acl in iam.ObjectPermissions.from_dict(json.loads(item.raw_object_permissions)).access_control_list - ] - return any(g in mentioned_groups for g in [info.workspace.display_name for info in migration_state.groups]) - - def get_crawler_tasks(self): - for listing in self._listings: - for info in listing(): - yield partial( - self._crawler_task, - ws=self._ws, - object_id=info.object_id, - request_type=info.request_type, - ) - - -def listing_wrapper( - func: Callable[..., list], id_attribute: str, object_type: RequestObjectType -) -> Callable[..., Iterator[GenericPermissionsInfo]]: - def wrapper() -> Iterator[GenericPermissionsInfo]: - for item in func(): - yield GenericPermissionsInfo( - object_id=getattr(item, id_attribute), - request_type=object_type, - ) - - return wrapper diff --git a/src/databricks/labs/ucx/toolkits/group_migration.py b/src/databricks/labs/ucx/toolkits/group_migration.py deleted file mode 100644 index 2796d213a5..0000000000 --- a/src/databricks/labs/ucx/toolkits/group_migration.py +++ /dev/null @@ -1,91 +0,0 @@ -import logging - -from databricks.sdk import WorkspaceClient - -from databricks.labs.ucx.config import MigrationConfig -from databricks.labs.ucx.inventory.permissions import PermissionManager -from databricks.labs.ucx.inventory.permissions_inventory import ( - PermissionsInventoryTable, -) -from databricks.labs.ucx.inventory.verification import VerificationManager -from databricks.labs.ucx.managers.group import GroupManager -from databricks.labs.ucx.support.impl import SupportsProvider -from databricks.labs.ucx.tacl._internal import ( - RuntimeBackend, - SqlBackend, - StatementExecutionBackend, -) - - -class GroupMigrationToolkit: - def __init__(self, config: MigrationConfig, *, warehouse_id=None): - self._num_threads = config.num_threads - self._workspace_start_path = config.workspace_start_path - - databricks_config = config.to_databricks_config() - self._configure_logger(config.log_level) - - # integrate with connection pool settings properly - # https://github.com/databricks/databricks-sdk-py/pull/276 - self._ws = WorkspaceClient(config=databricks_config) - self._ws.api_client._session.adapters["https://"].max_retries.total = 20 - self._verify_ws_client(self._ws) - - self._group_manager = GroupManager(self._ws, config.groups) - sql_backend = self._backend(self._ws, warehouse_id) - self._permissions_inventory = PermissionsInventoryTable(sql_backend, config.inventory_database) - self._supports_provider = SupportsProvider(self._ws, self._num_threads, self._workspace_start_path) - self._permissions_manager = PermissionManager( - self._ws, self._permissions_inventory, supports_provider=self._supports_provider - ) - self._verification_manager = VerificationManager(self._ws, self._supports_provider.supports["secrets"]) - - @staticmethod - def _backend(ws: WorkspaceClient, warehouse_id: str | None = None) -> SqlBackend: - if warehouse_id is None: - return RuntimeBackend() - return StatementExecutionBackend(ws, warehouse_id) - - @staticmethod - def _verify_ws_client(w: WorkspaceClient): - _me = w.current_user.me() - is_workspace_admin = any(g.display == "admins" for g in _me.groups) - if not is_workspace_admin: - msg = "Current user is not a workspace admin" - raise RuntimeError(msg) - - @staticmethod - def _configure_logger(level: str): - ucx_logger = logging.getLogger("databricks.labs.ucx") - ucx_logger.setLevel(level) - - def prepare_environment(self): - self._group_manager.prepare_groups_in_environment() - - def cleanup_inventory_table(self): - self._permissions_inventory.cleanup() - - def inventorize_permissions(self): - self._permissions_manager.inventorize_permissions() - - def apply_permissions_to_backup_groups(self): - self._permissions_manager.apply_group_permissions( - self._group_manager.migration_groups_provider, destination="backup" - ) - - def verify_permissions_on_backup_groups(self, to_verify): - self._verification_manager.verify(self._group_manager.migration_groups_provider, "backup", to_verify) - - def replace_workspace_groups_with_account_groups(self): - self._group_manager.replace_workspace_groups_with_account_groups() - - def apply_permissions_to_account_groups(self): - self._permissions_manager.apply_group_permissions( - self._group_manager.migration_groups_provider, destination="account" - ) - - def verify_permissions_on_account_groups(self, to_verify): - self._verification_manager.verify(self._group_manager.migration_groups_provider, "account", to_verify) - - def delete_backup_groups(self): - self._group_manager.delete_backup_groups() diff --git a/src/databricks/labs/ucx/workspace_access/__init__.py b/src/databricks/labs/ucx/workspace_access/__init__.py new file mode 100644 index 0000000000..5b554bbd89 --- /dev/null +++ b/src/databricks/labs/ucx/workspace_access/__init__.py @@ -0,0 +1,3 @@ +from databricks.labs.ucx.workspace_access.migration import GroupMigrationToolkit + +__all__ = ["GroupMigrationToolkit"] diff --git a/src/databricks/labs/ucx/support/base.py b/src/databricks/labs/ucx/workspace_access/base.py similarity index 52% rename from src/databricks/labs/ucx/support/base.py rename to src/databricks/labs/ucx/workspace_access/base.py index ab9fd75561..59fc64922f 100644 --- a/src/databricks/labs/ucx/support/base.py +++ b/src/databricks/labs/ucx/workspace_access/base.py @@ -1,20 +1,28 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from collections.abc import Callable, Iterator +from dataclasses import dataclass from functools import partial from logging import Logger +from typing import Literal -from databricks.sdk import WorkspaceClient - -from databricks.labs.ucx.inventory.types import Destination, PermissionsInventoryItem -from databricks.labs.ucx.providers.groups_info import GroupMigrationState -from databricks.labs.ucx.utils import noop +from databricks.labs.ucx.workspace_access.groups import GroupMigrationState logger = Logger(__name__) +@dataclass +class Permissions: + object_id: str + object_type: str + raw: str + + +Destination = Literal["backup", "account"] + + class Crawler: @abstractmethod - def get_crawler_tasks(self) -> Iterator[Callable[..., PermissionsInventoryItem | None]]: + def get_crawler_tasks(self) -> Iterator[Callable[..., Permissions | None]]: """ This method should return a list of crawler tasks (e.g. partials or just any callables) :return: @@ -23,35 +31,28 @@ def get_crawler_tasks(self) -> Iterator[Callable[..., PermissionsInventoryItem | class Applier: @abstractmethod - def is_item_relevant(self, item: PermissionsInventoryItem, migration_state: GroupMigrationState) -> bool: + def is_item_relevant(self, item: Permissions, migration_state: GroupMigrationState) -> bool: """ This method verifies that the given item is relevant for the given migration state. """ @abstractmethod def _get_apply_task( - self, item: PermissionsInventoryItem, migration_state: GroupMigrationState, destination: Destination + self, item: Permissions, migration_state: GroupMigrationState, destination: Destination ) -> partial: """ This method should return an instance of ApplierTask. """ def get_apply_task( - self, item: PermissionsInventoryItem, migration_state: GroupMigrationState, destination: Destination + self, item: Permissions, migration_state: GroupMigrationState, destination: Destination ) -> partial: # we explicitly put the relevance check here to avoid "forgotten implementation" in child classes if self.is_item_relevant(item, migration_state): return self._get_apply_task(item, migration_state, destination) else: - return partial(noop) + def noop(): + pass -class BaseSupport(ABC, Crawler, Applier): - """ - Base class for all support classes. - Child classes must implement all abstract methods. - """ - - def __init__(self, ws: WorkspaceClient): - # workspace client is required in all implementations - self._ws = ws + return partial(noop) diff --git a/src/databricks/labs/ucx/workspace_access/generic.py b/src/databricks/labs/ucx/workspace_access/generic.py new file mode 100644 index 0000000000..deb4ede41c --- /dev/null +++ b/src/databricks/labs/ucx/workspace_access/generic.py @@ -0,0 +1,196 @@ +import json +import logging +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from functools import partial + +from databricks.sdk import WorkspaceClient +from databricks.sdk.core import DatabricksError +from databricks.sdk.service import iam, ml, workspace + +from databricks.labs.ucx.mixins.hardening import rate_limited +from databricks.labs.ucx.workspace_access.base import ( + Applier, + Crawler, + Destination, + Permissions, +) +from databricks.labs.ucx.workspace_access.groups import GroupMigrationState + +logger = logging.getLogger(__name__) + + +@dataclass +class GenericPermissionsInfo: + object_id: str + request_type: str + + +class GenericPermissionsSupport(Crawler, Applier): + def __init__(self, ws: WorkspaceClient, listings: list[Callable[..., Iterator[GenericPermissionsInfo]]]): + self._ws = ws + self._listings = listings + + def get_crawler_tasks(self): + for listing in self._listings: + for info in listing(): + yield partial(self._crawler_task, info.request_type, info.object_id) + + def is_item_relevant(self, item: Permissions, migration_state: GroupMigrationState) -> bool: + # passwords and tokens are represented on the workspace-level + if item.object_id in ("tokens", "passwords"): + return True + mentioned_groups = [ + acl.group_name for acl in iam.ObjectPermissions.from_dict(json.loads(item.raw)).access_control_list + ] + return any(g in mentioned_groups for g in [info.workspace.display_name for info in migration_state.groups]) + + def _get_apply_task( + self, item: Permissions, migration_state: GroupMigrationState, destination: Destination + ) -> partial: + new_acl = self._prepare_new_acl( + iam.ObjectPermissions.from_dict(json.loads(item.raw)), migration_state, destination + ) + return partial(self._applier_task, item.object_type, item.object_id, new_acl) + + @rate_limited(max_requests=30) + def _applier_task(self, object_type: str, object_id: str, acl: list[iam.AccessControlRequest]): + self._ws.permissions.update(object_type, object_id, access_control_list=acl) + + @rate_limited(max_requests=100) + def _crawler_task(self, object_type: str, object_id: str) -> Permissions | None: + permissions = self._safe_get_permissions(object_type, object_id) + if not permissions: + return None + return Permissions( + object_id=object_id, + object_type=object_type, + raw=json.dumps(permissions.as_dict()), + ) + + def _safe_get_permissions(self, object_type: str, object_id: str) -> iam.ObjectPermissions | None: + try: + return self._ws.permissions.get(object_type, object_id) + except DatabricksError as e: + if e.error_code in [ + "RESOURCE_DOES_NOT_EXIST", + "RESOURCE_NOT_FOUND", + "PERMISSION_DENIED", + "FEATURE_DISABLED", + ]: + logger.warning(f"Could not get permissions for {object_type} {object_id} due to {e.error_code}") + return None + else: + raise e + + def _prepare_new_acl( + self, permissions: iam.ObjectPermissions, migration_state: GroupMigrationState, destination: Destination + ) -> list[iam.AccessControlRequest]: + _acl = permissions.access_control_list + acl_requests = [] + + for _item in _acl: + # TODO: we have a double iteration over migration_state.groups + # (also by migration_state.get_by_workspace_group_name). + # Has to be be fixed by iterating just on .groups + if _item.group_name in [g.workspace.display_name for g in migration_state.groups]: + migration_info = migration_state.get_by_workspace_group_name(_item.group_name) + assert migration_info is not None, f"Group {_item.group_name} is not in the migration groups provider" + destination_group: iam.Group = getattr(migration_info, destination) + _item.group_name = destination_group.display_name + _reqs = [ + iam.AccessControlRequest( + group_name=_item.group_name, + service_principal_name=_item.service_principal_name, + user_name=_item.user_name, + permission_level=p.permission_level, + ) + for p in _item.all_permissions + if not p.inherited + ] + acl_requests.extend(_reqs) + + return acl_requests + + +def listing_wrapper( + func: Callable[..., list], id_attribute: str, object_type: str +) -> Callable[..., Iterator[GenericPermissionsInfo]]: + def wrapper() -> Iterator[GenericPermissionsInfo]: + for item in func(): + yield GenericPermissionsInfo( + object_id=getattr(item, id_attribute), + request_type=object_type, + ) + + return wrapper + + +def workspace_listing(ws: WorkspaceClient, num_threads=20, start_path: str | None = "/"): + def _convert_object_type_to_request_type(_object: workspace.ObjectInfo) -> str | None: + match _object.object_type: + case workspace.ObjectType.NOTEBOOK: + return "notebooks" + case workspace.ObjectType.DIRECTORY: + return "directories" + case workspace.ObjectType.LIBRARY: + return None + case workspace.ObjectType.REPO: + return "repos" + case workspace.ObjectType.FILE: + return "files" + # silent handler for experiments - they'll be inventorized by the experiments manager + case None: + return None + + def inner(): + from databricks.labs.ucx.workspace_access.listing import WorkspaceListing + + ws_listing = WorkspaceListing(ws, num_threads=num_threads, with_directories=False) + for _object in ws_listing.walk(start_path): + request_type = _convert_object_type_to_request_type(_object) + if request_type: + yield GenericPermissionsInfo(object_id=str(_object.object_id), request_type=request_type) + + return inner + + +def models_listing(ws: WorkspaceClient): + def inner() -> Iterator[ml.ModelDatabricks]: + for model in ws.model_registry.list_models(): + model_with_id = ws.model_registry.get_model(model.name).registered_model_databricks + yield model_with_id + + return inner + + +def experiments_listing(ws: WorkspaceClient): + def inner() -> Iterator[ml.Experiment]: + for experiment in ws.experiments.list_experiments(): + """ + We filter-out notebook-based experiments, because they are covered by notebooks listing + """ + # workspace-based notebook experiment + if experiment.tags: + nb_tag = [t for t in experiment.tags if t.key == "mlflow.experimentType" and t.value == "NOTEBOOK"] + # repo-based notebook experiment + repo_nb_tag = [ + t for t in experiment.tags if t.key == "mlflow.experiment.sourceType" and t.value == "REPO_NOTEBOOK" + ] + if nb_tag or repo_nb_tag: + continue + + yield experiment + + return inner + + +def authorization_listing(): + def inner(): + for _value in ["passwords", "tokens"]: + yield GenericPermissionsInfo( + object_id=_value, + request_type="authorization", + ) + + return inner diff --git a/src/databricks/labs/ucx/managers/group.py b/src/databricks/labs/ucx/workspace_access/groups.py similarity index 65% rename from src/databricks/labs/ucx/managers/group.py rename to src/databricks/labs/ucx/workspace_access/groups.py index cda484c2e3..de4143a942 100644 --- a/src/databricks/labs/ucx/managers/group.py +++ b/src/databricks/labs/ucx/workspace_access/groups.py @@ -1,30 +1,55 @@ import json import logging import typing +from dataclasses import dataclass from functools import partial from databricks.sdk import WorkspaceClient from databricks.sdk.service import iam +from databricks.sdk.service.iam import Group from databricks.labs.ucx.config import GroupsConfig -from databricks.labs.ucx.generic import StrEnum -from databricks.labs.ucx.providers.groups_info import ( - GroupMigrationState, - MigrationGroupInfo, -) -from databricks.labs.ucx.providers.mixins.hardening import rate_limited -from databricks.labs.ucx.utils import ThreadedExecution +from databricks.labs.ucx.framework.parallel import ThreadedExecution +from databricks.labs.ucx.mixins.hardening import rate_limited logger = logging.getLogger(__name__) +GroupLevel = typing.Literal["workspace", "account"] -class GroupLevel(StrEnum): - WORKSPACE = "workspace" - ACCOUNT = "account" + +@dataclass +class MigrationGroupInfo: + workspace: Group + backup: Group + account: Group + + +class GroupMigrationState: + """Holds migration state of workspace-to-account groups""" + + def __init__(self): + self.groups: list[MigrationGroupInfo] = [] + + def add(self, group: MigrationGroupInfo): + self.groups.append(group) + + def is_in_scope(self, attr: str, group: Group) -> bool: + for info in self.groups: + if getattr(info, attr).id == group.id: + return True + return False + + def get_by_workspace_group_name(self, workspace_group_name: str) -> MigrationGroupInfo | None: + found = [g for g in self.groups if g.workspace.display_name == workspace_group_name] + if len(found) == 0: + return None + else: + return found[0] class GroupManager: SYSTEM_GROUPS: typing.ClassVar[list[str]] = ["users", "admins", "account users"] + SCIM_ATTRIBUTES = "id,displayName,meta,members" def __init__(self, ws: WorkspaceClient, groups: GroupsConfig): self._ws = ws @@ -37,11 +62,11 @@ def _list_workspace_groups(self) -> list[iam.Group]: logger.debug("Listing workspace groups...") workspace_groups = [ g - for g in self._ws.groups.list(attributes="id,displayName,meta") + for g in self._ws.groups.list(attributes=self.SCIM_ATTRIBUTES) if g.meta.resource_type == "WorkspaceGroup" and g.display_name not in self.SYSTEM_GROUPS ] logger.debug(f"Found {len(workspace_groups)} workspace groups") - return workspace_groups + return sorted(workspace_groups, key=lambda _: _.display_name) def _list_account_groups(self) -> list[iam.Group]: # TODO: we should avoid using this method, as it's not documented @@ -52,53 +77,52 @@ def _list_account_groups(self) -> list[iam.Group]: for r in self._ws.api_client.do( "get", "/api/2.0/account/scim/v2/Groups", - query={ - "attributes": "id,displayName,meta", - }, + query={"attributes": self.SCIM_ATTRIBUTES}, ).get("Resources", []) ] account_groups = [g for g in account_groups if g.display_name not in self.SYSTEM_GROUPS] logger.debug(f"Found {len(account_groups)} account groups") - return account_groups + return sorted(account_groups, key=lambda _: _.display_name) def _get_group(self, group_name, level: GroupLevel) -> iam.Group | None: - relevant_level_groups = self._workspace_groups if level == GroupLevel.WORKSPACE else self._account_groups + relevant_level_groups = self._workspace_groups if level == "workspace" else self._account_groups for group in relevant_level_groups: if group.display_name == group_name: return group def _get_or_create_backup_group(self, source_group_name: str, source_group: iam.Group) -> iam.Group: backup_group_name = f"{self.config.backup_group_prefix}{source_group_name}" - backup_group = self._get_group(backup_group_name, GroupLevel.WORKSPACE) + backup_group = self._get_group(backup_group_name, "workspace") if backup_group: logger.info(f"Backup group {backup_group_name} already exists, no action required") - else: - logger.info(f"Creating backup group {backup_group_name}") - backup_group = self._ws.groups.create( - display_name=backup_group_name, - meta=source_group.meta, - entitlements=source_group.entitlements, - roles=source_group.roles, - members=source_group.members, - ) - self._workspace_groups.append(backup_group) - logger.info(f"Backup group {backup_group_name} successfully created") + return backup_group + + logger.info(f"Creating backup group {backup_group_name}") + backup_group = self._ws.groups.create( + display_name=backup_group_name, + meta=source_group.meta, + entitlements=source_group.entitlements, + roles=source_group.roles, + members=source_group.members, + ) + self._workspace_groups.append(backup_group) + logger.info(f"Backup group {backup_group_name} successfully created") return backup_group def _set_migration_groups(self, groups_names: list[str]): def get_group_info(name: str): - ws_group = self._get_group(name, GroupLevel.WORKSPACE) + ws_group = self._get_group(name, "workspace") assert ws_group, f"Group {name} not found on the workspace level" - acc_group = self._get_group(name, GroupLevel.ACCOUNT) + acc_group = self._get_group(name, "account") assert acc_group, f"Group {name} not found on the account level" backup_group = self._get_or_create_backup_group(source_group_name=name, source_group=ws_group) return MigrationGroupInfo(workspace=ws_group, backup=backup_group, account=acc_group) - executables = [partial(get_group_info, group_name) for group_name in groups_names] - - collected_groups = ThreadedExecution[MigrationGroupInfo](executables).run() + collected_groups = ThreadedExecution.gather( + "get group info", [partial(get_group_info, group_name) for group_name in groups_names] + ) for g in collected_groups: self._migration_state.add(g) @@ -132,16 +156,17 @@ def _reflect_account_group_to_workspace(self, acc_group: iam.Group) -> None: # please keep the public methods below this line def prepare_groups_in_environment(self): - logger.info("Preparing groups in the current environment") - logger.info("At this step we'll verify that all groups exist and are of the correct type") - logger.info("If some temporary groups are missing, they'll be created") + logger.info( + "Preparing groups in the current environment. At this step we'll verify that all groups " + "exist and are of the correct type. If some temporary groups are missing, they'll be created" + ) if self.config.selected: logger.info("Using the provided group listing") for g in self.config.selected: assert g not in self.SYSTEM_GROUPS, f"Cannot migrate system group {g}" - assert self._get_group(g, GroupLevel.WORKSPACE), f"Group {g} not found on the workspace level" - assert self._get_group(g, GroupLevel.ACCOUNT), f"Group {g} not found on the account level" + assert self._get_group(g, "workspace"), f"Group {g} not found on the workspace level" + assert self._get_group(g, "account"), f"Group {g} not found on the account level" self._set_migration_groups(self.config.selected) else: @@ -157,17 +182,17 @@ def migration_groups_provider(self) -> GroupMigrationState: def replace_workspace_groups_with_account_groups(self): logger.info("Replacing the workspace groups with account-level groups") - logger.info(f"In total, {len(self.migration_groups_provider.groups)} group(s) to be replaced") - - executables = [ - partial(self._replace_group, migration_info) for migration_info in self.migration_groups_provider.groups - ] - ThreadedExecution(executables).run() + ThreadedExecution.gather( + "groups: workspace -> account", + [partial(self._replace_group, migration_info) for migration_info in self.migration_groups_provider.groups], + ) logger.info("Workspace groups were successfully replaced with account-level groups") def delete_backup_groups(self): - logger.info("Deleting the workspace-level backup groups") - logger.info(f"In total, {len(self.migration_groups_provider.groups)} group(s) to be deleted") + logger.info( + f"Deleting the workspace-level backup groups. " + f"In total, {len(self.migration_groups_provider.groups)} group(s) to be deleted" + ) for migration_info in self.migration_groups_provider.groups: try: diff --git a/src/databricks/labs/ucx/support/listing.py b/src/databricks/labs/ucx/workspace_access/listing.py similarity index 55% rename from src/databricks/labs/ucx/support/listing.py rename to src/databricks/labs/ucx/workspace_access/listing.py index eba98399d0..659edf54e3 100644 --- a/src/databricks/labs/ucx/support/listing.py +++ b/src/databricks/labs/ucx/workspace_access/listing.py @@ -5,12 +5,9 @@ from itertools import groupby from databricks.sdk import WorkspaceClient -from databricks.sdk.service import ml, workspace from databricks.sdk.service.workspace import ObjectInfo, ObjectType -from databricks.labs.ucx.inventory.types import RequestObjectType -from databricks.labs.ucx.providers.mixins.hardening import rate_limited -from databricks.labs.ucx.support.permissions import GenericPermissionsInfo +from databricks.labs.ucx.mixins.hardening import rate_limited logger = logging.getLogger(__name__) @@ -89,80 +86,9 @@ def walk(self, start_path="/"): new_futures[new_future] = directory futures_to_objects.update(new_futures) - logger.info(f"Recursive WorkspaceFS listing finished at {dt.datetime.now()}") - logger.info(f"Total time taken for workspace listing: {dt.datetime.now() - self.start_time}") + logger.info( + f"Recursive WorkspaceFS listing finished at {dt.datetime.now()}. " + f"Total time taken for workspace listing: {dt.datetime.now() - self.start_time}" + ) self._progress_report(None) return self.results - - -def models_listing(ws: WorkspaceClient): - def inner() -> Iterator[ml.ModelDatabricks]: - for model in ws.model_registry.list_models(): - model_with_id = ws.model_registry.get_model(model.name).registered_model_databricks - yield model_with_id - - return inner - - -def experiments_listing(ws: WorkspaceClient): - def inner() -> Iterator[ml.Experiment]: - for experiment in ws.experiments.list_experiments(): - """ - We filter-out notebook-based experiments, because they are covered by notebooks listing - """ - # workspace-based notebook experiment - if experiment.tags: - nb_tag = [t for t in experiment.tags if t.key == "mlflow.experimentType" and t.value == "NOTEBOOK"] - # repo-based notebook experiment - repo_nb_tag = [ - t for t in experiment.tags if t.key == "mlflow.experiment.sourceType" and t.value == "REPO_NOTEBOOK" - ] - if nb_tag or repo_nb_tag: - continue - - yield experiment - - return inner - - -def authorization_listing(): - def inner(): - for _value in ["passwords", "tokens"]: - yield GenericPermissionsInfo( - object_id=_value, - request_type=RequestObjectType.AUTHORIZATION, - ) - - return inner - - -def _convert_object_type_to_request_type(_object: workspace.ObjectInfo) -> RequestObjectType | None: - match _object.object_type: - case workspace.ObjectType.NOTEBOOK: - return RequestObjectType.NOTEBOOKS - case workspace.ObjectType.DIRECTORY: - return RequestObjectType.DIRECTORIES - case workspace.ObjectType.LIBRARY: - return None - case workspace.ObjectType.REPO: - return RequestObjectType.REPOS - case workspace.ObjectType.FILE: - return RequestObjectType.FILES - # silent handler for experiments - they'll be inventorized by the experiments manager - case None: - return None - - -def workspace_listing(ws: WorkspaceClient, num_threads=20, start_path: str | None = "/"): - def inner(): - ws_listing = WorkspaceListing( - ws, - num_threads=num_threads, - with_directories=False, - ) - for _object in ws_listing.walk(start_path): - request_type = _convert_object_type_to_request_type(_object) - if request_type: - yield GenericPermissionsInfo(object_id=str(_object.object_id), request_type=request_type) - - return inner diff --git a/src/databricks/labs/ucx/workspace_access/manager.py b/src/databricks/labs/ucx/workspace_access/manager.py new file mode 100644 index 0000000000..ca29e8dd33 --- /dev/null +++ b/src/databricks/labs/ucx/workspace_access/manager.py @@ -0,0 +1,90 @@ +import logging +from collections.abc import Callable, Iterator +from itertools import groupby +from typing import Literal + +from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend +from databricks.labs.ucx.framework.parallel import ThreadedExecution +from databricks.labs.ucx.workspace_access.base import Applier, Crawler, Permissions +from databricks.labs.ucx.workspace_access.groups import GroupMigrationState + +logger = logging.getLogger(__name__) + + +class PermissionManager(CrawlerBase): + def __init__( + self, backend: SqlBackend, inventory_database: str, crawlers: list[Crawler], appliers: dict[str, Applier] + ): + super().__init__(backend, "hive_metastore", inventory_database, "permissions") + self._crawlers = crawlers + self._appliers = appliers + + def inventorize_permissions(self): + logger.debug("Crawling permissions") + crawler_tasks = list(self._get_crawler_tasks()) + logger.info(f"Starting to crawl permissions. Total tasks: {len(crawler_tasks)}") + results = ThreadedExecution.gather("crawl permissions", crawler_tasks) + items = [] + for item in results: + if item is None: + continue + if item.object_type not in self._appliers: + msg = f"unknown object_type: {item.object_type}" + raise KeyError(msg) + items.append(item) + logger.info(f"Total crawled permissions after filtering: {len(items)}") + self._save(items) + logger.info(f"Saved {len(items)} to {self._full_name}") + + def apply_group_permissions(self, migration_state: GroupMigrationState, destination: Literal["backup", "account"]): + # list shall be sorted prior to using group by + items = sorted(self._load_all(), key=lambda i: i.object_type) + logger.info( + f"Applying the permissions to {destination} groups. " + f"Total groups to apply permissions: {len(migration_state.groups)}. " + f"Total permissions found: {len(items)}" + ) + applier_tasks = [] + supports_to_items = { + support: list(items_subset) for support, items_subset in groupby(items, key=lambda i: i.object_type) + } + + # we first check that all supports are valid. + for object_type in supports_to_items: + if object_type not in self._appliers: + msg = f"Could not find support for {object_type}. Please check the inventory table." + raise ValueError(msg) + + for object_type, items_subset in supports_to_items.items(): + relevant_support = self._appliers[object_type] + tasks_for_support = [ + relevant_support.get_apply_task(item, migration_state, destination) for item in items_subset + ] + logger.info(f"Total tasks for {object_type}: {len(tasks_for_support)}") + applier_tasks.extend(tasks_for_support) + + logger.info(f"Starting to apply permissions on {destination} groups. Total tasks: {len(applier_tasks)}") + ThreadedExecution.gather(f"apply {destination} group permissions", applier_tasks) + logger.info("Permissions were applied") + + def cleanup(self): + logger.info(f"Cleaning up inventory table {self._full_name}") + self._exec(f"DROP TABLE IF EXISTS {self._full_name}") + logger.info("Inventory table cleanup complete") + + def _save(self, items: list[Permissions]): + # TODO: update instead of append + logger.info(f"Saving {len(items)} items to {self._full_name}") + self._append_records(Permissions, items) + logger.info("Successfully saved the items to inventory table") + + def _load_all(self) -> list[Permissions]: + logger.info(f"Loading inventory table {self._full_name}") + return [ + Permissions(object_id, object_type, raw) + for object_id, object_type, raw in self._fetch(f"SELECT object_id, object_type, raw FROM {self._full_name}") + ] + + def _get_crawler_tasks(self) -> Iterator[Callable[..., Permissions | None]]: + for support in self._crawlers: + yield from support.get_crawler_tasks() diff --git a/src/databricks/labs/ucx/workspace_access/migration.py b/src/databricks/labs/ucx/workspace_access/migration.py new file mode 100644 index 0000000000..ff2a5c13ec --- /dev/null +++ b/src/databricks/labs/ucx/workspace_access/migration.py @@ -0,0 +1,146 @@ +import logging + +from databricks.sdk import WorkspaceClient +from databricks.sdk.service import sql + +from databricks.labs.ucx.config import MigrationConfig +from databricks.labs.ucx.framework.crawlers import ( + RuntimeBackend, + SqlBackend, + StatementExecutionBackend, +) +from databricks.labs.ucx.workspace_access.generic import ( + GenericPermissionsSupport, + authorization_listing, + experiments_listing, + listing_wrapper, + models_listing, + workspace_listing, +) +from databricks.labs.ucx.workspace_access.groups import GroupManager +from databricks.labs.ucx.workspace_access.manager import PermissionManager +from databricks.labs.ucx.workspace_access.redash import ( + SqlPermissionsSupport, + redash_listing_wrapper, +) +from databricks.labs.ucx.workspace_access.scim import ScimSupport +from databricks.labs.ucx.workspace_access.secrets import SecretScopesSupport +from databricks.labs.ucx.workspace_access.verification import VerificationManager + + +class GroupMigrationToolkit: + def __init__(self, config: MigrationConfig, *, warehouse_id=None): + self._configure_logger(config.log_level) + + ws = WorkspaceClient(config=config.to_databricks_config()) + ws.api_client._session.adapters["https://"].max_retries.total = 20 + self._verify_ws_client(ws) + self._ws = ws # TODO: remove this once notebooks/toolkit.py is removed + + generic_acl_listing = [ + listing_wrapper(ws.clusters.list, "cluster_id", "clusters"), + listing_wrapper(ws.cluster_policies.list, "policy_id", "cluster-policies"), + listing_wrapper(ws.instance_pools.list, "instance_pool_id", "instance-pools"), + listing_wrapper(ws.warehouses.list, "id", "sql/warehouses"), + listing_wrapper(ws.jobs.list, "job_id", "jobs"), + listing_wrapper(ws.pipelines.list_pipelines, "pipeline_id", "pipelines"), + listing_wrapper(experiments_listing(ws), "experiment_id", "experiments"), + listing_wrapper(models_listing(ws), "id", "registered-models"), + workspace_listing(ws, num_threads=config.num_threads, start_path=config.workspace_start_path), + authorization_listing(), + ] + redash_acl_listing = [ + redash_listing_wrapper(ws.alerts.list, sql.ObjectTypePlural.ALERTS), + redash_listing_wrapper(ws.dashboards.list, sql.ObjectTypePlural.DASHBOARDS), + redash_listing_wrapper(ws.queries.list, sql.ObjectTypePlural.QUERIES), + ] + generic_support = GenericPermissionsSupport(ws, generic_acl_listing) + sql_support = SqlPermissionsSupport(ws, redash_acl_listing) + secrets_support = SecretScopesSupport(ws) + scim_support = ScimSupport(ws) + self._permissions_manager = PermissionManager( + self._backend(ws, warehouse_id), + config.inventory_database, + [generic_support, sql_support, secrets_support, scim_support], + self._object_type_appliers(generic_support, sql_support, secrets_support, scim_support), + ) + self._group_manager = GroupManager(ws, config.groups) + self._verification_manager = VerificationManager(ws, secrets_support) + + @staticmethod + def _object_type_appliers(generic_support, sql_support, secrets_support, scim_support): + return { + # SCIM-based API + "entitlements": scim_support, + "roles": scim_support, + # Generic Permissions API + "authorization": generic_support, + "clusters": generic_support, + "cluster-policies": generic_support, + "instance-pools": generic_support, + "sql/warehouses": generic_support, + "jobs": generic_support, + "pipelines": generic_support, + "experiments": generic_support, + "registered-models": generic_support, + "notebooks": generic_support, + "files": generic_support, + "directories": generic_support, + "repos": generic_support, + # Redash equivalent of Generic Permissions API + "alerts": sql_support, + "queries": sql_support, + "dashboards": sql_support, + # Secret Scope ACL API + "secrets": secrets_support, + } + + @staticmethod + def _backend(ws: WorkspaceClient, warehouse_id: str | None = None) -> SqlBackend: + if warehouse_id is None: + return RuntimeBackend() + return StatementExecutionBackend(ws, warehouse_id) + + @staticmethod + def _verify_ws_client(w: WorkspaceClient): + _me = w.current_user.me() + is_workspace_admin = any(g.display == "admins" for g in _me.groups) + if not is_workspace_admin: + msg = "Current user is not a workspace admin" + raise RuntimeError(msg) + + @staticmethod + def _configure_logger(level: str): + ucx_logger = logging.getLogger("databricks.labs.ucx") + ucx_logger.setLevel(level) + + def prepare_environment(self): + self._group_manager.prepare_groups_in_environment() + + def cleanup_inventory_table(self): + self._permissions_manager.cleanup() + + def inventorize_permissions(self): + self._permissions_manager.inventorize_permissions() + + def apply_permissions_to_backup_groups(self): + self._permissions_manager.apply_group_permissions( + self._group_manager.migration_groups_provider, destination="backup" + ) + + def verify_permissions_on_backup_groups(self, to_verify): + self._verification_manager.verify(self._group_manager.migration_groups_provider, "backup", to_verify) + + def replace_workspace_groups_with_account_groups(self): + self._group_manager.replace_workspace_groups_with_account_groups() + + def apply_permissions_to_account_groups(self): + self._permissions_manager.apply_group_permissions( + self._group_manager.migration_groups_provider, destination="account" + ) + + def verify_permissions_on_account_groups(self, to_verify): + self._verification_manager.verify(self._group_manager.migration_groups_provider, "account", to_verify) + + def delete_backup_groups(self): + self._group_manager.delete_backup_groups() diff --git a/src/databricks/labs/ucx/support/sql.py b/src/databricks/labs/ucx/workspace_access/redash.py similarity index 72% rename from src/databricks/labs/ucx/support/sql.py rename to src/databricks/labs/ucx/workspace_access/redash.py index 8a2a04c520..51bd1c2c19 100644 --- a/src/databricks/labs/ucx/support/sql.py +++ b/src/databricks/labs/ucx/workspace_access/redash.py @@ -8,10 +8,15 @@ from databricks.sdk.core import DatabricksError from databricks.sdk.service import iam, sql -from databricks.labs.ucx.inventory.types import Destination, PermissionsInventoryItem -from databricks.labs.ucx.providers.groups_info import GroupMigrationState -from databricks.labs.ucx.providers.mixins.hardening import rate_limited -from databricks.labs.ucx.support.base import BaseSupport, logger +from databricks.labs.ucx.mixins.hardening import rate_limited +from databricks.labs.ucx.workspace_access.base import ( + Applier, + Crawler, + Destination, + Permissions, + logger, +) +from databricks.labs.ucx.workspace_access.groups import GroupMigrationState @dataclass @@ -20,21 +25,37 @@ class SqlPermissionsInfo: request_type: sql.ObjectTypePlural -class SqlPermissionsSupport(BaseSupport): - def is_item_relevant(self, item: PermissionsInventoryItem, migration_state: GroupMigrationState) -> bool: +# This module is called redash to disambiguate from databricks.sdk.service.sql + + +class SqlPermissionsSupport(Crawler, Applier): + def __init__(self, ws: WorkspaceClient, listings: list[Callable[..., list[SqlPermissionsInfo]]]): + self._ws = ws + self._listings = listings + + def is_item_relevant(self, item: Permissions, migration_state: GroupMigrationState) -> bool: mentioned_groups = [ - acl.group_name - for acl in sql.GetResponse.from_dict(json.loads(item.raw_object_permissions)).access_control_list + acl.group_name for acl in sql.GetResponse.from_dict(json.loads(item.raw)).access_control_list ] return any(g in mentioned_groups for g in [info.workspace.display_name for info in migration_state.groups]) - def __init__( - self, - ws: WorkspaceClient, - listings: list[Callable[..., list[SqlPermissionsInfo]]], - ): - super().__init__(ws) - self._listings = listings + def get_crawler_tasks(self): + for listing in self._listings: + for item in listing(): + yield partial(self._crawler_task, item.object_id, item.request_type) + + def _get_apply_task(self, item: Permissions, migration_state: GroupMigrationState, destination: Destination): + new_acl = self._prepare_new_acl( + sql.GetResponse.from_dict(json.loads(item.raw)).access_control_list, + migration_state, + destination, + ) + return partial( + self._applier_task, + object_type=sql.ObjectTypePlural(item.object_type), + object_id=item.object_id, + acl=new_acl, + ) def _safe_get_dbsql_permissions(self, object_type: sql.ObjectTypePlural, object_id: str) -> sql.GetResponse | None: try: @@ -47,13 +68,13 @@ def _safe_get_dbsql_permissions(self, object_type: sql.ObjectTypePlural, object_ raise e @rate_limited(max_requests=100) - def _crawler_task(self, object_id: str, object_type: sql.ObjectTypePlural) -> PermissionsInventoryItem | None: + def _crawler_task(self, object_id: str, object_type: sql.ObjectTypePlural) -> Permissions | None: permissions = self._safe_get_dbsql_permissions(object_type=object_type, object_id=object_id) if permissions: - return PermissionsInventoryItem( + return Permissions( object_id=object_id, - support=object_type.value, - raw_object_permissions=json.dumps(permissions.as_dict()), + object_type=object_type.value, + raw=json.dumps(permissions.as_dict()), ) @rate_limited(max_requests=30) @@ -64,11 +85,6 @@ def _applier_task(self, object_type: sql.ObjectTypePlural, object_id: str, acl: """ self._ws.dbsql_permissions.set(object_type=object_type, object_id=object_id, access_control_list=acl) - def get_crawler_tasks(self): - for listing in self._listings: - for item in listing(): - yield partial(self._crawler_task, item.object_id, item.request_type) - def _prepare_new_acl( self, acl: list[sql.AccessControl], migration_state: GroupMigrationState, destination: Destination ) -> list[sql.AccessControl]: @@ -92,20 +108,8 @@ def _prepare_new_acl( return acl_requests - def _get_apply_task( - self, item: PermissionsInventoryItem, migration_state: GroupMigrationState, destination: Destination - ): - new_acl = self._prepare_new_acl( - sql.GetResponse.from_dict(json.loads(item.raw_object_permissions)).access_control_list, - migration_state, - destination, - ) - return partial( - self._applier_task, object_type=sql.ObjectTypePlural(item.support), object_id=item.object_id, acl=new_acl - ) - -def listing_wrapper( +def redash_listing_wrapper( func: Callable[..., list], object_type: sql.ObjectTypePlural ) -> Callable[..., list[SqlPermissionsInfo]]: def wrapper() -> list[SqlPermissionsInfo]: diff --git a/src/databricks/labs/ucx/support/group_level.py b/src/databricks/labs/ucx/workspace_access/scim.py similarity index 65% rename from src/databricks/labs/ucx/support/group_level.py rename to src/databricks/labs/ucx/workspace_access/scim.py index 00a538100e..2830704765 100644 --- a/src/databricks/labs/ucx/support/group_level.py +++ b/src/databricks/labs/ucx/workspace_access/scim.py @@ -1,29 +1,24 @@ import json from functools import partial +from databricks.sdk import WorkspaceClient from databricks.sdk.service import iam -from databricks.labs.ucx.inventory.types import Destination, PermissionsInventoryItem -from databricks.labs.ucx.providers.groups_info import GroupMigrationState -from databricks.labs.ucx.providers.mixins.hardening import rate_limited -from databricks.labs.ucx.support.base import BaseSupport +from databricks.labs.ucx.mixins.hardening import rate_limited +from databricks.labs.ucx.workspace_access.base import ( + Applier, + Crawler, + Destination, + Permissions, +) +from databricks.labs.ucx.workspace_access.groups import GroupMigrationState -class ScimSupport(BaseSupport): - def _crawler_task(self, group: iam.Group, property_name: str): - return PermissionsInventoryItem( - object_id=group.id, - support=property_name, - raw_object_permissions=json.dumps([e.as_dict() for e in getattr(group, property_name)]), - ) - - @rate_limited(max_requests=10) - def _applier_task(self, group_id: str, value: list[iam.ComplexValue], property_name: str): - operations = [iam.Patch(op=iam.PatchOp.ADD, path=property_name, value=[e.as_dict() for e in value])] - schemas = [iam.PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP] - self._ws.groups.patch(id=group_id, operations=operations, schemas=schemas) +class ScimSupport(Crawler, Applier): + def __init__(self, ws: WorkspaceClient): + self._ws = ws - def is_item_relevant(self, item: PermissionsInventoryItem, migration_state: GroupMigrationState) -> bool: + def is_item_relevant(self, item: Permissions, migration_state: GroupMigrationState) -> bool: return any(g.workspace.id == item.object_id for g in migration_state.groups) def get_crawler_tasks(self): @@ -35,14 +30,25 @@ def get_crawler_tasks(self): for g in with_entitlements: yield partial(self._crawler_task, g, "entitlements") - def _get_apply_task( - self, item: PermissionsInventoryItem, migration_state: GroupMigrationState, destination: Destination - ): - value = [iam.ComplexValue.from_dict(e) for e in json.loads(item.raw_object_permissions)] + def _get_apply_task(self, item: Permissions, migration_state: GroupMigrationState, destination: Destination): + value = [iam.ComplexValue.from_dict(e) for e in json.loads(item.raw)] target_info = [g for g in migration_state.groups if g.workspace.id == item.object_id] if len(target_info) == 0: msg = f"Could not find group with ID {item.object_id}" raise ValueError(msg) else: target_group_id = getattr(target_info[0], destination).id - return partial(self._applier_task, group_id=target_group_id, value=value, property_name=item.support) + return partial(self._applier_task, group_id=target_group_id, value=value, property_name=item.object_type) + + def _crawler_task(self, group: iam.Group, property_name: str): + return Permissions( + object_id=group.id, + object_type=property_name, + raw=json.dumps([e.as_dict() for e in getattr(group, property_name)]), + ) + + @rate_limited(max_requests=10) + def _applier_task(self, group_id: str, value: list[iam.ComplexValue], property_name: str): + operations = [iam.Patch(op=iam.PatchOp.ADD, path=property_name, value=[e.as_dict() for e in value])] + schemas = [iam.PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP] + self._ws.groups.patch(id=group_id, operations=operations, schemas=schemas) diff --git a/src/databricks/labs/ucx/support/secrets.py b/src/databricks/labs/ucx/workspace_access/secrets.py similarity index 79% rename from src/databricks/labs/ucx/support/secrets.py rename to src/databricks/labs/ucx/workspace_access/secrets.py index 24c9b436d4..01d1d9cbca 100644 --- a/src/databricks/labs/ucx/support/secrets.py +++ b/src/databricks/labs/ucx/workspace_access/secrets.py @@ -6,32 +6,36 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.service import iam, workspace -from databricks.labs.ucx.inventory.types import Destination, PermissionsInventoryItem -from databricks.labs.ucx.providers.groups_info import GroupMigrationState -from databricks.labs.ucx.providers.mixins.hardening import rate_limited -from databricks.labs.ucx.support.base import BaseSupport +from databricks.labs.ucx.mixins.hardening import rate_limited +from databricks.labs.ucx.workspace_access.base import ( + Applier, + Crawler, + Destination, + Permissions, +) +from databricks.labs.ucx.workspace_access.groups import GroupMigrationState -class SecretScopesSupport(BaseSupport): +class SecretScopesSupport(Crawler, Applier): def __init__(self, ws: WorkspaceClient): - super().__init__(ws=ws) + self._ws = ws def get_crawler_tasks(self): scopes = self._ws.secrets.list_scopes() def _crawler_task(scope: workspace.SecretScope): acl_items = self._ws.secrets.list_acls(scope.name) - return PermissionsInventoryItem( + return Permissions( object_id=scope.name, - support="secrets", - raw_object_permissions=json.dumps([item.as_dict() for item in acl_items]), + object_type="secrets", + raw=json.dumps([item.as_dict() for item in acl_items]), ) for scope in scopes: yield partial(_crawler_task, scope) - def is_item_relevant(self, item: PermissionsInventoryItem, migration_state: GroupMigrationState) -> bool: - acls = [workspace.AclItem.from_dict(acl) for acl in json.loads(item.raw_object_permissions)] + def is_item_relevant(self, item: Permissions, migration_state: GroupMigrationState) -> bool: + acls = [workspace.AclItem.from_dict(acl) for acl in json.loads(item.raw)] mentioned_groups = [acl.principal for acl in acls] return any(g in mentioned_groups for g in [info.workspace.display_name for info in migration_state.groups]) @@ -72,9 +76,9 @@ def _rate_limited_put_acl(self, object_id: str, principal: str, permission: work self._inflight_check(scope_name=object_id, group_name=principal, expected_permission=permission) def _get_apply_task( - self, item: PermissionsInventoryItem, migration_state: GroupMigrationState, destination: Destination + self, item: Permissions, migration_state: GroupMigrationState, destination: Destination ) -> partial: - acls = [workspace.AclItem.from_dict(acl) for acl in json.loads(item.raw_object_permissions)] + acls = [workspace.AclItem.from_dict(acl) for acl in json.loads(item.raw)] new_acls = [] for acl in acls: diff --git a/src/databricks/labs/ucx/inventory/verification.py b/src/databricks/labs/ucx/workspace_access/verification.py similarity index 95% rename from src/databricks/labs/ucx/inventory/verification.py rename to src/databricks/labs/ucx/workspace_access/verification.py index 00b7be758c..f8fb379dfa 100644 --- a/src/databricks/labs/ucx/inventory/verification.py +++ b/src/databricks/labs/ucx/workspace_access/verification.py @@ -2,8 +2,8 @@ from databricks.sdk import WorkspaceClient -from databricks.labs.ucx.providers.groups_info import GroupMigrationState -from databricks.labs.ucx.support.secrets import SecretScopesSupport +from databricks.labs.ucx.workspace_access.groups import GroupMigrationState +from databricks.labs.ucx.workspace_access.secrets import SecretScopesSupport class VerificationManager: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d0e6da2c58..dfb6d38ae9 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -8,16 +8,14 @@ from databricks.sdk import AccountClient, WorkspaceClient from databricks.sdk.core import Config -from databricks.labs.ucx.providers.mixins.fixtures import * # noqa: F403 -from databricks.labs.ucx.providers.mixins.sql import StatementExecutionExt -from databricks.labs.ucx.utils import ThreadedExecution +from databricks.labs.ucx.mixins.fixtures import * # noqa: F403 +from databricks.labs.ucx.mixins.sql import StatementExecutionExt logging.getLogger("tests").setLevel("DEBUG") logging.getLogger("databricks.labs.ucx").setLevel("DEBUG") logger = logging.getLogger(__name__) -Threader = partial(ThreadedExecution, num_threads=20) load_debug_env_if_runs_from_ide("ucws") # noqa: F405 diff --git a/tests/integration/test_assessment.py b/tests/integration/test_assessment.py index 327cd2cf50..f21f98b290 100644 --- a/tests/integration/test_assessment.py +++ b/tests/integration/test_assessment.py @@ -1,6 +1,6 @@ import pytest -from databricks.labs.ucx.toolkits.assessment import AssessmentToolkit +from databricks.labs.ucx.assessment.assessment import AssessmentToolkit def test_table_inventory(ws, make_catalog, make_schema): diff --git a/tests/integration/test_dashboards.py b/tests/integration/test_dashboards.py index 7e6dce7b27..9e12767faf 100644 --- a/tests/integration/test_dashboards.py +++ b/tests/integration/test_dashboards.py @@ -4,7 +4,7 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.service.sql import AccessControl, ObjectTypePlural, PermissionLevel -from databricks.labs.ucx.providers.mixins.redash import ( +from databricks.labs.ucx.mixins.redash import ( DashboardWidgetsAPI, QueryVisualizationsExt, VizColumn, diff --git a/tests/integration/test_fixtures.py b/tests/integration/test_fixtures.py index 167c54c55b..cefa9b389b 100644 --- a/tests/integration/test_fixtures.py +++ b/tests/integration/test_fixtures.py @@ -3,7 +3,7 @@ from databricks.sdk.service.workspace import AclPermission -from databricks.labs.ucx.providers.mixins.fixtures import * # noqa: F403 +from databricks.labs.ucx.mixins.fixtures import * # noqa: F403 load_debug_env_if_runs_from_ide("ucws") # noqa: F405 diff --git a/tests/integration/test_group.py b/tests/integration/test_groups.py similarity index 51% rename from tests/integration/test_group.py rename to tests/integration/test_groups.py index 529aec427f..d6c987ac49 100644 --- a/tests/integration/test_group.py +++ b/tests/integration/test_groups.py @@ -1,7 +1,22 @@ from databricks.sdk import WorkspaceClient from databricks.labs.ucx.config import GroupsConfig -from databricks.labs.ucx.managers.group import GroupLevel, GroupManager +from databricks.labs.ucx.workspace_access.groups import GroupManager + + +def test_prepare_environment(ws, make_ucx_group): + ws_group, acc_group = make_ucx_group() + + group_manager = GroupManager(ws, GroupsConfig(selected=[ws_group.display_name])) + group_manager.prepare_groups_in_environment() + + group_migration_state = group_manager.migration_groups_provider + for _info in group_migration_state.groups: + _ws = ws.groups.get(id=_info.workspace.id) + _backup = ws.groups.get(id=_info.backup.id) + _ws_members = sorted([m.value for m in _ws.members]) + _backup_members = sorted([m.value for m in _backup.members]) + assert _ws_members == _backup_members def test_group_listing(ws: WorkspaceClient, make_ucx_group): @@ -14,5 +29,5 @@ def test_group_listing(ws: WorkspaceClient, make_ucx_group): def test_id_validity(ws: WorkspaceClient, make_ucx_group): ws_group, acc_group = make_ucx_group() manager = GroupManager(ws, GroupsConfig(selected=[ws_group.display_name])) - assert ws_group.id == manager._get_group(ws_group.display_name, GroupLevel.WORKSPACE).id - assert acc_group.id == manager._get_group(acc_group.display_name, GroupLevel.ACCOUNT).id + assert ws_group.id == manager._get_group(ws_group.display_name, "workspace").id + assert acc_group.id == manager._get_group(acc_group.display_name, "account").id diff --git a/tests/integration/test_installation.py b/tests/integration/test_installation.py index 31137b003f..2ea3bda57a 100644 --- a/tests/integration/test_installation.py +++ b/tests/integration/test_installation.py @@ -13,11 +13,10 @@ from databricks.sdk.service.workspace import ImportFormat from databricks.labs.ucx.config import GroupsConfig, MigrationConfig, TaclConfig +from databricks.labs.ucx.hive_metastore.grants import Grant +from databricks.labs.ucx.hive_metastore.tables import Table from databricks.labs.ucx.install import Installer -from databricks.labs.ucx.inventory.types import RequestObjectType -from databricks.labs.ucx.providers.mixins.compute import CommandExecutor -from databricks.labs.ucx.tacl.grants import Grant -from databricks.labs.ucx.tacl.tables import Table +from databricks.labs.ucx.mixins.compute import CommandExecutor logger = logging.getLogger(__name__) @@ -91,7 +90,7 @@ def test_sql_backend_works(ws, wsfs_wheel): commands.install_notebook_library(f"/Workspace{wsfs_wheel}") database_names = commands.run( """ - from databricks.labs.ucx.tacl._internal import RuntimeBackend + from databricks.labs.ucx.framework.crawlers import RuntimeBackend backend = RuntimeBackend() return backend.fetch("SHOW DATABASES") """ @@ -215,7 +214,7 @@ def test_toolkit_notebook( permission_level=random.choice([PermissionLevel.CAN_USE]), group_name=ws_group_a.display_name, ) - cpp_src = ws.permissions.get(RequestObjectType.CLUSTER_POLICIES, cluster_policy.policy_id) + cpp_src = ws.permissions.get("cluster-policies", cluster_policy.policy_id) cluster_policy_src_permissions = sorted( [_ for _ in cpp_src.access_control_list if _.group_name == ws_group_a.display_name], key=lambda p: p.group_name, @@ -228,7 +227,7 @@ def test_toolkit_notebook( ), group_name=ws_group_b.display_name, ) - jp_src = ws.permissions.get(RequestObjectType.JOBS, job.job_id) + jp_src = ws.permissions.get("jobs", job.job_id) job_src_permissions = sorted( [_ for _ in jp_src.access_control_list if _.group_name == ws_group_b.display_name], key=lambda p: p.group_name, @@ -334,7 +333,7 @@ def test_toolkit_notebook( logger.info("validating permissions") - cp_dst = ws.permissions.get(RequestObjectType.CLUSTER_POLICIES, cluster_policy.policy_id) + cp_dst = ws.permissions.get("cluster-policies", cluster_policy.policy_id) cluster_policy_dst_permissions = sorted( [_ for _ in cp_dst.access_control_list if _.group_name == ws_group_a.display_name], key=lambda p: p.group_name, @@ -346,17 +345,17 @@ def test_toolkit_notebook( s.all_permissions for s in cluster_policy_src_permissions ], "Target permissions were not applied correctly for cluster policies" - jp_dst = ws.permissions.get(RequestObjectType.JOBS, job.job_id) + jp_dst = ws.permissions.get("jobs", job.job_id) job_dst_permissions = sorted( [_ for _ in jp_dst.access_control_list if _.group_name == ws_group_b.display_name], key=lambda p: p.group_name, ) assert len(job_dst_permissions) == len( job_src_permissions - ), f"Target permissions were not applied correctly for {RequestObjectType.JOBS}/{job.job_id}" + ), f"Target permissions were not applied correctly for jobs/{job.job_id}" assert [t.all_permissions for t in job_dst_permissions] == [ s.all_permissions for s in job_src_permissions - ], f"Target permissions were not applied correctly for {RequestObjectType.JOBS}/{job.job_id}" + ], f"Target permissions were not applied correctly for jobs/{job.job_id}" logger.info("validating tacl") diff --git a/tests/integration/test_permissions.py b/tests/integration/test_permissions.py deleted file mode 100644 index 8d1692b49f..0000000000 --- a/tests/integration/test_permissions.py +++ /dev/null @@ -1,23 +0,0 @@ -import os - -from databricks.labs.ucx.inventory.permissions_inventory import ( - PermissionsInventoryTable, -) -from databricks.labs.ucx.inventory.types import PermissionsInventoryItem -from databricks.labs.ucx.tacl._internal import StatementExecutionBackend - - -def test_permissions_save_and_load(ws, make_schema): - schema = make_schema().split(".")[-1] - backend = StatementExecutionBackend(ws, os.environ["TEST_DEFAULT_WAREHOUSE_ID"]) - pi = PermissionsInventoryTable(backend, schema) - - saved = [ - PermissionsInventoryItem(object_id="abc", support="bcd", raw_object_permissions="def"), - PermissionsInventoryItem(object_id="efg", support="fgh", raw_object_permissions="ghi"), - ] - - pi.save(saved) - loaded = pi.load_all() - - assert saved == loaded diff --git a/tests/integration/test_permissions_manager.py b/tests/integration/test_permissions_manager.py new file mode 100644 index 0000000000..272b6d6e46 --- /dev/null +++ b/tests/integration/test_permissions_manager.py @@ -0,0 +1,21 @@ +import os + +from databricks.labs.ucx.framework.crawlers import StatementExecutionBackend +from databricks.labs.ucx.workspace_access.base import Permissions +from databricks.labs.ucx.workspace_access.manager import PermissionManager + + +def test_permissions_save_and_load(ws, make_schema): + schema = make_schema().split(".")[-1] + backend = StatementExecutionBackend(ws, os.environ["TEST_DEFAULT_WAREHOUSE_ID"]) + pi = PermissionManager(backend, schema, [], {}) + + saved = [ + Permissions(object_id="abc", object_type="bcd", raw="def"), + Permissions(object_id="efg", object_type="fgh", raw="ghi"), + ] + + pi._save(saved) + loaded = pi._load_all() + + assert saved == loaded diff --git a/tests/integration/test_setup.py b/tests/integration/test_setup.py index fee02461cc..70c1d671cd 100644 --- a/tests/integration/test_setup.py +++ b/tests/integration/test_setup.py @@ -5,7 +5,7 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.service.iam import ComplexValue -from databricks.labs.ucx.utils import ThreadedExecution +from databricks.labs.ucx.framework.parallel import ThreadedExecution logger = logging.getLogger(__name__) Threader = partial(ThreadedExecution, num_threads=40) diff --git a/tests/integration/test_tacls.py b/tests/integration/test_tacls.py index 0422677d1b..2bc4da4939 100644 --- a/tests/integration/test_tacls.py +++ b/tests/integration/test_tacls.py @@ -3,7 +3,7 @@ from databricks.sdk import WorkspaceClient -from databricks.labs.ucx.toolkits.table_acls import TaclToolkit +from databricks.labs.ucx.hive_metastore.table_acls import TaclToolkit logger = logging.getLogger(__name__) diff --git a/tests/integration/test_e2e.py b/tests/integration/test_workspace_access.py similarity index 86% rename from tests/integration/test_e2e.py rename to tests/integration/test_workspace_access.py index 0cc1a0e74b..fd8ac367bc 100644 --- a/tests/integration/test_e2e.py +++ b/tests/integration/test_workspace_access.py @@ -12,13 +12,12 @@ MigrationConfig, TaclConfig, ) -from databricks.labs.ucx.inventory.types import RequestObjectType -from databricks.labs.ucx.toolkits.group_migration import GroupMigrationToolkit +from databricks.labs.ucx.workspace_access import GroupMigrationToolkit logger = logging.getLogger(__name__) -def test_e2e( +def test_workspace_access_e2e( ws: WorkspaceClient, make_schema, make_ucx_group, @@ -56,7 +55,7 @@ def test_e2e( permission_level=random.choice([PermissionLevel.CAN_ATTACH_TO, PermissionLevel.CAN_MANAGE]), group_name=ws_group.display_name, ) - to_verify.add((RequestObjectType.INSTANCE_POOLS, pool.instance_pool_id)) + to_verify.add(("instance-pools", pool.instance_pool_id)) cluster = make_cluster(instance_pool_id=os.environ["TEST_INSTANCE_POOL_ID"], single_node=True) make_cluster_permissions( @@ -66,7 +65,7 @@ def test_e2e( ), group_name=ws_group.display_name, ) - to_verify.add((RequestObjectType.CLUSTERS, cluster.cluster_id)) + to_verify.add(("clusters", cluster.cluster_id)) cluster_policy = make_cluster_policy() make_cluster_policy_permissions( @@ -74,7 +73,7 @@ def test_e2e( permission_level=random.choice([PermissionLevel.CAN_USE]), group_name=ws_group.display_name, ) - to_verify.add((RequestObjectType.CLUSTER_POLICIES, cluster_policy.policy_id)) + to_verify.add(("cluster-policies", cluster_policy.policy_id)) model = make_model() make_registered_model_permissions( @@ -89,7 +88,7 @@ def test_e2e( ), group_name=ws_group.display_name, ) - to_verify.add((RequestObjectType.REGISTERED_MODELS, model.id)) + to_verify.add(("registered-models", model.id)) experiment = make_experiment() make_experiment_permissions( @@ -99,7 +98,7 @@ def test_e2e( ), group_name=ws_group.display_name, ) - to_verify.add((RequestObjectType.EXPERIMENTS, experiment.experiment_id)) + to_verify.add(("experiments", experiment.experiment_id)) directory = make_directory() make_directory_permissions( @@ -109,7 +108,7 @@ def test_e2e( ), group_name=ws_group.display_name, ) - to_verify.add((RequestObjectType.DIRECTORIES, ws.workspace.get_status(directory).object_id)) + to_verify.add(("directories", ws.workspace.get_status(directory).object_id)) notebook = make_notebook(path=f"{directory}/sample.py") make_notebook_permissions( @@ -119,7 +118,7 @@ def test_e2e( ), group_name=ws_group.display_name, ) - to_verify.add((RequestObjectType.NOTEBOOKS, ws.workspace.get_status(notebook).object_id)) + to_verify.add(("notebooks", ws.workspace.get_status(notebook).object_id)) job = make_job() make_job_permissions( @@ -129,7 +128,7 @@ def test_e2e( ), group_name=ws_group.display_name, ) - to_verify.add((RequestObjectType.JOBS, job.job_id)) + to_verify.add(("jobs", job.job_id)) pipeline = make_pipeline() make_pipeline_permissions( @@ -137,7 +136,7 @@ def test_e2e( permission_level=random.choice([PermissionLevel.CAN_VIEW, PermissionLevel.CAN_RUN, PermissionLevel.CAN_MANAGE]), group_name=ws_group.display_name, ) - to_verify.add((RequestObjectType.PIPELINES, pipeline.pipeline_id)) + to_verify.add(("pipelines", pipeline.pipeline_id)) scope = make_secret_scope() make_secret_scope_acl(scope=scope, principal=ws_group.display_name, permission=workspace.AclPermission.WRITE) @@ -148,7 +147,7 @@ def test_e2e( permission_level=PermissionLevel.CAN_USE, group_name=ws_group.display_name, ) - to_verify.add((RequestObjectType.AUTHORIZATION, "tokens")) + to_verify.add(("authorization", "tokens")) warehouse = make_warehouse() make_warehouse_permissions( @@ -156,7 +155,7 @@ def test_e2e( permission_level=random.choice([PermissionLevel.CAN_USE, PermissionLevel.CAN_MANAGE]), group_name=ws_group.display_name, ) - to_verify.add((RequestObjectType.SQL_WAREHOUSES, warehouse.id)) + to_verify.add(("sql/warehouses", warehouse.id)) config = MigrationConfig( connect=ConnectConfig.from_databricks_config(ws.config), @@ -165,6 +164,7 @@ def test_e2e( workspace_start_path=directory, tacl=TaclConfig(auto=True), log_level="DEBUG", + num_threads=8, ) warehouse_id = os.environ["TEST_DEFAULT_WAREHOUSE_ID"] diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index e581930fef..ca49282964 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -1,6 +1,6 @@ import logging -from databricks.labs.ucx.logger import _install +from databricks.labs.ucx.framework.logger import _install _install() diff --git a/src/databricks/labs/ucx/providers/mixins/__init__.py b/tests/unit/assessment/__init__.py similarity index 100% rename from src/databricks/labs/ucx/providers/mixins/__init__.py rename to tests/unit/assessment/__init__.py diff --git a/tests/unit/test_assessment.py b/tests/unit/assessment/test_assessment.py similarity index 84% rename from tests/unit/test_assessment.py rename to tests/unit/assessment/test_assessment.py index 47978e867a..69defba277 100644 --- a/tests/unit/test_assessment.py +++ b/tests/unit/assessment/test_assessment.py @@ -1,6 +1,6 @@ from unittest.mock import Mock -from databricks.labs.ucx.toolkits.assessment import AssessmentToolkit +from databricks.labs.ucx.assessment.assessment import AssessmentToolkit def test_get_command(): diff --git a/src/databricks/labs/ucx/support/__init__.py b/tests/unit/framework/__init__.py similarity index 100% rename from src/databricks/labs/ucx/support/__init__.py rename to tests/unit/framework/__init__.py diff --git a/tests/unit/mocks.py b/tests/unit/framework/mocks.py similarity index 95% rename from tests/unit/mocks.py rename to tests/unit/framework/mocks.py index e0f054347b..645372a4a8 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/framework/mocks.py @@ -2,7 +2,7 @@ import re from collections.abc import Iterator -from databricks.labs.ucx.tacl._internal import SqlBackend +from databricks.labs.ucx.framework.crawlers import SqlBackend logger = logging.getLogger(__name__) diff --git a/tests/unit/test_crawler_base.py b/tests/unit/framework/test_crawlers.py similarity index 96% rename from tests/unit/test_crawler_base.py rename to tests/unit/framework/test_crawlers.py index a73920677b..8afb4c9599 100644 --- a/tests/unit/test_crawler_base.py +++ b/tests/unit/framework/test_crawlers.py @@ -2,9 +2,9 @@ import pytest -from databricks.labs.ucx.tacl._internal import CrawlerBase +from databricks.labs.ucx.framework.crawlers import CrawlerBase -from .mocks import MockBackend +from ..framework.mocks import MockBackend @dataclass diff --git a/tests/unit/test_logger.py b/tests/unit/framework/test_logger.py similarity index 100% rename from tests/unit/test_logger.py rename to tests/unit/framework/test_logger.py diff --git a/src/databricks/labs/ucx/tacl/__init__.py b/tests/unit/hive_metastore/__init__.py similarity index 100% rename from src/databricks/labs/ucx/tacl/__init__.py rename to tests/unit/hive_metastore/__init__.py diff --git a/tests/unit/test_grants.py b/tests/unit/hive_metastore/test_grants.py similarity index 96% rename from tests/unit/test_grants.py rename to tests/unit/hive_metastore/test_grants.py index 26ec7c4e07..6757e73b80 100644 --- a/tests/unit/test_grants.py +++ b/tests/unit/hive_metastore/test_grants.py @@ -1,10 +1,10 @@ import pytest -from databricks.labs.ucx.providers.mixins.sql import Row -from databricks.labs.ucx.tacl.grants import Grant, GrantsCrawler -from databricks.labs.ucx.tacl.tables import TablesCrawler +from databricks.labs.ucx.hive_metastore.grants import Grant, GrantsCrawler +from databricks.labs.ucx.hive_metastore.tables import TablesCrawler +from databricks.labs.ucx.mixins.sql import Row -from .mocks import MockBackend +from ..framework.mocks import MockBackend def test_type_and_key_table(): diff --git a/tests/unit/test_tables.py b/tests/unit/hive_metastore/test_tables.py similarity index 93% rename from tests/unit/test_tables.py rename to tests/unit/hive_metastore/test_tables.py index 92103aa590..4cbd8d6003 100644 --- a/tests/unit/test_tables.py +++ b/tests/unit/hive_metastore/test_tables.py @@ -1,7 +1,8 @@ import pytest -from databricks.labs.ucx.tacl._internal import SqlBackend -from databricks.labs.ucx.tacl.tables import Table, TablesCrawler +from databricks.labs.ucx.hive_metastore.tables import Table, TablesCrawler + +from ..framework.mocks import MockBackend def test_is_delta_true(): @@ -84,5 +85,5 @@ def test_uc_sql(table, query): def test_tables_crawler_inventory_table(): - tc = TablesCrawler(SqlBackend, "main", "default") + tc = TablesCrawler(MockBackend(), "main", "default") assert tc._table == "tables" diff --git a/src/databricks/labs/ucx/toolkits/__init__.py b/tests/unit/mixins/__init__.py similarity index 100% rename from src/databricks/labs/ucx/toolkits/__init__.py rename to tests/unit/mixins/__init__.py diff --git a/tests/unit/test_ratelimit.py b/tests/unit/mixins/test_ratelimit.py similarity index 85% rename from tests/unit/test_ratelimit.py rename to tests/unit/mixins/test_ratelimit.py index 806180e553..aedb42b4dd 100644 --- a/tests/unit/test_ratelimit.py +++ b/tests/unit/mixins/test_ratelimit.py @@ -1,6 +1,6 @@ import pytest -from databricks.labs.ucx.providers.mixins.hardening import rate_limited +from databricks.labs.ucx.mixins.hardening import rate_limited def test_ratelimiting(mocker): diff --git a/tests/unit/support/test_group_level.py b/tests/unit/support/test_group_level.py deleted file mode 100644 index c8c4b11e3c..0000000000 --- a/tests/unit/support/test_group_level.py +++ /dev/null @@ -1,93 +0,0 @@ -import json -from unittest.mock import MagicMock - -import pytest -from databricks.sdk.service import iam - -from databricks.labs.ucx.inventory.types import PermissionsInventoryItem -from databricks.labs.ucx.support.group_level import ScimSupport - - -def test_scim_crawler(): - ws = MagicMock() - ws.groups.list.return_value = [ - iam.Group( - id="1", - display_name="group1", - roles=[], # verify that empty roles and entitlements are not returned - ), - iam.Group( - id="2", - display_name="group2", - roles=[iam.ComplexValue(value="role1")], - entitlements=[iam.ComplexValue(value="entitlement1")], - ), - iam.Group( - id="3", - display_name="group3", - roles=[iam.ComplexValue(value="role1"), iam.ComplexValue(value="role2")], - entitlements=[], - ), - ] - sup = ScimSupport(ws=ws) - tasks = list(sup.get_crawler_tasks()) - assert len(tasks) == 3 - ws.groups.list.assert_called_once() - for task in tasks: - item = task() - if item.object_id == "1": - assert item is None - else: - assert item.object_id in ["2", "3"] - assert item.support in ["roles", "entitlements"] - assert item.raw_object_permissions is not None - - -def test_scim_apply(migration_state): - ws = MagicMock() - sup = ScimSupport(ws=ws) - sample_permissions = [iam.ComplexValue(value="role1"), iam.ComplexValue(value="role2")] - item = PermissionsInventoryItem( - object_id="test-ws", - support="roles", - raw_object_permissions=json.dumps([p.as_dict() for p in sample_permissions]), - ) - - task = sup.get_apply_task(item, migration_state, "backup") - task() - ws.groups.patch.assert_called_once_with( - id="test-backup", - operations=[iam.Patch(op=iam.PatchOp.ADD, path="roles", value=[p.as_dict() for p in sample_permissions])], - schemas=[iam.PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], - ) - - -def test_no_group_in_migration_state(migration_state): - ws = MagicMock() - sup = ScimSupport(ws=ws) - sample_permissions = [iam.ComplexValue(value="role1"), iam.ComplexValue(value="role2")] - item = PermissionsInventoryItem( - object_id="test-non-existent", - support="roles", - raw_object_permissions=json.dumps([p.as_dict() for p in sample_permissions]), - ) - with pytest.raises(ValueError): - sup._get_apply_task(item, migration_state, "backup") - - -def test_non_relevant(migration_state): - ws = MagicMock() - sup = ScimSupport(ws=ws) - sample_permissions = [iam.ComplexValue(value="role1")] - relevant_item = PermissionsInventoryItem( - object_id="test-ws", - support="roles", - raw_object_permissions=json.dumps([p.as_dict() for p in sample_permissions]), - ) - irrelevant_item = PermissionsInventoryItem( - object_id="something-non-relevant", - support="roles", - raw_object_permissions=json.dumps([p.as_dict() for p in sample_permissions]), - ) - assert sup.is_item_relevant(relevant_item, migration_state) - assert not sup.is_item_relevant(irrelevant_item, migration_state) diff --git a/tests/unit/support/test_impl.py b/tests/unit/support/test_impl.py deleted file mode 100644 index f7e931bd76..0000000000 --- a/tests/unit/support/test_impl.py +++ /dev/null @@ -1,29 +0,0 @@ -from unittest.mock import MagicMock - -from databricks.labs.ucx.support.impl import SupportsProvider - - -def test_supports_provider(): - provider = SupportsProvider(ws=MagicMock(), num_threads=1, workspace_start_path="/") - assert provider.supports.keys() == { - "entitlements", - "roles", - "clusters", - "cluster-policies", - "instance-pools", - "sql/warehouses", - "jobs", - "pipelines", - "experiments", - "registered-models", - "tokens", - "passwords", - "notebooks", - "files", - "directories", - "repos", - "alerts", - "queries", - "dashboards", - "secrets", - } diff --git a/tests/unit/support/test_listing.py b/tests/unit/support/test_listing.py deleted file mode 100644 index 293506cac1..0000000000 --- a/tests/unit/support/test_listing.py +++ /dev/null @@ -1,85 +0,0 @@ -import datetime as dt -from unittest.mock import MagicMock, patch - -from databricks.sdk.service import ml, workspace - -from databricks.labs.ucx.inventory.types import RequestObjectType -from databricks.labs.ucx.support.listing import ( - WorkspaceListing, - experiments_listing, - logger, - models_listing, - workspace_listing, -) -from databricks.labs.ucx.support.permissions import listing_wrapper - - -def test_logging_calls(): - workspace_listing = WorkspaceListing(ws=MagicMock(), num_threads=1) - workspace_listing.start_time = dt.datetime.now() - workspace_listing._counter = 9 - with patch.object(logger, "info") as mock_info: - workspace_listing._progress_report(None) - mock_info.assert_called_once() - - -def test_models_listing(): - ws = MagicMock() - ws.model_registry.list_models.return_value = [ml.Model(name="test")] - ws.model_registry.get_model.return_value = ml.GetModelResponse( - registered_model_databricks=ml.ModelDatabricks( - id="some-id", - name="test", - ) - ) - - wrapped = listing_wrapper(models_listing(ws), id_attribute="id", object_type=RequestObjectType.REGISTERED_MODELS) - result = list(wrapped()) - assert len(result) == 1 - assert result[0].object_id == "some-id" - assert result[0].request_type == RequestObjectType.REGISTERED_MODELS - - -def test_experiment_listing(): - ws = MagicMock() - ws.experiments.list_experiments.return_value = [ - ml.Experiment(experiment_id="test"), - ml.Experiment(experiment_id="test2", tags=[ml.ExperimentTag(key="whatever", value="SOMETHING")]), - ml.Experiment(experiment_id="test3", tags=[ml.ExperimentTag(key="mlflow.experimentType", value="NOTEBOOK")]), - ml.Experiment( - experiment_id="test4", tags=[ml.ExperimentTag(key="mlflow.experiment.sourceType", value="REPO_NOTEBOOK")] - ), - ] - wrapped = listing_wrapper( - experiments_listing(ws), id_attribute="experiment_id", object_type=RequestObjectType.EXPERIMENTS - ) - results = list(wrapped()) - assert len(results) == 2 - for res in results: - assert res.request_type == RequestObjectType.EXPERIMENTS - assert res.object_id in ["test", "test2"] - - -def test_workspace_listing(): - listing = MagicMock(spec=WorkspaceListing) - listing.walk.return_value = [ - workspace.ObjectInfo(object_id=1, object_type=workspace.ObjectType.NOTEBOOK), - workspace.ObjectInfo(object_id=2, object_type=workspace.ObjectType.DIRECTORY), - workspace.ObjectInfo(object_id=3, object_type=workspace.ObjectType.LIBRARY), - workspace.ObjectInfo(object_id=4, object_type=workspace.ObjectType.REPO), - workspace.ObjectInfo(object_id=5, object_type=workspace.ObjectType.FILE), - workspace.ObjectInfo(object_id=6, object_type=None), # MLflow Experiment - ] - - with patch("databricks.labs.ucx.support.listing.WorkspaceListing", return_value=listing): - results = workspace_listing(ws=MagicMock())() - assert len(list(results)) == 4 - listing.walk.assert_called_once() - for res in results: - assert res.request_type in [ - RequestObjectType.NOTEBOOKS, - RequestObjectType.DIRECTORIES, - RequestObjectType.REPOS, - RequestObjectType.FILES, - ] - assert res.object_id in [1, 2, 4, 5] diff --git a/tests/unit/test_generic.py b/tests/unit/test_generic.py deleted file mode 100644 index 3f9cb12e9d..0000000000 --- a/tests/unit/test_generic.py +++ /dev/null @@ -1,18 +0,0 @@ -import pytest - -from databricks.labs.ucx.generic import StrEnum - - -def test_error(): - with pytest.raises(TypeError): - - class InvalidEnum(StrEnum): - A = 1 - - -def test_generate(): - class Sample(StrEnum): - A = "a" - B = "b" - - assert Sample._generate_next_value_("C", 3) == "C" diff --git a/tests/unit/test_permissions_inventory.py b/tests/unit/test_permissions_inventory.py deleted file mode 100644 index e5554d6ec2..0000000000 --- a/tests/unit/test_permissions_inventory.py +++ /dev/null @@ -1,55 +0,0 @@ -from databricks.labs.ucx.inventory.permissions_inventory import ( - PermissionsInventoryTable, -) -from databricks.labs.ucx.inventory.types import PermissionsInventoryItem -from databricks.labs.ucx.providers.mixins.sql import Row - -from .mocks import MockBackend - - -def test_inventory_table_manager_init(): - b = MockBackend() - pi = PermissionsInventoryTable(b, "test_database") - - assert pi._full_name == "hive_metastore.test_database.permissions" - - -def test_cleanup(): - b = MockBackend() - pi = PermissionsInventoryTable(b, "test_database") - - pi.cleanup() - - assert "DROP TABLE IF EXISTS hive_metastore.test_database.permissions" == b.queries[0] - - -def test_save(): - b = MockBackend() - pi = PermissionsInventoryTable(b, "test_database") - - pi.save([PermissionsInventoryItem("object1", "clusters", "test acl")]) - - assert ( - "INSERT INTO hive_metastore.test_database.permissions (object_id, support, " - "raw_object_permissions) VALUES ('object1', 'clusters', 'test acl')" - ) == b.queries[0] - - -def make_row(data, columns): - row = Row(data) - row.__columns__ = columns - return row - - -def test_load_all(): - b = MockBackend( - rows={ - "SELECT": [ - make_row(("object1", "clusters", "test acl"), ["object_id", "support", "raw_object_permissions"]), - ] - } - ) - pi = PermissionsInventoryTable(b, "test_database") - - output = pi.load_all() - assert output[0] == PermissionsInventoryItem("object1", support="clusters", raw_object_permissions="test acl") diff --git a/tests/unit/test_permissions_manager.py b/tests/unit/test_permissions_manager.py deleted file mode 100644 index 8d0bf9700f..0000000000 --- a/tests/unit/test_permissions_manager.py +++ /dev/null @@ -1,82 +0,0 @@ -import json -from unittest import mock -from unittest.mock import MagicMock - -import pytest -from databricks.sdk.service import iam - -from databricks.labs.ucx.inventory.permissions import PermissionManager -from databricks.labs.ucx.inventory.permissions_inventory import ( - PermissionsInventoryTable, -) -from databricks.labs.ucx.inventory.types import PermissionsInventoryItem -from databricks.labs.ucx.support.impl import SupportsProvider - - -def test_manager_inventorize(): - sup = SupportsProvider(ws=MagicMock(), num_threads=1, workspace_start_path="/") - pm = PermissionManager( - ws=MagicMock(), permissions_inventory=PermissionsInventoryTable(MagicMock(), "test"), supports_provider=sup - ) - - with mock.patch("databricks.labs.ucx.inventory.permissions.ThreadedExecution.run", MagicMock()) as run_mock: - pm.inventorize_permissions() - run_mock.assert_called_once() - - -def test_manager_apply(): - sup = SupportsProvider(ws=MagicMock(), num_threads=1, workspace_start_path="/") - inventory = MagicMock(spec=PermissionsInventoryTable) - inventory.load_all.return_value = [ - PermissionsInventoryItem( - object_id="test", - support="clusters", - raw_object_permissions=json.dumps( - iam.ObjectPermissions( - object_id="test", - object_type="clusters", - access_control_list=[ - iam.AccessControlResponse( - group_name="test", - all_permissions=[ - iam.Permission(inherited=False, permission_level=iam.PermissionLevel.CAN_USE) - ], - ) - ], - ).as_dict() - ), - ), - PermissionsInventoryItem( - object_id="test2", - support="cluster-policies", - raw_object_permissions=json.dumps( - iam.ObjectPermissions( - object_id="test", - object_type="cluster-policies", - access_control_list=[ - iam.AccessControlResponse( - group_name="test", - all_permissions=[ - iam.Permission(inherited=False, permission_level=iam.PermissionLevel.CAN_USE) - ], - ) - ], - ).as_dict() - ), - ), - ] - pm = PermissionManager(ws=MagicMock(), permissions_inventory=inventory, supports_provider=sup) - with mock.patch("databricks.labs.ucx.inventory.permissions.ThreadedExecution.run", MagicMock()) as run_mock: - pm.apply_group_permissions(migration_state=MagicMock(), destination="backup") - run_mock.assert_called_once() - - -def test_unregistered_support(): - sup = SupportsProvider(ws=MagicMock(), num_threads=1, workspace_start_path="/") - inventory = MagicMock(spec=PermissionsInventoryTable) - inventory.load_all.return_value = [ - PermissionsInventoryItem(object_id="test", support="SOME_NON_EXISTENT", raw_object_permissions="") - ] - pm = PermissionManager(ws=MagicMock(), permissions_inventory=inventory, supports_provider=sup) - with pytest.raises(ValueError): - pm.apply_group_permissions(migration_state=MagicMock(), destination="backup") diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py deleted file mode 100644 index e9423cafd7..0000000000 --- a/tests/unit/test_types.py +++ /dev/null @@ -1,7 +0,0 @@ -from databricks.labs.ucx.inventory.types import RequestObjectType - - -def test_request_object_type(): - typed = RequestObjectType.AUTHORIZATION - assert typed == "authorization" - assert typed.__repr__() == "authorization" diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py deleted file mode 100644 index a47f264ce8..0000000000 --- a/tests/unit/test_utils.py +++ /dev/null @@ -1,11 +0,0 @@ -from databricks.labs.ucx.utils import Request, noop - - -def test_req(): - req = Request({"test": "test"}) - assert req.as_dict() == {"test": "test"} - - -def test_noop(): - noop() - assert True diff --git a/tests/unit/support/__init__.py b/tests/unit/workspace_access/__init__.py similarity index 100% rename from tests/unit/support/__init__.py rename to tests/unit/workspace_access/__init__.py diff --git a/tests/unit/support/conftest.py b/tests/unit/workspace_access/conftest.py similarity index 89% rename from tests/unit/support/conftest.py rename to tests/unit/workspace_access/conftest.py index 5375cce750..fbe96332a2 100644 --- a/tests/unit/support/conftest.py +++ b/tests/unit/workspace_access/conftest.py @@ -1,7 +1,7 @@ import pytest from databricks.sdk.service import iam -from databricks.labs.ucx.providers.groups_info import ( +from databricks.labs.ucx.workspace_access.groups import ( GroupMigrationState, MigrationGroupInfo, ) diff --git a/tests/unit/support/test_base.py b/tests/unit/workspace_access/test_base.py similarity index 63% rename from tests/unit/support/test_base.py rename to tests/unit/workspace_access/test_base.py index 665e565243..bf9d108672 100644 --- a/tests/unit/support/test_base.py +++ b/tests/unit/workspace_access/test_base.py @@ -2,18 +2,16 @@ from databricks.sdk.service import iam -from databricks.labs.ucx.inventory.types import PermissionsInventoryItem -from databricks.labs.ucx.providers.groups_info import ( +from databricks.labs.ucx.workspace_access.base import Applier, Permissions +from databricks.labs.ucx.workspace_access.groups import ( GroupMigrationState, MigrationGroupInfo, ) -from databricks.labs.ucx.support.base import Applier -from databricks.labs.ucx.utils import noop def test_applier(): class SampleApplier(Applier): - def is_item_relevant(self, item: PermissionsInventoryItem, migration_state: GroupMigrationState) -> bool: + def is_item_relevant(self, item: Permissions, migration_state: GroupMigrationState) -> bool: workspace_groups = [info.workspace.display_name for info in migration_state.groups] return item.object_id in workspace_groups @@ -24,7 +22,7 @@ def test_task(): return partial(test_task) applier = SampleApplier() - positive_item = PermissionsInventoryItem(object_id="test", support="test", raw_object_permissions="test") + positive_item = Permissions(object_id="test", object_type="test", raw="test") migration_state = GroupMigrationState() migration_state.add( group=MigrationGroupInfo( @@ -37,6 +35,6 @@ def test_task(): task = applier.get_apply_task(positive_item, migration_state, "backup") assert task.func.__name__ == "test_task" - negative_item = PermissionsInventoryItem(object_id="not-here", support="test", raw_object_permissions="test") + negative_item = Permissions(object_id="not-here", object_type="test", raw="test") new_task = applier.get_apply_task(negative_item, migration_state, "backup") - assert new_task.func == noop + new_task.func() diff --git a/tests/unit/support/test_permissions.py b/tests/unit/workspace_access/test_generic.py similarity index 59% rename from tests/unit/support/test_permissions.py rename to tests/unit/workspace_access/test_generic.py index f21c7aa520..b85f80088c 100644 --- a/tests/unit/support/test_permissions.py +++ b/tests/unit/workspace_access/test_generic.py @@ -3,16 +3,15 @@ import pytest from databricks.sdk.core import DatabricksError -from databricks.sdk.service import compute, iam +from databricks.sdk.service import compute, iam, ml -from databricks.labs.ucx.inventory.types import ( - PermissionsInventoryItem, - RequestObjectType, -) -from databricks.labs.ucx.support.listing import authorization_listing -from databricks.labs.ucx.support.permissions import ( +from databricks.labs.ucx.workspace_access.generic import ( GenericPermissionsSupport, + Permissions, + authorization_listing, + experiments_listing, listing_wrapper, + models_listing, ) @@ -26,7 +25,7 @@ def test_crawler(): sample_permission = iam.ObjectPermissions( object_id="test", - object_type=str(RequestObjectType.CLUSTERS), + object_type="clusters", access_control_list=[ iam.AccessControlResponse( group_name="test", @@ -40,7 +39,7 @@ def test_crawler(): sup = GenericPermissionsSupport( ws=ws, listings=[ - listing_wrapper(ws.clusters.list, "cluster_id", RequestObjectType.CLUSTERS), + listing_wrapper(ws.clusters.list, "cluster_id", "clusters"), ], ) @@ -51,21 +50,21 @@ def test_crawler(): item = _task() ws.permissions.get.assert_called_once() assert item.object_id == "test" - assert item.support == "clusters" - assert json.loads(item.raw_object_permissions) == sample_permission.as_dict() + assert item.object_type == "clusters" + assert json.loads(item.raw) == sample_permission.as_dict() def test_apply(migration_state): ws = MagicMock() sup = GenericPermissionsSupport(ws=ws, listings=[]) # no listings since only apply is tested - item = PermissionsInventoryItem( + item = Permissions( object_id="test", - support="clusters", - raw_object_permissions=json.dumps( + object_type="clusters", + raw=json.dumps( iam.ObjectPermissions( object_id="test", - object_type=str(RequestObjectType.CLUSTERS), + object_type="clusters", access_control_list=[ iam.AccessControlResponse( group_name="test", @@ -93,17 +92,13 @@ def test_apply(migration_state): ) ] - ws.permissions.update.assert_called_with( - request_object_type=RequestObjectType.CLUSTERS, - request_object_id="test", - access_control_list=expected_acl_payload, - ) + ws.permissions.update.assert_called_with("clusters", "test", access_control_list=expected_acl_payload) def test_relevance(): sup = GenericPermissionsSupport(ws=MagicMock(), listings=[]) # no listings since only apply is tested result = sup.is_item_relevant( - item=PermissionsInventoryItem(object_id="passwords", support="passwords", raw_object_permissions="some-stuff"), + item=Permissions(object_id="passwords", object_type="passwords", raw="some-stuff"), migration_state=MagicMock(), ) assert result is True @@ -113,12 +108,12 @@ def test_safe_get(): ws = MagicMock() ws.permissions.get.side_effect = DatabricksError(error_code="RESOURCE_DOES_NOT_EXIST") sup = GenericPermissionsSupport(ws=ws, listings=[]) - result = sup._safe_get_permissions(ws, RequestObjectType.CLUSTERS, "test") + result = sup._safe_get_permissions("clusters", "test") assert result is None ws.permissions.get.side_effect = DatabricksError(error_code="SOMETHING_UNEXPECTED") with pytest.raises(DatabricksError): - sup._safe_get_permissions(ws, RequestObjectType.CLUSTERS, "test") + sup._safe_get_permissions("clusters", "test") def test_no_permissions(): @@ -132,7 +127,7 @@ def test_no_permissions(): sup = GenericPermissionsSupport( ws=ws, listings=[ - listing_wrapper(ws.clusters.list, "cluster_id", RequestObjectType.CLUSTERS), + listing_wrapper(ws.clusters.list, "cluster_id", "clusters"), ], ) tasks = list(sup.get_crawler_tasks()) @@ -154,12 +149,8 @@ def test_passwords_tokens_crawler(migration_state): ] ws.permissions.get.side_effect = [ - iam.ObjectPermissions( - object_id="passwords", object_type=RequestObjectType.AUTHORIZATION, access_control_list=basic_acl - ), - iam.ObjectPermissions( - object_id="tokens", object_type=RequestObjectType.AUTHORIZATION, access_control_list=basic_acl - ), + iam.ObjectPermissions(object_id="passwords", object_type="authorization", access_control_list=basic_acl), + iam.ObjectPermissions(object_id="tokens", object_type="authorization", access_control_list=basic_acl), ] sup = GenericPermissionsSupport(ws=ws, listings=[authorization_listing()]) @@ -168,17 +159,49 @@ def test_passwords_tokens_crawler(migration_state): auth_items = [task() for task in tasks] for item in auth_items: assert item.object_id in ["tokens", "passwords"] - assert item.support in ["tokens", "passwords"] + assert item.object_type == "authorization" applier = sup.get_apply_task(item, migration_state, "backup") - new_acl = sup._prepare_new_acl( - permissions=iam.ObjectPermissions.from_dict(json.loads(item.raw_object_permissions)), - migration_state=migration_state, - destination="backup", - ) applier() ws.permissions.update.assert_called_once_with( - request_object_type=RequestObjectType.AUTHORIZATION, - request_object_id=item.object_id, - access_control_list=new_acl, + item.object_type, + item.object_id, + access_control_list=[ + iam.AccessControlRequest(group_name="db-temp-test", permission_level=iam.PermissionLevel.CAN_USE) + ], ) ws.permissions.update.reset_mock() + + +def test_models_listing(): + ws = MagicMock() + ws.model_registry.list_models.return_value = [ml.Model(name="test")] + ws.model_registry.get_model.return_value = ml.GetModelResponse( + registered_model_databricks=ml.ModelDatabricks( + id="some-id", + name="test", + ) + ) + + wrapped = listing_wrapper(models_listing(ws), id_attribute="id", object_type="registered-models") + result = list(wrapped()) + assert len(result) == 1 + assert result[0].object_id == "some-id" + assert result[0].request_type == "registered-models" + + +def test_experiment_listing(): + ws = MagicMock() + ws.experiments.list_experiments.return_value = [ + ml.Experiment(experiment_id="test"), + ml.Experiment(experiment_id="test2", tags=[ml.ExperimentTag(key="whatever", value="SOMETHING")]), + ml.Experiment(experiment_id="test3", tags=[ml.ExperimentTag(key="mlflow.experimentType", value="NOTEBOOK")]), + ml.Experiment( + experiment_id="test4", tags=[ml.ExperimentTag(key="mlflow.experiment.sourceType", value="REPO_NOTEBOOK")] + ), + ] + wrapped = listing_wrapper(experiments_listing(ws), id_attribute="experiment_id", object_type="experiments") + results = list(wrapped()) + assert len(results) == 2 + for res in results: + assert res.request_type == "experiments" + assert res.object_id in ["test", "test2"] diff --git a/tests/unit/test_group.py b/tests/unit/workspace_access/test_groups.py similarity index 77% rename from tests/unit/test_group.py rename to tests/unit/workspace_access/test_groups.py index 0dfb510b92..70ddc6d4f9 100644 --- a/tests/unit/test_group.py +++ b/tests/unit/workspace_access/test_groups.py @@ -1,11 +1,98 @@ -from unittest.mock import Mock +import json +from unittest.mock import MagicMock, Mock import pytest +from databricks.sdk.service import iam from databricks.sdk.service.iam import Group, ResourceMeta from databricks.labs.ucx.config import GroupsConfig -from databricks.labs.ucx.managers.group import GroupManager -from databricks.labs.ucx.providers.groups_info import MigrationGroupInfo +from databricks.labs.ucx.workspace_access.groups import GroupManager, MigrationGroupInfo +from databricks.labs.ucx.workspace_access.scim import Permissions, ScimSupport + + +def test_scim_crawler(): + ws = MagicMock() + ws.groups.list.return_value = [ + iam.Group( + id="1", + display_name="group1", + roles=[], # verify that empty roles and entitlements are not returned + ), + iam.Group( + id="2", + display_name="group2", + roles=[iam.ComplexValue(value="role1")], + entitlements=[iam.ComplexValue(value="entitlement1")], + ), + iam.Group( + id="3", + display_name="group3", + roles=[iam.ComplexValue(value="role1"), iam.ComplexValue(value="role2")], + entitlements=[], + ), + ] + sup = ScimSupport(ws=ws) + tasks = list(sup.get_crawler_tasks()) + assert len(tasks) == 3 + ws.groups.list.assert_called_once() + for task in tasks: + item = task() + if item.object_id == "1": + assert item is None + else: + assert item.object_id in ["2", "3"] + assert item.object_type in ["roles", "entitlements"] + assert item.raw is not None + + +def test_scim_apply(migration_state): + ws = MagicMock() + sup = ScimSupport(ws=ws) + sample_permissions = [iam.ComplexValue(value="role1"), iam.ComplexValue(value="role2")] + item = Permissions( + object_id="test-ws", + object_type="roles", + raw=json.dumps([p.as_dict() for p in sample_permissions]), + ) + + task = sup.get_apply_task(item, migration_state, "backup") + task() + ws.groups.patch.assert_called_once_with( + id="test-backup", + operations=[iam.Patch(op=iam.PatchOp.ADD, path="roles", value=[p.as_dict() for p in sample_permissions])], + schemas=[iam.PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], + ) + + +def test_no_group_in_migration_state(migration_state): + ws = MagicMock() + sup = ScimSupport(ws=ws) + sample_permissions = [iam.ComplexValue(value="role1"), iam.ComplexValue(value="role2")] + item = Permissions( + object_id="test-non-existent", + object_type="roles", + raw=json.dumps([p.as_dict() for p in sample_permissions]), + ) + with pytest.raises(ValueError): + sup._get_apply_task(item, migration_state, "backup") + + +def test_non_relevant(migration_state): + ws = MagicMock() + sup = ScimSupport(ws=ws) + sample_permissions = [iam.ComplexValue(value="role1")] + relevant_item = Permissions( + object_id="test-ws", + object_type="roles", + raw=json.dumps([p.as_dict() for p in sample_permissions]), + ) + irrelevant_item = Permissions( + object_id="something-non-relevant", + object_type="roles", + raw=json.dumps([p.as_dict() for p in sample_permissions]), + ) + assert sup.is_item_relevant(relevant_item, migration_state) + assert not sup.is_item_relevant(irrelevant_item, migration_state) def compare(s, t): diff --git a/tests/unit/test_listing.py b/tests/unit/workspace_access/test_listing.py similarity index 71% rename from tests/unit/test_listing.py rename to tests/unit/workspace_access/test_listing.py index cade7d979d..21d1d4e4cc 100644 --- a/tests/unit/test_listing.py +++ b/tests/unit/workspace_access/test_listing.py @@ -1,8 +1,46 @@ -from unittest.mock import Mock +import datetime as dt +from unittest.mock import MagicMock, Mock, patch +from databricks.sdk.service import workspace from databricks.sdk.service.workspace import ObjectInfo, ObjectType -from databricks.labs.ucx.support.listing import WorkspaceListing +from databricks.labs.ucx.workspace_access.generic import workspace_listing +from databricks.labs.ucx.workspace_access.listing import WorkspaceListing + + +def test_logging_calls(): + ws = MagicMock() + workspace_listing = WorkspaceListing(ws=ws, num_threads=1) + workspace_listing.start_time = dt.datetime.now() + workspace_listing._counter = 9 + # with patch.object(logger, "info") as mock_info: + # workspace_listing._progress_report(None) + # mock_info.assert_called_once() + + +def test_workspace_listing(): + listing = MagicMock(spec=WorkspaceListing) + listing.walk.return_value = [ + workspace.ObjectInfo(object_id=1, object_type=workspace.ObjectType.NOTEBOOK), + workspace.ObjectInfo(object_id=2, object_type=workspace.ObjectType.DIRECTORY), + workspace.ObjectInfo(object_id=3, object_type=workspace.ObjectType.LIBRARY), + workspace.ObjectInfo(object_id=4, object_type=workspace.ObjectType.REPO), + workspace.ObjectInfo(object_id=5, object_type=workspace.ObjectType.FILE), + workspace.ObjectInfo(object_id=6, object_type=None), # MLflow Experiment + ] + + with patch("databricks.labs.ucx.workspace_access.listing.WorkspaceListing", return_value=listing): + results = workspace_listing(ws=MagicMock())() + assert len(list(results)) == 4 + listing.walk.assert_called_once() + for res in results: + assert res.request_type in [ + "notebooks", + "directories", + "repos", + "files", + ] + assert res.object_id in [1, 2, 4, 5] # Helper to compare an unordered list of objects diff --git a/tests/unit/workspace_access/test_manager.py b/tests/unit/workspace_access/test_manager.py new file mode 100644 index 0000000000..db3e03408c --- /dev/null +++ b/tests/unit/workspace_access/test_manager.py @@ -0,0 +1,161 @@ +import json +from unittest.mock import MagicMock + +import pytest +from databricks.sdk.service import iam + +from databricks.labs.ucx.mixins.sql import Row +from databricks.labs.ucx.workspace_access.manager import PermissionManager, Permissions + +from ..framework.mocks import MockBackend + + +@pytest.fixture +def b(): + return MockBackend() + + +def test_inventory_table_manager_init(b): + pi = PermissionManager(b, "test_database", [], {}) + + assert pi._full_name == "hive_metastore.test_database.permissions" + + +def test_cleanup(b): + pi = PermissionManager(b, "test_database", [], {}) + + pi.cleanup() + + assert "DROP TABLE IF EXISTS hive_metastore.test_database.permissions" == b.queries[0] + + +def test_save(b): + pi = PermissionManager(b, "test_database", [], {}) + + pi._save([Permissions("object1", "clusters", "test acl")]) + + assert ( + "INSERT INTO hive_metastore.test_database.permissions (object_id, object_type, " + "raw) VALUES ('object1', 'clusters', 'test acl')" + ) == b.queries[0] + + +def permissions_row(*data): + row = Row(data) + row.__columns__ = ["object_id", "object_type", "raw"] + return row + + +def test_load_all(): + b = MockBackend( + rows={ + "SELECT": [ + permissions_row("object1", "clusters", "test acl"), + ] + } + ) + pi = PermissionManager(b, "test_database", [], {}) + + output = pi._load_all() + assert output[0] == Permissions("object1", "clusters", "test acl") + + +def test_manager_inventorize(b, mocker): + some_crawler = mocker.Mock() + some_crawler.get_crawler_tasks = lambda: [lambda: None, lambda: Permissions("a", "b", "c"), lambda: None] + pm = PermissionManager(b, "test_database", [some_crawler], {"b": mocker.Mock()}) + + pm.inventorize_permissions() + + assert ( + "INSERT INTO hive_metastore.test_database.permissions " + "(object_id, object_type, raw) VALUES ('a', 'b', 'c')" == b.queries[0] + ) + + +def test_manager_inventorize_unknown_object_type_raises_error(b, mocker): + some_crawler = mocker.Mock() + some_crawler.get_crawler_tasks = lambda: [lambda: None, lambda: Permissions("a", "b", "c"), lambda: None] + pm = PermissionManager(b, "test_database", [some_crawler], {}) + + with pytest.raises(KeyError): + pm.inventorize_permissions() + + +def test_manager_apply(mocker): + b = MockBackend( + rows={ + "SELECT": [ + permissions_row( + "test", + "clusters", + json.dumps( + iam.ObjectPermissions( + object_id="test", + object_type="clusters", + access_control_list=[ + iam.AccessControlResponse( + group_name="test", + all_permissions=[ + iam.Permission(inherited=False, permission_level=iam.PermissionLevel.CAN_USE) + ], + ) + ], + ).as_dict() + ), + ), + permissions_row( + "test2", + "cluster-policies", + json.dumps( + iam.ObjectPermissions( + object_id="test", + object_type="cluster-policies", + access_control_list=[ + iam.AccessControlResponse( + group_name="test", + all_permissions=[ + iam.Permission(inherited=False, permission_level=iam.PermissionLevel.CAN_USE) + ], + ) + ], + ).as_dict() + ), + ), + ] + } + ) + + # has to be set, as it's going to be appended through multiple threads + applied_items = set() + mock_applier = mocker.Mock() + # this emulates a real applier and call to an API + mock_applier.get_apply_task = lambda item, _, dst: lambda: applied_items.add( + f"{item.object_id} {item.object_id} {dst}" + ) + + pm = PermissionManager( + b, + "test_database", + [], + { + "clusters": mock_applier, + "cluster-policies": mock_applier, + }, + ) + pm.apply_group_permissions(MagicMock(), "backup") + + assert {"test2 test2 backup", "test test backup"} == applied_items + + +def test_unregistered_support(): + b = MockBackend( + rows={ + "SELECT": [ + permissions_row("test", "__unknown__", "{}"), + ] + } + ) + pm = PermissionManager(b, "test", [], {}) + with pytest.raises(ValueError): + pm.apply_group_permissions(migration_state=MagicMock(), destination="backup") diff --git a/tests/unit/support/test_sql.py b/tests/unit/workspace_access/test_redash.py similarity index 84% rename from tests/unit/support/test_sql.py rename to tests/unit/workspace_access/test_redash.py index 1e3e1c411b..84587db7d8 100644 --- a/tests/unit/support/test_sql.py +++ b/tests/unit/workspace_access/test_redash.py @@ -5,8 +5,11 @@ from databricks.sdk.core import DatabricksError from databricks.sdk.service import sql -from databricks.labs.ucx.inventory.types import PermissionsInventoryItem -from databricks.labs.ucx.support.sql import SqlPermissionsSupport, listing_wrapper +from databricks.labs.ucx.workspace_access.redash import ( + Permissions, + SqlPermissionsSupport, + redash_listing_wrapper, +) def test_crawlers(): @@ -39,9 +42,9 @@ def test_crawlers(): sup = SqlPermissionsSupport( ws=ws, listings=[ - listing_wrapper(ws.alerts.list, sql.ObjectTypePlural.ALERTS), - listing_wrapper(ws.dashboards.list, sql.ObjectTypePlural.DASHBOARDS), - listing_wrapper(ws.queries.list, sql.ObjectTypePlural.QUERIES), + redash_listing_wrapper(ws.alerts.list, sql.ObjectTypePlural.ALERTS), + redash_listing_wrapper(ws.dashboards.list, sql.ObjectTypePlural.DASHBOARDS), + redash_listing_wrapper(ws.queries.list, sql.ObjectTypePlural.QUERIES), ], ) @@ -53,17 +56,17 @@ def test_crawlers(): for task in tasks: item = task() assert item.object_id == "test" - assert item.support in ["alerts", "dashboards", "queries"] - assert item.raw_object_permissions is not None + assert item.object_type in ["alerts", "dashboards", "queries"] + assert item.raw is not None def test_apply(migration_state): ws = MagicMock() sup = SqlPermissionsSupport(ws=ws, listings=[]) - item = PermissionsInventoryItem( + item = Permissions( object_id="test", - support="alerts", - raw_object_permissions=json.dumps( + object_type="alerts", + raw=json.dumps( sql.GetResponse( object_type=sql.ObjectType.ALERT, object_id="test", diff --git a/tests/unit/support/test_secrets.py b/tests/unit/workspace_access/test_secrets.py similarity index 86% rename from tests/unit/support/test_secrets.py rename to tests/unit/workspace_access/test_secrets.py index 594919f96a..dbe1699530 100644 --- a/tests/unit/support/test_secrets.py +++ b/tests/unit/workspace_access/test_secrets.py @@ -4,9 +4,11 @@ import pytest from databricks.sdk.service import workspace -from databricks.labs.ucx.inventory.types import PermissionsInventoryItem -from databricks.labs.ucx.providers.groups_info import GroupMigrationState -from databricks.labs.ucx.support.secrets import SecretScopesSupport +from databricks.labs.ucx.workspace_access.groups import GroupMigrationState +from databricks.labs.ucx.workspace_access.secrets import ( + Permissions, + SecretScopesSupport, +) def test_secret_scopes_crawler(): @@ -32,17 +34,17 @@ def test_secret_scopes_crawler(): item = _task() assert item.object_id == "test" - assert item.support == "secrets" - assert item.raw_object_permissions == '[{"permission": "MANAGE", "principal": "test"}]' + assert item.object_type == "secrets" + assert item.raw == '[{"permission": "MANAGE", "principal": "test"}]' def test_secret_scopes_apply(migration_state: GroupMigrationState): ws = MagicMock() sup = SecretScopesSupport(ws=ws) - item = PermissionsInventoryItem( + item = Permissions( object_id="test", - support="secrets", - raw_object_permissions=json.dumps( + object_type="secrets", + raw=json.dumps( [ workspace.AclItem( principal="test",