diff --git a/python/az/aro/azext_aro/_params.py b/python/az/aro/azext_aro/_params.py index a3e6a95a8c6..fcb56e536a0 100644 --- a/python/az/aro/azext_aro/_params.py +++ b/python/az/aro/azext_aro/_params.py @@ -16,6 +16,7 @@ from azext_aro._validators import validate_worker_vm_disk_size_gb from azext_aro._validators import validate_refresh_cluster_credentials from azext_aro._validators import validate_version_format +from azext_aro._validators import validate_outbound_type from azure.cli.core.commands.parameters import name_type from azure.cli.core.commands.parameters import get_enum_type, get_three_state_flag from azure.cli.core.commands.parameters import resource_group_name_type @@ -64,7 +65,9 @@ def load_arguments(self, _): c.argument('service_cidr', help='CIDR of service network. Must be a minimum of /18 or larger.', validator=validate_cidr('service_cidr')) - + c.argument('outbound_type', + help='Outbound type of cluster. Must be "Loadbalancer" (default) or "UserDefinedRouting".', + validator=validate_outbound_type) c.argument('disk_encryption_set', help='ResourceID of the DiskEncryptionSet to be used for master and worker VMs.', validator=validate_disk_encryption_set) diff --git a/python/az/aro/azext_aro/_validators.py b/python/az/aro/azext_aro/_validators.py index 5f18ed30e0c..b3a9dff6613 100644 --- a/python/az/aro/azext_aro/_validators.py +++ b/python/az/aro/azext_aro/_validators.py @@ -118,6 +118,13 @@ def validate_pull_secret(namespace): raise InvalidArgumentValueError("Invalid --pull-secret.") from e +def validate_outbound_type(namespace): + outbound_type = getattr(namespace, 'outbound_type') + if outbound_type in {'UserDefinedRouting', 'Loadbalancer', None}: + return + raise InvalidArgumentValueError('Outbound type must be "UserDefinedRouting" or "Loadbalancer"') + + def validate_subnet(key): def _validate_subnet(cmd, namespace): subnet = getattr(namespace, key) diff --git a/python/az/aro/azext_aro/custom.py b/python/az/aro/azext_aro/custom.py index 6c2ec5520cc..6d4291d68c9 100644 --- a/python/az/aro/azext_aro/custom.py +++ b/python/az/aro/azext_aro/custom.py @@ -53,6 +53,7 @@ def aro_create(cmd, # pylint: disable=too-many-locals client_secret=None, pod_cidr=None, service_cidr=None, + outbound_type=None, disk_encryption_set=None, master_encryption_at_host=False, master_vm_size=None, @@ -139,6 +140,11 @@ def aro_create(cmd, # pylint: disable=too-many-locals network_profile=openshiftcluster.NetworkProfile( pod_cidr=pod_cidr or '10.128.0.0/14', service_cidr=service_cidr or '172.30.0.0/16', +<<<<<<< HEAD +======= + outbound_type=outbound_type or 'Loadbalancer', + software_defined_network=software_defined_network or 'OpenShiftSDN' +>>>>>>> 381c20683 (add outbound-type param to az cli extension) ), master_profile=openshiftcluster.MasterProfile( vm_size=master_vm_size or 'Standard_D8s_v3', diff --git a/python/az/aro/azext_aro/tests/latest/unit/test_validators.py b/python/az/aro/azext_aro/tests/latest/unit/test_validators.py index 4e3ce912416..3bcf4991f89 100644 --- a/python/az/aro/azext_aro/tests/latest/unit/test_validators.py +++ b/python/az/aro/azext_aro/tests/latest/unit/test_validators.py @@ -3,7 +3,7 @@ from unittest.mock import Mock, patch from azext_aro._validators import ( - validate_cidr, validate_client_id, validate_client_secret, validate_cluster_resource_group, + validate_cidr, validate_client_id, validate_client_secret, validate_cluster_resource_group, validate_outbound_type, validate_disk_encryption_set, validate_domain, validate_pull_secret, validate_subnet, validate_subnets, validate_visibility, validate_vnet_resource_group_name, validate_worker_count, validate_worker_vm_disk_size_gb, validate_refresh_cluster_credentials ) @@ -775,3 +775,40 @@ def test_validate_refresh_cluster_credentials(test_description, namespace, expec else: with pytest.raises(expected_exception): validate_refresh_cluster_credentials(namespace) + + +test_validate_outbound_type_data = [ + ( + "Should not raise exception when key is Loadbalancer.", + Mock(outbound_type='Loadbalancer'), + None + ), + ( + "Should not raise exception when key is UserDefinedRouting.", + Mock(outbound_type='UserDefinedRouting'), + None + ), + ( + "Should not raise exception when key is empty.", + Mock(outbound_type=None), + None + ), + ( + "Should raise exception when key is a different value.", + Mock(outbound_type='testFail'), + InvalidArgumentValueError + ), +] + + +@pytest.mark.parametrize( + "test_description, namespace, expected_exception", + test_validate_outbound_type_data, + ids=[i[0] for i in test_validate_outbound_type_data] +) +def test_validate_outbound_type(test_description, namespace, expected_exception): + if expected_exception is None: + validate_outbound_type(namespace) + else: + with pytest.raises(expected_exception): + validate_outbound_type(namespace)