Skip to content

Commit

Permalink
CLI Phase 1 - add UpgradeableTo field to update functionality (#3844)
Browse files Browse the repository at this point in the history
* CLI Phase 1 - add UpgradeableTo field to update functionality

* fix upgradeableTo parameter

* apply suggestions from code review

* add unit tests for upgradeableTo

* fix unit tests

* Disallow refresh_cluster_credentials together with upgradeable_to

* apply code review suggestions

* fix python lint issue
  • Loading branch information
slawande2 authored Oct 2, 2024
1 parent 53565ba commit 9375d82
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 19 deletions.
4 changes: 4 additions & 0 deletions python/az/aro/azext_aro/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
validate_enable_managed_identity,
validate_platform_workload_identities,
validate_cluster_identity,
validate_upgradeable_to_format,
)
from azure.cli.core.commands.parameters import (
name_type,
Expand Down Expand Up @@ -160,6 +161,9 @@ def load_arguments(self, _):
options_list=['--assign-platform-workload-identity', '--assign-platform-wi'],
validator=validate_platform_workload_identities(isCreate=False),
action=AROPlatformWorkloadIdentityAddAction, nargs='+')
c.argument('upgradeable_to', arg_group='Identity', options_list=['--upgradeable-to'],
help='OpenShift version to upgrade to.', is_preview=True,
validator=validate_upgradeable_to_format)

with self.argument_context('aro get-admin-kubeconfig') as c:
c.argument('file',
Expand Down
14 changes: 13 additions & 1 deletion python/az/aro/azext_aro/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@ def validate_client_id(namespace):
raise MutuallyExclusiveArgumentError('Must not specify --client-id when --enable-managed-identity is True') # pylint: disable=line-too-long
if namespace.platform_workload_identities is not None:
raise MutuallyExclusiveArgumentError('Must not specify --client-id when --assign-platform-workload-identity is used') # pylint: disable=line-too-long

try:
uuid.UUID(namespace.client_id)
except ValueError as e:
raise InvalidArgumentValueError(f"Invalid --client-id '{namespace.client_id}'.") from e # pylint: disable=line-too-long

if namespace.client_secret is None or not str(namespace.client_secret):
raise RequiredArgumentMissingError('Must specify --client-secret with --client-id.') # pylint: disable=line-too-long
if namespace.upgradeable_to is not None:
raise MutuallyExclusiveArgumentError('Must not specify --client-id when --upgradeable-to is used.') # pylint: disable=line-too-long


def validate_client_secret(isCreate):
Expand All @@ -65,6 +66,8 @@ def _validate_client_secret(namespace):
raise MutuallyExclusiveArgumentError('Must not specify --client-secret when --assign-platform-workload-identity is used') # pylint: disable=line-too-long
if isCreate and (namespace.client_id is None or not str(namespace.client_id)):
raise RequiredArgumentMissingError('Must specify --client-id with --client-secret.')
if namespace.upgradeable_to is not None:
raise MutuallyExclusiveArgumentError('Must not specify --client-secret when --upgradeable-to is used.') # pylint: disable=line-too-long

return _validate_client_secret

Expand Down Expand Up @@ -281,13 +284,22 @@ def validate_refresh_cluster_credentials(namespace):
return
if namespace.client_secret is not None or namespace.client_id is not None:
raise RequiredArgumentMissingError('--client-id and --client-secret must be not set with --refresh-credentials.') # pylint: disable=line-too-long
if namespace.upgradeable_to is not None:
raise MutuallyExclusiveArgumentError('Must not specify --refresh-credentials when --upgradeable-to is used.') # pylint: disable=line-too-long


def validate_version_format(namespace):
if namespace.version is not None and not re.match(r'^[4-9]{1}\.[0-9]{1,2}\.[0-9]{1,2}$', namespace.version):
raise InvalidArgumentValueError('--version is invalid')


def validate_upgradeable_to_format(namespace):
if not namespace.upgradeable_to:
return
if not re.match(r'^[4-9]{1}\.(1[4-9]|[1-9][0-9])\.[0-9]{1,2}$', namespace.upgradeable_to):
raise InvalidArgumentValueError('--upgradeable-to format is invalid')


def validate_load_balancer_managed_outbound_ip_count(namespace):
if namespace.load_balancer_managed_outbound_ip_count is None:
return
Expand Down
29 changes: 17 additions & 12 deletions python/az/aro/azext_aro/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def aro_update(cmd,
client_secret=None,
platform_workload_identities=None,
load_balancer_managed_outbound_ip_count=None,
upgradeable_to=None,
no_wait=False):
# if we can't read cluster spec, we will not be able to do much. Fail.
oc = client.open_shift_clusters.get(resource_group_name, resource_name)
Expand All @@ -479,20 +480,24 @@ def aro_update(cmd,
if client_id is not None:
oc_update.service_principal_profile.client_id = client_id

if platform_workload_identities is not None:
pwis = {}
for i in oc.platform_workload_identity_profile.platform_workload_identities:
pwis[i.operator_name] = openshiftcluster.PlatformWorkloadIdentity(
operator_name=i.operator_name,
resource_id=i.resource_id
)
if oc.platform_workload_identity_profile is not None:
if platform_workload_identities is not None or upgradeable_to is not None:
oc_update.platform_workload_identity_profile = openshiftcluster.PlatformWorkloadIdentityProfile()

for i in platform_workload_identities:
pwis[i.operator_name] = i
if platform_workload_identities is not None:
pwis = {}
for i in oc.platform_workload_identity_profile.platform_workload_identities:
pwis[i.operator_name] = openshiftcluster.PlatformWorkloadIdentity(
operator_name=i.operator_name,
resource_id=i.resource_id
)

oc_update.platform_workload_identity_profile = openshiftcluster.PlatformWorkloadIdentityProfile(
platform_workload_identities=list(pwis.values())
)
for i in platform_workload_identities:
pwis[i.operator_name] = i

oc_update.platform_workload_identity_profile.platform_workload_identities = list(pwis.values())

oc_update.platform_workload_identity_profile.upgradeable_to = upgradeable_to

if load_balancer_managed_outbound_ip_count is not None:
oc_update.network_profile = openshiftcluster.NetworkProfile()
Expand Down
67 changes: 61 additions & 6 deletions python/az/aro/azext_aro/tests/latest/unit/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
validate_load_balancer_managed_outbound_ip_count,
validate_enable_managed_identity,
validate_platform_workload_identities,
validate_cluster_identity
validate_cluster_identity,
validate_upgradeable_to_format
)
from azure.cli.core.azclierror import (
InvalidArgumentValueError, RequiredArgumentMissingError,
Expand Down Expand Up @@ -112,7 +113,7 @@ def test_validate_cidr(test_description, dummyclass, attribute_to_get_from_objec
),
(
"should not raise any exception when namespace.client_id is a valid input for creating a UUID and namespace.client_secret has a valid str representation",
Mock(client_id="12345678123456781234567812345678", platform_workload_identities=None, client_secret="12345"),
Mock(upgradeable_to=None, client_id="12345678123456781234567812345678", platform_workload_identities=None, client_secret="12345"),
None
)
]
Expand Down Expand Up @@ -171,15 +172,27 @@ def test_validate_client_id(test_description, namespace, expected_exception):
(
"should not raise any exception when isCreate is true and all arguments valid",
True,
Mock(client_id="12345678123456781234567812345678", client_secret="123", platform_workload_identities=None),
Mock(upgradeable_to=None, client_id="12345678123456781234567812345678", client_secret="123", platform_workload_identities=None),
None
),
(
"should not raise any exception when isCreate is false and all arguments valid",
False,
Mock(client_secret="123", platform_workload_identities=None),
Mock(upgradeable_to=None, client_secret="123", platform_workload_identities=None),
None
),
(
"should raise MutuallyExclusiveArgumentError exception when isCreate is true and upgradeable_to, client_id and client_secret are present",
True,
Mock(upgradeable_to="4.14.2", client_id="12345678123456781234567812345678", client_secret="123", platform_workload_identities=None),
MutuallyExclusiveArgumentError
),
(
"should raise MutuallyExclusiveArgumentError exception when isCreate is false and upgradeable_to, client_id and client_secret are present",
False,
Mock(upgradeable_to="4.14.2", client_id="12345678123456781234567812345678", client_secret="123", platform_workload_identities=None),
MutuallyExclusiveArgumentError
),
]


