Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ refactor configuration and remove global provider state #71

Merged
merged 2 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 39 additions & 16 deletions src/uc_migration_toolkit/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from databricks.sdk.core import Config
from pydantic import RootModel
from pydantic.dataclasses import dataclass

Expand Down Expand Up @@ -31,33 +32,34 @@ def __post_init__(self):


@dataclass
class WorkspaceAuthConfig:
token: str | None = None
class InventoryConfig:
table: InventoryTable


@dataclass
class ConnectConfig:
# Keep all the fields in sync with databricks.sdk.core.Config
host: str | None = None
account_id: str | None = None
token: str | None = None
client_id: str | None = None
client_secret: str | None = None
azure_client_id: str | None = None
azure_tenant_id: str | None = None
azure_client_secret: str | None = None
azure_environment: str | None = None
cluster_id: str | None = None
profile: str | None = None


@dataclass
class AuthConfig:
workspace: WorkspaceAuthConfig | None = None

class Config:
frozen = True


@dataclass
class InventoryConfig:
table: InventoryTable
debug_headers: bool = False
rate_limit: int | None = None
nfx marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class MigrationConfig:
inventory: InventoryConfig
with_table_acls: bool
groups: GroupsConfig
auth: AuthConfig | None = None
connect: ConnectConfig | None = None
num_threads: int | None = 4
log_level: str | None = "INFO"

Expand All @@ -66,5 +68,26 @@ def __post_init__(self):
msg = "Table ACLS are not yet implemented"
raise NotImplementedError(msg)

def to_databricks_config(self) -> Config:
connect = self.connect
if connect is None:
# default empty config
connect = ConnectConfig()
return Config(
host=connect.host,
account_id=connect.account_id,
token=connect.token,
client_id=connect.client_id,
client_secret=connect.client_secret,
azure_client_id=connect.azure_client_id,
azure_tenant_id=connect.azure_tenant_id,
azure_client_secret=connect.azure_client_secret,
azure_environment=connect.azure_environment,
cluster_id=connect.cluster_id,
profile=connect.profile,
debug_headers=connect.debug_headers,
rate_limit=connect.rate_limit,
)

def to_json(self) -> str:
return RootModel[MigrationConfig](self).model_dump_json(indent=4)
27 changes: 13 additions & 14 deletions src/uc_migration_toolkit/managers/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from databricks.sdk.service.iam import Group

from uc_migration_toolkit.config import GroupsConfig
from uc_migration_toolkit.generic import StrEnum
from uc_migration_toolkit.providers.client import provider
from uc_migration_toolkit.providers.config import provider as config_provider
from uc_migration_toolkit.providers.client import ImprovedWorkspaceClient
from uc_migration_toolkit.providers.groups_info import (
MigrationGroupInfo,
MigrationGroupsProvider,
Expand All @@ -22,17 +22,17 @@ class GroupLevel(StrEnum):
class GroupManager:
SYSTEM_GROUPS: typing.ClassVar[list[str]] = ["users", "admins", "account users"]

def __init__(self):
self.config = config_provider.config.groups
def __init__(self, ws: ImprovedWorkspaceClient, groups: GroupsConfig):
self._ws = ws
self.config = groups
self._migration_groups_provider: MigrationGroupsProvider = MigrationGroupsProvider()

# please keep the internal methods below this line

@staticmethod
def _find_eligible_groups() -> list[str]:
def _find_eligible_groups(self) -> list[str]:
logger.info("Finding eligible groups automatically")
_display_name_filter = " and ".join([f'displayName ne "{group}"' for group in GroupManager.SYSTEM_GROUPS])
ws_groups = list(provider.ws.groups.list(attributes="displayName,meta", filter=_display_name_filter))
ws_groups = list(self._ws.groups.list(attributes="displayName,meta", filter=_display_name_filter))
eligible_groups = [g for g in ws_groups if g.meta.resource_type == "WorkspaceGroup"]
logger.info(f"Found {len(eligible_groups)} eligible groups")
return [g.display_name for g in eligible_groups]
Expand All @@ -55,9 +55,8 @@ def _get_clean_group_info(group: Group, cleanup_keys: list[str] | None = None) -

return group_info

@staticmethod
def _get_group(group_name, level: GroupLevel) -> Group | None:
method = provider.ws.groups.list if level == GroupLevel.WORKSPACE else provider.ws.list_account_level_groups
def _get_group(self, group_name, level: GroupLevel) -> Group | None:
method = self._ws.groups.list if level == GroupLevel.WORKSPACE else self._ws.list_account_level_groups
query_filter = f"displayName eq '{group_name}'"
attributes = ",".join(["id", "displayName", "meta", "entitlements", "roles", "members"])

Expand All @@ -78,7 +77,7 @@ def _get_or_create_backup_group(self, source_group_name: str, source_group: Grou
logger.info(f"Creating backup group {backup_group_name}")
new_group_payload = self._get_clean_group_info(source_group)
new_group_payload["displayName"] = backup_group_name
backup_group = provider.ws.groups.create(request=Group.from_dict(new_group_payload))
backup_group = self._ws.groups.create(request=Group.from_dict(new_group_payload))
logger.info(f"Backup group {backup_group_name} successfully created")

return backup_group
Expand Down Expand Up @@ -106,12 +105,12 @@ def _replace_group(self, migration_info: MigrationGroupInfo):

if self._get_group(ws_group.display_name, GroupLevel.WORKSPACE):
logger.info(f"Deleting the workspace-level group {ws_group.display_name} with id {ws_group.id}")
provider.ws.groups.delete(ws_group.id)
self._ws.groups.delete(ws_group.id)
logger.info(f"Workspace-level group {ws_group.display_name} with id {ws_group.id} was deleted")
else:
logger.warning(f"Workspace-level group {ws_group.display_name} does not exist, skipping")

provider.ws.reflect_account_group_to_workspace(acc_group)
self._ws.reflect_account_group_to_workspace(acc_group)

# please keep the public methods below this line

Expand Down Expand Up @@ -152,7 +151,7 @@ def delete_backup_groups(self):

for migration_info in self.migration_groups_provider.groups:
try:
provider.ws.groups.delete(id=migration_info.backup.id)
self._ws.groups.delete(id=migration_info.backup.id)
except Exception as e:
logger.warning(
f"Failed to delete backup group {migration_info.backup.display_name} "
Expand Down
Loading