Skip to content

Commit

Permalink
Prefetch all account-level and workspace-level groups (#192)
Browse files Browse the repository at this point in the history
Addresses #190 and other SCIM-related issues. 

Conceptually, the following happens:

1. Existing method to list the workspace groups is inconsistent, and the
root cause of these issues is somewhere on SCIM API side. We've raised a
ticket, but the resolution timing is unclear.
2. To avoid being time-dependent on SCIM API, we re-introduce the
listing methods, but this time we intentionally don't use `filter`-based
conditions and explicitly split the API methods into two.
3. To avoid reaching out to the SCIM API too frequently, we list the
whole API only once and then work with an in-memory data.

This PR also adds 100% coverage for `GroupManager` and adds integration
tests for groups to verify it works as expected.
  • Loading branch information
renardeinside authored Sep 13, 2023
1 parent f4e5989 commit b654a3a
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 228 deletions.
96 changes: 52 additions & 44 deletions src/databricks/labs/ucx/managers/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.iam import Group
from databricks.sdk.service import iam
from ratelimit import limits, sleep_and_retry

from databricks.labs.ucx.config import GroupsConfig
Expand All @@ -30,40 +30,44 @@ def __init__(self, ws: WorkspaceClient, groups: GroupsConfig):
self._ws = ws
self.config = groups
self._migration_state: GroupMigrationState = GroupMigrationState()
self._account_groups = self._list_account_groups()
self._workspace_groups = self._list_workspace_groups()

def _list_workspace_groups(self) -> list[iam.Group]:
logger.debug("Listing workspace groups...")
workspace_groups = [
g
for g in self._ws.groups.list(attributes="id,displayName,meta")
if g.meta.resource_type == "WorkspaceGroup" and g.display_name not in self.SYSTEM_GROUPS
]
logger.debug(f"Found {len(workspace_groups)} workspace groups")
return workspace_groups

def _list_account_groups(self) -> list[iam.Group]:
# TODO: we should avoid using this method, as it's not documented
# unfortunately, there's no other way to consistently get the list of account groups
logger.debug("Listing account groups...")
account_groups = [
iam.Group.from_dict(r)
for r in self._ws.api_client.do(
"get",
"/api/2.0/account/scim/v2/Groups",
query={
"attributes": "id,displayName,meta",
},
).get("Resources", [])
]
account_groups = [g for g in account_groups if g.display_name not in self.SYSTEM_GROUPS]
logger.debug(f"Found {len(account_groups)} account groups")
return account_groups

# please keep the internal methods below this line

def _find_eligible_groups(self) -> list[str]:
logger.info("Finding eligible groups automatically")
_display_name_filter = " and ".join([f'displayName ne "{group}"' for group in GroupManager.SYSTEM_GROUPS])
ws_groups = list(self._ws.groups.list(attributes="displayName,meta", filter=_display_name_filter))
eligible_groups = [g for g in ws_groups if g.meta.resource_type == "WorkspaceGroup"]
logger.info(f"Found {len(eligible_groups)} eligible groups")
return [g.display_name for g in eligible_groups]
def _get_group(self, group_name, level: GroupLevel) -> iam.Group | None:
relevant_level_groups = self._workspace_groups if level == GroupLevel.WORKSPACE else self._account_groups
for group in relevant_level_groups:
if group.display_name == group_name:
return group

@sleep_and_retry
@limits(calls=100, period=1) # assumption
def _list_account_level_groups(
self, filter: str, attributes: str | None = None, excluded_attributes: str | None = None # noqa: A002
) -> list[Group]:
query = {"filter": filter, "attributes": attributes, "excludedAttributes": excluded_attributes}
response = self._ws.api_client.do("GET", "/api/2.0/account/scim/v2/Groups", query=query)
return [Group.from_dict(v) for v in response.get("Resources", [])]

def _get_group(self, group_name, level: GroupLevel) -> Group | None:
# TODO: calling this can cause issues for SCIM backend, cache groups instead
method = self._ws.groups.list if level == GroupLevel.WORKSPACE else self._list_account_level_groups
query_filter = f"displayName eq '{group_name}'"
attributes = ",".join(["id", "displayName", "meta", "entitlements", "roles", "members"])

group = next(
iter(method(filter=query_filter, attributes=attributes)),
None,
)

return group

def _get_or_create_backup_group(self, source_group_name: str, source_group: Group) -> Group:
def _get_or_create_backup_group(self, source_group_name: str, source_group: iam.Group) -> iam.Group:
backup_group_name = f"{self.config.backup_group_prefix}{source_group_name}"
backup_group = self._get_group(backup_group_name, GroupLevel.WORKSPACE)

Expand All @@ -78,6 +82,7 @@ def _get_or_create_backup_group(self, source_group_name: str, source_group: Grou
roles=source_group.roles,
members=source_group.members,
)
self._workspace_groups.append(backup_group)
logger.info(f"Backup group {backup_group_name} successfully created")

return backup_group
Expand All @@ -101,20 +106,20 @@ def get_group_info(name: str):

def _replace_group(self, migration_info: MigrationGroupInfo):
ws_group = migration_info.workspace
acc_group = migration_info.account

if self._get_group(ws_group.display_name, GroupLevel.WORKSPACE):
logger.info(f"Deleting the workspace-level group {ws_group.display_name} with id {ws_group.id}")
self._ws.groups.delete(ws_group.id)
logger.info(f"Workspace-level group {ws_group.display_name} with id {ws_group.id} was deleted")
else:
logger.warning(f"Workspace-level group {ws_group.display_name} does not exist, skipping")
logger.info(f"Deleting the workspace-level group {ws_group.display_name} with id {ws_group.id}")
self._ws.groups.delete(ws_group.id)

# delete ws_group from the list of workspace groups
self._workspace_groups = [g for g in self._workspace_groups if g.id != ws_group.id]

logger.info(f"Workspace-level group {ws_group.display_name} with id {ws_group.id} was deleted")

self._reflect_account_group_to_workspace(acc_group)
self._reflect_account_group_to_workspace(migration_info.account)

@sleep_and_retry
@limits(calls=5, period=1) # assumption
def _reflect_account_group_to_workspace(self, acc_group: Group) -> None:
def _reflect_account_group_to_workspace(self, acc_group: iam.Group) -> None:
logger.info(f"Reflecting group {acc_group.display_name} to workspace")

# TODO: add OpenAPI spec for it
Expand All @@ -136,11 +141,14 @@ def prepare_groups_in_environment(self):

for g in self.config.selected:
assert g not in self.SYSTEM_GROUPS, f"Cannot migrate system group {g}"
assert self._get_group(g, GroupLevel.WORKSPACE), f"Group {g} not found on the workspace level"
assert self._get_group(g, GroupLevel.ACCOUNT), f"Group {g} not found on the account level"

self._set_migration_groups(self.config.selected)
else:
logger.info("No group listing provided, finding eligible groups automatically")
self._set_migration_groups(groups_names=self._find_eligible_groups())
logger.info("No group listing provided, all available workspace-level groups will be used")
available_group_names = [g.display_name for g in self._workspace_groups]
self._set_migration_groups(groups_names=available_group_names)
logger.info("Environment prepared successfully")

@property
Expand Down
18 changes: 18 additions & 0 deletions tests/integration/test_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from databricks.sdk import WorkspaceClient

from databricks.labs.ucx.config import GroupsConfig
from databricks.labs.ucx.managers.group import GroupLevel, GroupManager


def test_group_listing(ws: WorkspaceClient, make_ucx_group):
ws_group, acc_group = make_ucx_group()
manager = GroupManager(ws, GroupsConfig(selected=[ws_group.display_name]))
assert ws_group.display_name in [g.display_name for g in manager._workspace_groups]
assert acc_group.display_name in [g.display_name for g in manager._account_groups]


def test_id_validity(ws: WorkspaceClient, make_ucx_group):
ws_group, acc_group = make_ucx_group()
manager = GroupManager(ws, GroupsConfig(selected=[ws_group.display_name]))
assert ws_group.id == manager._get_group(ws_group.display_name, GroupLevel.WORKSPACE).id
assert acc_group.id == manager._get_group(acc_group.display_name, GroupLevel.ACCOUNT).id
Loading

0 comments on commit b654a3a

Please sign in to comment.