Expand Down Expand Up @@ -804,9 +817,15 @@ def test_validate_worker_vm_disk_size_gb(test_description, namespace, expected_e
),
(
"should not raise any Exception because namespace.client_secret is None and namespace.client_id is None",
Mock(client_secret=None, client_id=None),
Mock(upgradeable_to=None, client_secret=None, client_id=None),
None
)
),
(
"should raise MutuallyExclusiveArgumentError exception because namespace.upgradeable_to is not None",
Mock(upgradeable_to="4.14.2", client_id=None, client_secret=None),
MutuallyExclusiveArgumentError
),

]


Expand Down Expand Up @@ -1276,3 +1295,39 @@ def test_validate_cluster_identity(test_description, namespace, expected_excepti

if expected_identity is not None:
assert (expected_identity == namespace.mi_user_assigned)


test_validate_upgradeable_to_data = [
(
"should not raise any Exception because namespace.upgradeable_to is empty",
Mock(upgradeable_to="", client_id=None, client_secret=None),
None, None
),

(
"should raise InvalidArgumentValueError Exception because upgradeable_to format is invalid",
Mock(upgradeable_to="a", client_id=None, client_secret=None),
InvalidArgumentValueError, "--upgradeable-to is invalid"
),

(
"Should raise InvalidArgumentValueError when --upgradeable-to < 4.14.0",
Mock(upgradeable_to="4.0.4",
client_id=None, client_secret=None),
InvalidArgumentValueError, 'Enabling managed identity requires --upgradeable-to >= 4.14.0'
),

]


@pytest.mark.parametrize(
"test_description, namespace, expected_exception, expected_exception_message",
test_validate_upgradeable_to_data,
ids=[i[0] for i in test_validate_upgradeable_to_data]
)
def test_validate_upgradeable_to_data(test_description, namespace, expected_exception, expected_exception_message):
if expected_exception is None:
validate_upgradeable_to_format(namespace)
else:
with pytest.raises(expected_exception):
validate_upgradeable_to_format(namespace)

0 comments on commit 9375d82

Please sign in to comment.