diff --git a/pkg/deploy/predeploy.go b/pkg/deploy/predeploy.go index 4a3097a783b..ff223f04fcf 100644 --- a/pkg/deploy/predeploy.go +++ b/pkg/deploy/predeploy.go @@ -10,15 +10,20 @@ import ( "crypto/x509" "encoding/base64" "encoding/json" + "fmt" + "net/http" "path/filepath" "strings" "time" + mgmtcompute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2020-06-01/compute" azkeyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.0/keyvault" mgmtfeatures "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2019-07-01/features" "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/go-autorest/autorest/to" + "k8s.io/apimachinery/pkg/util/wait" + "github.com/Azure/ARO-RP/pkg/api" "github.com/Azure/ARO-RP/pkg/deploy/assets" "github.com/Azure/ARO-RP/pkg/deploy/generator" "github.com/Azure/ARO-RP/pkg/env" @@ -26,9 +31,13 @@ import ( "github.com/Azure/ARO-RP/pkg/util/keyvault" ) -// Rotate the secret on every deploy of the RP if the most recent -// secret is greater than 7 days old -const rotateSecretAfter = time.Hour * 168 +const ( + // Rotate the secret on every deploy of the RP if the most recent + // secret is greater than 7 days old + rotateSecretAfter = time.Hour * 24 * 7 + rpRestartScript = "systemctl restart aro-rp" + gatewayRestartScript = "systemctl restart aro-gateway" +) // PreDeploy deploys managed identity, NSGs and keyvaults, needed for main // deployment @@ -352,6 +361,7 @@ func (d *deployer) deployPreDeploy(ctx context.Context, resourceGroupName, deplo } func (d *deployer) configureServiceSecrets(ctx context.Context) error { + isRotated := false for _, s := range []struct { kv keyvault.Manager secretName string @@ -361,7 +371,8 @@ func (d *deployer) configureServiceSecrets(ctx context.Context) error { {d.serviceKeyvault, env.FrontendEncryptionSecretV2Name, 64}, {d.portalKeyvault, env.PortalServerSessionKeySecretName, 32}, } { - err := d.ensureAndRotateSecret(ctx, s.kv, s.secretName, s.len) + isNew, err := d.ensureAndRotateSecret(ctx, s.kv, s.secretName, s.len) + isRotated = isNew || isRotated if err != nil { return err } @@ -376,26 +387,43 @@ func (d *deployer) configureServiceSecrets(ctx context.Context) error { {d.serviceKeyvault, env.EncryptionSecretName, 32}, {d.serviceKeyvault, env.FrontendEncryptionSecretName, 32}, } { - err := d.ensureSecret(ctx, s.kv, s.secretName, s.len) + isNew, err := d.ensureSecret(ctx, s.kv, s.secretName, s.len) + isRotated = isNew || isRotated if err != nil { return err } } - return d.ensureSecretKey(ctx, d.portalKeyvault, env.PortalServerSSHKeySecretName) + isNew, err := d.ensureSecretKey(ctx, d.portalKeyvault, env.PortalServerSSHKeySecretName) + isRotated = isNew || isRotated + if err != nil { + return err + } + + if isRotated { + err = d.restartOldScalesets(ctx, d.config.GatewayResourceGroupName) + if err != nil { + return err + } + err = d.restartOldScalesets(ctx, d.config.RPResourceGroupName) + if err != nil { + return err + } + } + return nil } -func (d *deployer) ensureAndRotateSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error { +func (d *deployer) ensureAndRotateSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) (isNew bool, err error) { existingSecrets, err := kv.GetSecrets(ctx) if err != nil { - return err + return false, err } for _, secret := range existingSecrets { if filepath.Base(*secret.ID) == secretName { latestVersion, err := kv.GetSecret(ctx, secretName) if err != nil { - return err + return false, err } updatedTime := time.Unix(0, latestVersion.Attributes.Created.Duration().Nanoseconds()).Add(rotateSecretAfter) @@ -403,27 +431,27 @@ func (d *deployer) ensureAndRotateSecret(ctx context.Context, kv keyvault.Manage // do not create a secret if rotateSecretAfter time has // not elapsed since the secret version's creation timestamp if time.Now().Before(updatedTime) { - return nil + return false, nil } } } - return d.createSecret(ctx, kv, secretName, len) + return true, d.createSecret(ctx, kv, secretName, len) } -func (d *deployer) ensureSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error { +func (d *deployer) ensureSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) (isNew bool, err error) { existingSecrets, err := kv.GetSecrets(ctx) if err != nil { - return err + return false, err } for _, secret := range existingSecrets { if filepath.Base(*secret.ID) == secretName { - return nil + return false, nil } } - return d.createSecret(ctx, kv, secretName, len) + return true, d.createSecret(ctx, kv, secretName, len) } func (d *deployer) createSecret(ctx context.Context, kv keyvault.Manager, secretName string, len int) error { @@ -439,25 +467,105 @@ func (d *deployer) createSecret(ctx context.Context, kv keyvault.Manager, secret }) } -func (d *deployer) ensureSecretKey(ctx context.Context, kv keyvault.Manager, secretName string) error { +func (d *deployer) ensureSecretKey(ctx context.Context, kv keyvault.Manager, secretName string) (isNew bool, err error) { existingSecrets, err := kv.GetSecrets(ctx) if err != nil { - return err + return false, err } for _, secret := range existingSecrets { if filepath.Base(*secret.ID) == secretName { - return nil + return false, nil } } key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return err + return false, err } d.log.Infof("setting %s", secretName) - return kv.SetSecret(ctx, secretName, azkeyvault.SecretSetParameters{ + return true, kv.SetSecret(ctx, secretName, azkeyvault.SecretSetParameters{ Value: to.StringPtr(base64.StdEncoding.EncodeToString(x509.MarshalPKCS1PrivateKey(key))), }) } + +func (d *deployer) restartOldScalesets(ctx context.Context, resourceGroupName string) error { + scalesets, err := d.vmss.List(ctx, resourceGroupName) + if err != nil { + return err + } + + for _, vmss := range scalesets { + err = d.restartOldScaleset(ctx, *vmss.Name, resourceGroupName) + if err != nil { + return err + } + } + + return nil +} + +func (d *deployer) restartOldScaleset(ctx context.Context, vmssName string, resourceGroupName string) error { + var restartScript string + switch { + case strings.HasPrefix(vmssName, gatewayVMSSPrefix): + restartScript = gatewayRestartScript + case strings.HasPrefix(vmssName, rpVMSSPrefix): + restartScript = rpRestartScript + default: + return &api.CloudError{ + StatusCode: http.StatusBadRequest, + CloudErrorBody: &api.CloudErrorBody{ + Code: api.CloudErrorCodeInvalidResource, + Message: fmt.Sprintf("provided vmss %s does not match RP or gateway prefix", + vmssName, + ), + }, + } + } + + scalesetVMs, err := d.vmssvms.List(ctx, resourceGroupName, vmssName, "", "", "") + if err != nil { + return err + } + + for _, vm := range scalesetVMs { + d.log.Printf("waiting for restart script to complete on older vmss %s, instance %s", vmssName, *vm.InstanceID) + err = d.vmssvms.RunCommandAndWait(ctx, resourceGroupName, vmssName, *vm.InstanceID, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{restartScript}, + }) + + if err != nil { + return err + } + + // wait for load balancer probe to change the vm health status + time.Sleep(30 * time.Second) + timeoutCtx, cancel := context.WithTimeout(ctx, time.Hour) + defer cancel() + err = d.waitForReadiness(timeoutCtx, resourceGroupName, vmssName, *vm.InstanceID) + if err != nil { + return err + } + } + + return nil +} + +func (d *deployer) waitForReadiness(ctx context.Context, resourceGroupName string, vmssName string, vmInstanceID string) error { + return wait.PollImmediateUntil(10*time.Second, func() (bool, error) { + return d.isVMInstanceHealthy(ctx, resourceGroupName, vmssName, vmInstanceID), nil + }, ctx.Done()) +} + +func (d *deployer) isVMInstanceHealthy(ctx context.Context, resourceGroupName string, vmssName string, vmInstanceID string) bool { + r, err := d.vmssvms.GetInstanceView(ctx, resourceGroupName, vmssName, vmInstanceID) + instanceUnhealthy := r.VMHealth != nil && r.VMHealth.Status != nil && r.VMHealth.Status.Code != nil && *r.VMHealth.Status.Code != "HealthState/healthy" + if err != nil || instanceUnhealthy { + d.log.Printf("instance %s is unhealthy", vmInstanceID) + return false + } + return true +} diff --git a/pkg/deploy/predeploy_test.go b/pkg/deploy/predeploy_test.go new file mode 100644 index 00000000000..d004b5f9b94 --- /dev/null +++ b/pkg/deploy/predeploy_test.go @@ -0,0 +1,1740 @@ +package deploy + +// Copyright (c) Microsoft Corporation. +// Licensed under the Apache License 2.0. + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + mgmtcompute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2020-06-01/compute" + azkeyvault "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.0/keyvault" + mgmtmsi "github.com/Azure/azure-sdk-for-go/services/msi/mgmt/2018-11-30/msi" + mgmtfeatures "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2019-07-01/features" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/autorest/to" + gofrsuuid "github.com/gofrs/uuid" + "github.com/golang/mock/gomock" + "github.com/sirupsen/logrus" + + "github.com/Azure/ARO-RP/pkg/deploy/generator" + "github.com/Azure/ARO-RP/pkg/env" + mock_compute "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/mgmt/compute" + mock_features "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/mgmt/features" + mock_msi "github.com/Azure/ARO-RP/pkg/util/mocks/azureclient/mgmt/msi" + mock_keyvault "github.com/Azure/ARO-RP/pkg/util/mocks/keyvault" + utilerror "github.com/Azure/ARO-RP/test/util/error" +) + +var ( + instanceID = "testID" + rgName = "testRG" + location = "testLocation" + globalRGName = "testRG-global" + subscriptionRGName = "testRG-subscription" + notExistingFileName = "testFile" + existingFileName = generator.FileGatewayProductionPredeploy + existingFileDeploymentName = strings.TrimSuffix(existingFileName, ".json") + secretExists = "secretExists" + noSecretExists = "noSecretExists" + + errGeneric = errors.New("generic error") + deploymentFailedError = &azure.ServiceError{ + Code: "DeploymentFailed", + Details: []map[string]interface{}{{}}, + } + deploymentNotFoundError = autorest.DetailedError{ + Original: &azure.RequestError{ + ServiceError: &azure.ServiceError{ + Code: "DeploymentNotFound", + Details: []map[string]interface{}{{}}, + }, + }, + } + + healthyVMSS = mgmtcompute.VirtualMachineScaleSetVMInstanceView{ + VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ + Status: &mgmtcompute.InstanceViewStatus{Code: to.StringPtr("HealthState/healthy")}, + }, + } + unhealthyVMSS = mgmtcompute.VirtualMachineScaleSetVMInstanceView{ + VMHealth: &mgmtcompute.VirtualMachineHealthStatus{ + Status: &mgmtcompute.InstanceViewStatus{ + Code: to.StringPtr("HealthState/unhealthy"), + }, + }, + } + + nowUnixTime = date.NewUnixTimeFromSeconds(float64(time.Now().Unix())) + newSecretBundle = azkeyvault.SecretBundle{ + Attributes: &azkeyvault.SecretAttributes{Created: &nowUnixTime}, + } + + secretItems = []azkeyvault.SecretItem{{ID: to.StringPtr("secretExists")}} + + vms = []mgmtcompute.VirtualMachineScaleSetVM{{InstanceID: to.StringPtr(instanceID)}} +) + +func TestPreDeploy(t *testing.T) { + ctx := context.Background() + rpRgName := "testRG-aro-rp" + gatewayRgName := "testRG-gwy" + overrideLocation := "overrideTestLocation" + vmssName := rpVMSSPrefix + "test" + + group := mgmtfeatures.ResourceGroup{ + Location: &location, + } + fakeMSIObjectId, _ := gofrsuuid.NewV4() + msi := mgmtmsi.Identity{ + UserAssignedIdentityProperties: &mgmtmsi.UserAssignedIdentityProperties{PrincipalID: &fakeMSIObjectId}, + } + deployment := mgmtfeatures.DeploymentExtended{} + vmsss := []mgmtcompute.VirtualMachineScaleSet{{Name: &vmssName}} + oneMissingSecrets := []string{env.FrontendEncryptionSecretV2Name, env.PortalServerSessionKeySecretName, env.EncryptionSecretName, env.FrontendEncryptionSecretName, env.PortalServerSSHKeySecretName} + oneMissingSecretItems := []azkeyvault.SecretItem{} + for _, secret := range oneMissingSecrets { + oneMissingSecretItems = append(oneMissingSecretItems, azkeyvault.SecretItem{ID: to.StringPtr(secret)}) + } + + type resourceGroups struct { + subscriptionRGName string + globalResourceGroup string + rpResourceGroupName string + gatewayResourceGroupName string + } + type testParams struct { + resourceGroups resourceGroups + location string + instanceID string + vmssName string + restartScript string + overrideLocation string + acrReplicaDisabled bool + } + type mock func(*mock_features.MockDeploymentsClient, *mock_features.MockResourceGroupsClient, *mock_msi.MockUserAssignedIdentitiesClient, *mock_keyvault.MockManager, *mock_compute.MockVirtualMachineScaleSetsClient, *mock_compute.MockVirtualMachineScaleSetVMsClient, testParams) + createOrUpdateAtSubscriptionScopeAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, "rp-global-subscription-"+tp.location, gomock.Any()).Return(returnError) + } + } + createOrUpdateAndWaitMock := func(resourceGroup string, returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAndWait(ctx, resourceGroup, gomock.Any(), gomock.Any()).Return(returnError) + } + } + createOrUpdateMock := func(resourceGroup string, returnResourceGroup mgmtfeatures.ResourceGroup, returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + rg.EXPECT().CreateOrUpdate(ctx, resourceGroup, mgmtfeatures.ResourceGroup{Location: &tp.location}).Return(returnResourceGroup, returnError) + } + } + msiGetMock := func(resourceGroup string, returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + m.EXPECT().Get(ctx, resourceGroup, gomock.Any()).Return(msi, returnError) + } + } + getDeploymentMock := func(resourceGroup string, returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + d.EXPECT().Get(ctx, resourceGroup, gomock.Any()).Return(deployment, returnError) + } + } + getSecretsMock := func(secretItems []azkeyvault.SecretItem, returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + k.EXPECT().GetSecrets(ctx).Return(secretItems, returnError) + } + } + getSecretMock := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + k.EXPECT().GetSecret(ctx, gomock.Any()).Return(newSecretBundle, nil) + } + setSecretMock := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + k.EXPECT().SetSecret(ctx, gomock.Any(), gomock.Any()).Return(nil) + } + vmssListMock := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmss.EXPECT().List(ctx, gomock.Any()).Return(vmsss, nil) + } + vmssVMsListMock := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().List(ctx, gomock.Any(), tp.vmssName, "", "", "").Return(vms, nil) + } + vmRestartMock := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().RunCommandAndWait(ctx, gomock.Any(), tp.vmssName, tp.instanceID, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{tp.restartScript}, + }).Return(nil) + } + instanceViewMock := func(d *mock_features.MockDeploymentsClient, rg *mock_features.MockResourceGroupsClient, m *mock_msi.MockUserAssignedIdentitiesClient, k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().GetInstanceView(gomock.Any(), gomock.Any(), tp.vmssName, tp.instanceID).Return(healthyVMSS, nil) + } + + for _, tt := range []struct { + name string + acrReplicaDisabled bool + testParams testParams + mocks []mock + wantErr string + }{ + { + name: "don't continue if Global Subscription RBAC DeploymentFailed", + testParams: testParams{ + location: location, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(errGeneric), + }, + wantErr: "generic error", + }, + { + name: "don't continue if Global Subscription RBAC Deployment is Successful but SubscriptionResourceGroup creation fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + }, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "don't continue if SubscriptionResourceGroup creation is Successful but GlobalResourceGroup creation fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + }, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "don't continue if GlobalResourceGroup creation is Successful but RPResourceGroup creation fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + }, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "don't continue if RPResourceGroup creation is successful but GatewayResourceGroup creation fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "don't continue if GatewayResourceGroup is successful but rp-subscription template deployment fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "don't continue if rp-subscription template deployment is successful but rp managed identity get fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "don't continue if rp managed identity get is successful but gateway managed identity get fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "don't continue if rpglobal deployment fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "don't continue if rpglobal deployment fails twice with DeploymentFailed", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, deploymentFailedError), createOrUpdateAndWaitMock(globalRGName, deploymentFailedError), + }, + wantErr: `Code="DeploymentFailed" Message="" Details=[{}]`, + }, + { + name: "don't continue if ACR Replication fails", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + overrideLocation: overrideLocation, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), createOrUpdateAndWaitMock(globalRGName, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "don't continue if skipping ACR Replication due to no ACRLocationOverride but failing gateway predeploy", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "don't continue if skipping ACR Replication due to ACRLocationOverride same as GlobalResourceGroupLocation but failing gateway predeploy", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + overrideLocation: location, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "don't continue if skipping ACR Replication due to ACRReplicaDisabled but failing gateway predeploy", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + overrideLocation: overrideLocation, + acrReplicaDisabled: true, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "don't continue gateway predeploy is successful but rp predeploy failed", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + overrideLocation: overrideLocation, + acrReplicaDisabled: true, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, nil), createOrUpdateAndWaitMock(rpRgName, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "get error for the configureServiceSecrets", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + overrideLocation: overrideLocation, + acrReplicaDisabled: true, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), getSecretsMock(oneMissingSecretItems, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "Everything is successful", + testParams: testParams{ + location: location, + resourceGroups: resourceGroups{ + subscriptionRGName: subscriptionRGName, + globalResourceGroup: globalRGName, + rpResourceGroupName: rpRgName, + gatewayResourceGroupName: gatewayRgName, + }, + overrideLocation: overrideLocation, + acrReplicaDisabled: true, + vmssName: vmssName, + instanceID: instanceID, + restartScript: rpRestartScript, + }, + mocks: []mock{ + createOrUpdateAtSubscriptionScopeAndWaitMock(nil), createOrUpdateMock(subscriptionRGName, group, nil), createOrUpdateMock(globalRGName, group, nil), createOrUpdateMock(rpRgName, group, nil), createOrUpdateMock(gatewayRgName, group, nil), createOrUpdateAndWaitMock(subscriptionRGName, nil), createOrUpdateAndWaitMock(rpRgName, nil), msiGetMock(rpRgName, nil), createOrUpdateAndWaitMock(gatewayRgName, nil), msiGetMock(gatewayRgName, nil), createOrUpdateAndWaitMock(globalRGName, nil), getDeploymentMock(gatewayRgName, deploymentNotFoundError), createOrUpdateAndWaitMock(gatewayRgName, nil), createOrUpdateAndWaitMock(rpRgName, nil), getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(oneMissingSecretItems, nil), getSecretMock, getSecretsMock(oneMissingSecretItems, nil), getSecretMock, getSecretsMock(oneMissingSecretItems, nil), getSecretsMock(oneMissingSecretItems, nil), getSecretsMock(oneMissingSecretItems, nil), vmssListMock, vmssVMsListMock, vmRestartMock, instanceViewMock, vmssListMock, vmssVMsListMock, vmRestartMock, instanceViewMock, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + mockResourceGroups := mock_features.NewMockResourceGroupsClient(controller) + mockMSIs := mock_msi.NewMockUserAssignedIdentitiesClient(controller) + mockKV := mock_keyvault.NewMockManager(controller) + mockVMSS := mock_compute.NewMockVirtualMachineScaleSetsClient(controller) + mockVMSSVM := mock_compute.NewMockVirtualMachineScaleSetVMsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + globaldeployments: mockDeployments, + deployments: mockDeployments, + groups: mockResourceGroups, + globalgroups: mockResourceGroups, + userassignedidentities: mockMSIs, + config: &RPConfig{ + Configuration: &Configuration{ + GlobalResourceGroupLocation: &tt.testParams.location, + SubscriptionResourceGroupLocation: &tt.testParams.location, + SubscriptionResourceGroupName: &tt.testParams.resourceGroups.subscriptionRGName, + GlobalResourceGroupName: &tt.testParams.resourceGroups.globalResourceGroup, + ACRLocationOverride: &tt.testParams.overrideLocation, + ACRReplicaDisabled: &tt.testParams.acrReplicaDisabled, + }, + RPResourceGroupName: tt.testParams.resourceGroups.rpResourceGroupName, + GatewayResourceGroupName: tt.testParams.resourceGroups.gatewayResourceGroupName, + Location: tt.testParams.location, + }, + serviceKeyvault: mockKV, + portalKeyvault: mockKV, + vmss: mockVMSS, + vmssvms: mockVMSSVM, + } + + for _, m := range tt.mocks { + m(mockDeployments, mockResourceGroups, mockMSIs, mockKV, mockVMSS, mockVMSSVM, tt.testParams) + } + + err := d.PreDeploy(ctx) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestDeployRPGlobalSubscription(t *testing.T) { + ctx := context.Background() + + type testParams struct { + location string + } + type mock func(*mock_features.MockDeploymentsClient, testParams) + createOrUpdateAtSubscriptionScopeAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAtSubscriptionScopeAndWait(ctx, "rp-global-subscription-"+tp.location, gomock.Any()).Return(returnError) + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantErr string + }{ + { + name: "Don't continue if deployment fails with error other than DeploymentFailed", + testParams: testParams{location: location}, + mocks: []mock{createOrUpdateAtSubscriptionScopeAndWaitMock(errGeneric)}, + wantErr: "generic error", + }, + { + name: "Don't continue if deployment fails with error DeploymentFailed five times", + testParams: testParams{location: location}, + mocks: []mock{createOrUpdateAtSubscriptionScopeAndWaitMock(deploymentFailedError), createOrUpdateAtSubscriptionScopeAndWaitMock(deploymentFailedError), createOrUpdateAtSubscriptionScopeAndWaitMock(deploymentFailedError), createOrUpdateAtSubscriptionScopeAndWaitMock(deploymentFailedError), createOrUpdateAtSubscriptionScopeAndWaitMock(deploymentFailedError)}, + wantErr: `Code="DeploymentFailed" Message="" Details=[{}]`, + }, + { + name: "Pass successfully when deployment is successfulin second attempt", + testParams: testParams{location: location}, + mocks: []mock{createOrUpdateAtSubscriptionScopeAndWaitMock(deploymentFailedError), createOrUpdateAtSubscriptionScopeAndWaitMock(nil)}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + Configuration: &Configuration{ + GlobalResourceGroupLocation: &tt.testParams.location, + }, + Location: tt.testParams.location, + }, + globaldeployments: mockDeployments, + } + + for _, m := range tt.mocks { + m(mockDeployments, tt.testParams) + } + + err := d.deployRPGlobalSubscription(ctx) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestDeployRPSubscription(t *testing.T) { + ctx := context.Background() + + type testParams struct { + resourceGroup string + location string + } + type mock func(*mock_features.MockDeploymentsClient, testParams) + CreateOrUpdateAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAndWait(ctx, tp.resourceGroup, "rp-production-subscription-"+tp.location, gomock.Any()).Return(returnError) + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantErr string + }{ + { + name: "Don't continue if deployment fails", + testParams: testParams{ + location: location, + resourceGroup: subscriptionRGName, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(errGeneric)}, + wantErr: "generic error", + }, + { + name: "Pass successfully when deployment is successful", + testParams: testParams{ + location: location, + resourceGroup: subscriptionRGName, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(nil)}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + Configuration: &Configuration{ + SubscriptionResourceGroupName: &tt.testParams.resourceGroup, + }, + Location: tt.testParams.location, + }, + deployments: mockDeployments, + } + + for _, m := range tt.mocks { + m(mockDeployments, tt.testParams) + } + + err := d.deployRPSubscription(ctx) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestDeployManagedIdentity(t *testing.T) { + ctx := context.Background() + + type testParams struct { + resourceGroup string + deploymentFileName string + deploymentName string + } + type mock func(*mock_features.MockDeploymentsClient, testParams) + CreateOrUpdateAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAndWait(ctx, tp.resourceGroup, tp.deploymentName, gomock.Any()).Return(returnError) + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantErr string + }{ + { + name: "Don't continue if deployment file does not exist", + testParams: testParams{ + deploymentFileName: notExistingFileName, + }, + wantErr: "open " + notExistingFileName + ": file does not exist", + }, + { + name: "Don't continue if deployment fails", + testParams: testParams{ + deploymentFileName: existingFileName, + deploymentName: existingFileDeploymentName, + resourceGroup: rgName, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(errGeneric)}, + wantErr: "generic error", + }, + { + name: "Pass successfully when deployment is successful", + testParams: testParams{ + deploymentFileName: existingFileName, + deploymentName: existingFileDeploymentName, + resourceGroup: rgName, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(nil)}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + Configuration: &Configuration{}, + }, + deployments: mockDeployments, + } + + for _, m := range tt.mocks { + m(mockDeployments, tt.testParams) + } + + err := d.deployManagedIdentity(ctx, tt.testParams.resourceGroup, tt.testParams.deploymentFileName) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestDeployRPGlobal(t *testing.T) { + ctx := context.Background() + rpSPID := "rpSPIDTest" + gwySPID := "gwySPIDTest" + + type testParams struct { + resourceGroup string + location string + rpSPID string + gwySPID string + } + type mock func(*mock_features.MockDeploymentsClient, testParams) + CreateOrUpdateAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAndWait(ctx, tp.resourceGroup, "rp-global-"+tp.location, gomock.Any()).Return(returnError) + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantErr string + }{ + { + name: "Don't continue if deployment fails with error other than DeploymentFailed", + testParams: testParams{ + location: location, + resourceGroup: globalRGName, + rpSPID: rpSPID, + gwySPID: gwySPID, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(errGeneric)}, + wantErr: "generic error", + }, + { + name: "Don't continue if deployment fails with DeploymentFailed error twice", + testParams: testParams{ + location: location, + resourceGroup: globalRGName, + rpSPID: rpSPID, + gwySPID: gwySPID, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(deploymentFailedError), CreateOrUpdateAndWaitMock(deploymentFailedError)}, + wantErr: `Code="DeploymentFailed" Message="" Details=[{}]`, + }, + { + name: "Pass successfully when deployment is successful in second attempt", + testParams: testParams{ + location: location, + resourceGroup: globalRGName, + rpSPID: rpSPID, + gwySPID: gwySPID, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(deploymentFailedError), CreateOrUpdateAndWaitMock(nil)}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + Configuration: &Configuration{ + GlobalResourceGroupName: to.StringPtr(tt.testParams.resourceGroup), + }, + Location: tt.testParams.location, + }, + globaldeployments: mockDeployments, + } + + for _, m := range tt.mocks { + m(mockDeployments, tt.testParams) + } + + err := d.deployRPGlobal(ctx, tt.testParams.rpSPID, tt.testParams.gwySPID) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestDeployRPGlobalACRReplication(t *testing.T) { + ctx := context.Background() + + type testParams struct { + resourceGroup string + location string + } + type mock func(*mock_features.MockDeploymentsClient, testParams) + CreateOrUpdateAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAndWait(ctx, tp.resourceGroup, "rp-global-acr-replication-"+tp.location, gomock.Any()).Return(returnError) + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantErr string + }{ + { + name: "Don't continue if deployment fails", + testParams: testParams{ + location: location, + resourceGroup: globalRGName, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(errGeneric)}, + wantErr: "generic error", + }, + { + name: "Pass when deployment is successful", + testParams: testParams{ + location: location, + resourceGroup: globalRGName, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(nil)}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + Configuration: &Configuration{ + GlobalResourceGroupName: to.StringPtr(tt.testParams.resourceGroup), + }, + Location: tt.testParams.location, + }, + globaldeployments: mockDeployments, + } + + for _, m := range tt.mocks { + m(mockDeployments, tt.testParams) + } + + err := d.deployRPGlobalACRReplication(ctx) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestDeployPreDeploy(t *testing.T) { + ctx := context.Background() + spIDName := "testSPIDName" + spID := "testSPID" + + type testParams struct { + resourceGroup string + deploymentFileName string + deploymentName string + spIDName string + spID string + isCreate bool + } + type mock func(*mock_features.MockDeploymentsClient, testParams) + CreateOrUpdateAndWaitMock := func(returnError error) mock { + return func(d *mock_features.MockDeploymentsClient, tp testParams) { + d.EXPECT().CreateOrUpdateAndWait(ctx, tp.resourceGroup, tp.deploymentName, gomock.Any()).Return(returnError) + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantErr string + }{ + { + name: "Don't continue if deployment file does not exist", + testParams: testParams{ + resourceGroup: rgName, + deploymentFileName: notExistingFileName, + spIDName: spIDName, + spID: spID, + }, + wantErr: "open " + notExistingFileName + ": file does not exist", + }, + { + name: "Don't continue if deployment fails", + testParams: testParams{ + resourceGroup: rgName, + deploymentFileName: existingFileName, + deploymentName: existingFileDeploymentName, + spIDName: spIDName, + spID: spID, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(errGeneric)}, + wantErr: "generic error", + }, + { + name: "Pass when deployment is successful", + testParams: testParams{ + resourceGroup: rgName, + deploymentFileName: existingFileName, + deploymentName: existingFileDeploymentName, + spIDName: spIDName, + spID: spID, + }, + mocks: []mock{CreateOrUpdateAndWaitMock(nil)}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockDeployments := mock_features.NewMockDeploymentsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + Configuration: &Configuration{}, + GatewayResourceGroupName: tt.testParams.resourceGroup, + }, + deployments: mockDeployments, + } + + for _, m := range tt.mocks { + m(mockDeployments, tt.testParams) + } + + err := d.deployPreDeploy(ctx, tt.testParams.resourceGroup, tt.testParams.deploymentFileName, tt.testParams.spIDName, tt.testParams.spID, tt.testParams.isCreate) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestConfigureServiceSecrets(t *testing.T) { + ctx := context.Background() + vmssName := rpVMSSPrefix + "test" + vmsss := []mgmtcompute.VirtualMachineScaleSet{{Name: to.StringPtr(vmssName)}} + oneMissingSecrets := []string{env.FrontendEncryptionSecretV2Name, env.PortalServerSessionKeySecretName, env.EncryptionSecretName, env.FrontendEncryptionSecretName, env.PortalServerSSHKeySecretName} + oneMissingSecretItems := []azkeyvault.SecretItem{} + for _, secret := range oneMissingSecrets { + oneMissingSecretItems = append(oneMissingSecretItems, azkeyvault.SecretItem{ID: to.StringPtr(secret)}) + } + allSecrets := []string{env.EncryptionSecretV2Name, env.FrontendEncryptionSecretV2Name, env.PortalServerSessionKeySecretName, env.EncryptionSecretName, env.FrontendEncryptionSecretName, env.PortalServerSSHKeySecretName} + allSecretItems := []azkeyvault.SecretItem{} + for _, secret := range allSecrets { + allSecretItems = append(allSecretItems, azkeyvault.SecretItem{ID: to.StringPtr(secret)}) + } + + type testParams struct { + vmssName string + instanceID string + resourceGroup string + restartScript string + } + type mock func(*mock_keyvault.MockManager, *mock_compute.MockVirtualMachineScaleSetsClient, *mock_compute.MockVirtualMachineScaleSetVMsClient, testParams) + getSecretsMock := func(secretItems []azkeyvault.SecretItem, returnError error) mock { + return func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + k.EXPECT().GetSecrets(ctx).Return(secretItems, returnError) + } + } + getSecretMock := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + k.EXPECT().GetSecret(ctx, gomock.Any()).Return(newSecretBundle, nil) + } + setSecretMock := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + k.EXPECT().SetSecret(ctx, gomock.Any(), gomock.Any()).Return(nil) + } + vmssListMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmss.EXPECT().List(ctx, tp.resourceGroup).Return(vmsss, returnError).AnyTimes() + } + } + vmssVMsListMock := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().List(ctx, tp.resourceGroup, tp.vmssName, "", "", "").Return(vms, nil).AnyTimes() + } + vmRestartMock := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().RunCommandAndWait(ctx, tp.resourceGroup, tp.vmssName, tp.instanceID, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{tp.restartScript}, + }).Return(nil).AnyTimes() + } + instanceViewMock := func(k *mock_keyvault.MockManager, vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().GetInstanceView(gomock.Any(), tp.resourceGroup, tp.vmssName, tp.instanceID).Return(healthyVMSS, nil).AnyTimes() + } + + for _, tt := range []struct { + name string + secretToFind string + testParams testParams + mocks []mock + wantErr string + }{ + { + name: "return error if ensureAndRotateSecret fails", + mocks: []mock{ + getSecretsMock(allSecretItems, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "return error if ensureAndRotateSecret passes without rotating any secret but ensureSecret fails", + mocks: []mock{ + getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "return error if ensureAndRotateSecret passes with rotating a missing secret but ensureSecret fails", + mocks: []mock{ + getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "return error if ensureAndRotateSecret, ensureSecret passes without rotating a secret but ensureSecretKey fails", + mocks: []mock{ + getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, errGeneric), + }, + wantErr: "generic error", + }, + { + name: "return nil if ensureAndRotateSecret, ensureSecret, ensureSecretKey passes without rotating a secret", + mocks: []mock{ + getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), + }, + }, + { + name: "return error if ensureAndRotateSecret, ensureSecret, ensureSecretKey passes with rotating secret in each ensure function call but restartoldscaleset failing", + testParams: testParams{ + vmssName: vmssName, + instanceID: instanceID, + resourceGroup: rgName, + }, + mocks: []mock{ + getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), vmssListMock(errGeneric), + }, + wantErr: "generic error", + }, + { + name: "return nil if ensureAndRotateSecret, ensureSecret, ensureSecretKey passes with rotating secret and restartoldscaleset passess successfully", + testParams: testParams{ + vmssName: vmssName, + instanceID: instanceID, + resourceGroup: rgName, + restartScript: rpRestartScript, + }, + mocks: []mock{ + getSecretsMock(oneMissingSecretItems, nil), setSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretMock, getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), getSecretsMock(allSecretItems, nil), vmssListMock(nil), vmssVMsListMock, vmRestartMock, instanceViewMock, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockKV := mock_keyvault.NewMockManager(controller) + mockVMSS := mock_compute.NewMockVirtualMachineScaleSetsClient(controller) + mockVMSSVM := mock_compute.NewMockVirtualMachineScaleSetVMsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + config: &RPConfig{ + RPResourceGroupName: tt.testParams.resourceGroup, + GatewayResourceGroupName: tt.testParams.resourceGroup, + }, + serviceKeyvault: mockKV, + portalKeyvault: mockKV, + vmss: mockVMSS, + vmssvms: mockVMSSVM, + } + + for _, m := range tt.mocks { + m(mockKV, mockVMSS, mockVMSSVM, tt.testParams) + } + + err := d.configureServiceSecrets(ctx) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestEnsureAndRotateSecret(t *testing.T) { + ctx := context.Background() + oldUnixTime := date.NewUnixTimeFromSeconds(float64(time.Now().Add(-rotateSecretAfter).Unix())) + oldSecretBundle := azkeyvault.SecretBundle{ + Attributes: &azkeyvault.SecretAttributes{Created: &oldUnixTime}, + } + + type testParams struct { + secretToFind string + } + type mock func(*mock_keyvault.MockManager, testParams) + getSecretsMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().GetSecrets(ctx).Return(secretItems, returnError) + } + } + getSecretMock := func(secretBundle azkeyvault.SecretBundle, returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().GetSecret(ctx, tp.secretToFind).Return(secretBundle, returnError) + } + } + setSecretMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().SetSecret(ctx, tp.secretToFind, gomock.Any()).Return(returnError) + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantErr string + wantBool bool + }{ + { + name: "return false and error if GetSecrets fails", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(errGeneric)}, + wantBool: false, + wantErr: "generic error", + }, + { + name: "return false and error if GetSecrets passes but GetSecret fails for the found secret", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(nil), getSecretMock(newSecretBundle, errGeneric)}, + wantBool: false, + wantErr: "generic error", + }, + { + name: "return false and nil if GetSecrets and GetSecret passes and the secret is not too old", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(nil), getSecretMock(newSecretBundle, nil)}, + wantBool: false, + }, + { + name: "return true and error if GetSecrets & GetSecret passes and the secret is old but new secret creation fails", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(nil), getSecretMock(oldSecretBundle, nil), setSecretMock(errGeneric)}, + wantBool: true, + wantErr: "generic error", + }, + { + name: "return true and nil if GetSecrets & GetSecret passes and the secret is old and new secret creation passes", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(nil), getSecretMock(oldSecretBundle, nil), setSecretMock(nil)}, + wantBool: true, + }, + { + name: "return true and nil if the secret is not present and new secret creation passes", + testParams: testParams{secretToFind: noSecretExists}, + mocks: []mock{getSecretsMock(nil), setSecretMock(nil)}, + wantBool: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockKV := mock_keyvault.NewMockManager(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + } + + for _, m := range tt.mocks { + m(mockKV, tt.testParams) + } + + got, err := d.ensureAndRotateSecret(ctx, mockKV, tt.testParams.secretToFind, 8) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + if tt.wantBool != got { + t.Errorf("%#v", got) + } + }) + } +} + +func TestEnsureSecret(t *testing.T) { + ctx := context.Background() + + type testParams struct { + secretToFind string + } + type mock func(*mock_keyvault.MockManager, testParams) + getSecretsMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().GetSecrets(ctx).Return(secretItems, returnError) + } + } + setSecretMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().SetSecret(ctx, tp.secretToFind, gomock.Any()).Return(returnError) + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantErr string + wantBool bool + }{ + { + name: "return false and error if GetSecrets fails", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(errGeneric)}, + wantBool: false, + wantErr: "generic error", + }, + { + name: "return false and nil if GetSecrets passes and secret is found", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(nil)}, + wantBool: false, + }, + { + name: "return true and error if GetSecrets passes but secret is not found and new secret creation fails", + testParams: testParams{secretToFind: noSecretExists}, + mocks: []mock{getSecretsMock(nil), setSecretMock(errGeneric)}, + wantBool: true, + wantErr: "generic error", + }, + { + name: "return true and nil if GetSecrets passes but secret is not found and new secret creation also passes", + testParams: testParams{secretToFind: noSecretExists}, + mocks: []mock{getSecretsMock(nil), setSecretMock(nil)}, + wantBool: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockKV := mock_keyvault.NewMockManager(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + } + + for _, m := range tt.mocks { + m(mockKV, tt.testParams) + } + + got, err := d.ensureSecret(ctx, mockKV, tt.testParams.secretToFind, 8) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + if tt.wantBool != got { + t.Errorf("%#v", got) + } + }) + } +} + +func TestCreateSecret(t *testing.T) { + ctx := context.Background() + + type testParams struct { + secretToCreate string + } + type mock func(*mock_keyvault.MockManager, testParams) + setSecretMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().SetSecret(ctx, tp.secretToCreate, gomock.Any()).Return(returnError) + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantErr string + }{ + { + name: "return error if new secret creation fails", + testParams: testParams{ + secretToCreate: noSecretExists, + }, + mocks: []mock{setSecretMock(errGeneric)}, + wantErr: "generic error", + }, + { + name: "return nil new secret creation passes", + testParams: testParams{ + secretToCreate: noSecretExists, + }, + mocks: []mock{setSecretMock(nil)}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockKV := mock_keyvault.NewMockManager(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + } + + for _, m := range tt.mocks { + m(mockKV, tt.testParams) + } + + err := d.createSecret(ctx, mockKV, tt.testParams.secretToCreate, 8) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestEnsureSecretKey(t *testing.T) { + ctx := context.Background() + + type testParams struct { + secretToFind string + } + type mock func(*mock_keyvault.MockManager, testParams) + getSecretsMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().GetSecrets(ctx).Return(secretItems, returnError) + } + } + setSecretMock := func(returnError error) mock { + return func(k *mock_keyvault.MockManager, tp testParams) { + k.EXPECT().SetSecret(ctx, tp.secretToFind, gomock.Any()).Return(returnError) + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantErr string + wantBool bool + }{ + { + name: "return false and error if GetSecrets fails", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(errGeneric)}, + wantBool: false, + wantErr: "generic error", + }, + { + name: "return false and nil if GetSecrets passes and secret is found", + testParams: testParams{secretToFind: secretExists}, + mocks: []mock{getSecretsMock(nil)}, + wantBool: false, + }, + { + name: "return true and error if GetSecrets passes but secret is not found and new secret creation fails", + testParams: testParams{secretToFind: noSecretExists}, + mocks: []mock{getSecretsMock(nil), setSecretMock(errGeneric)}, + wantBool: true, + wantErr: "generic error", + }, + { + name: "return true and nil if GetSecrets passes but secret is not found and new secret creation also passes", + testParams: testParams{secretToFind: noSecretExists}, + mocks: []mock{getSecretsMock(nil), setSecretMock(nil)}, + wantBool: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockKV := mock_keyvault.NewMockManager(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + } + + for _, m := range tt.mocks { + m(mockKV, tt.testParams) + } + + got, err := d.ensureSecretKey(ctx, mockKV, tt.testParams.secretToFind) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + if tt.wantBool != got { + t.Errorf("%#v", got) + } + }) + } +} + +func TestRestartOldScalesets(t *testing.T) { + ctx := context.Background() + rpVMSSName := rpVMSSPrefix + "test" + invalidVMSSName := "other-vmss" + invalidVMSSs := []mgmtcompute.VirtualMachineScaleSet{{Name: &invalidVMSSName}} + vmsss := []mgmtcompute.VirtualMachineScaleSet{{Name: &rpVMSSName}} + + type testParams struct { + resourceGroup string + vmssName string + instanceID string + restartScript string + } + type mock func(*mock_compute.MockVirtualMachineScaleSetsClient, *mock_compute.MockVirtualMachineScaleSetVMsClient, testParams) + listVMSSMock := func(returnVMSS []mgmtcompute.VirtualMachineScaleSet, returnError error) mock { + return func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmss.EXPECT().List(ctx, tp.resourceGroup).Return(returnVMSS, returnError) + } + } + listVMSSVMMock := func(returnError error) mock { + return func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().List(ctx, tp.resourceGroup, tp.vmssName, "", "", "").Return(vms, returnError) + } + } + vmRestartMock := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().RunCommandAndWait(ctx, tp.resourceGroup, tp.vmssName, tp.instanceID, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{tp.restartScript}, + }).Return(nil) + } + getInstanceViewMock := func(vmss *mock_compute.MockVirtualMachineScaleSetsClient, vmssvms *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + vmssvms.EXPECT().GetInstanceView(gomock.Any(), tp.resourceGroup, tp.vmssName, tp.instanceID).Return(healthyVMSS, nil) + } + + for _, tt := range []struct { + name string + mocks []mock + testParams testParams + wantErr string + }{ + { + name: "Don't continue if vmss list fails", + testParams: testParams{resourceGroup: rgName}, + mocks: []mock{listVMSSMock(vmsss, errGeneric)}, + wantErr: "generic error", + }, + { + name: "Don't continue if vmss list has an invalid vmss name", + testParams: testParams{resourceGroup: rgName}, + mocks: []mock{listVMSSMock(invalidVMSSs, nil)}, + wantErr: "400: InvalidResource: : provided vmss other-vmss does not match RP or gateway prefix", + }, + { + name: "Don't continue if vmssvms list fails", + testParams: testParams{ + resourceGroup: rgName, + vmssName: rpVMSSName, + }, + mocks: []mock{listVMSSMock(vmsss, nil), listVMSSVMMock(errGeneric)}, + wantErr: "generic error", + }, + { + name: "Restart is successful for the VMs in VMSS", + testParams: testParams{ + resourceGroup: rgName, + vmssName: rpVMSSName, + instanceID: instanceID, + restartScript: rpRestartScript, + }, + mocks: []mock{listVMSSMock(vmsss, nil), listVMSSVMMock(nil), vmRestartMock, getInstanceViewMock}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockVMSS := mock_compute.NewMockVirtualMachineScaleSetsClient(controller) + mockVMSSVM := mock_compute.NewMockVirtualMachineScaleSetVMsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + vmss: mockVMSS, + vmssvms: mockVMSSVM, + } + + for _, m := range tt.mocks { + m(mockVMSS, mockVMSSVM, tt.testParams) + } + + err := d.restartOldScalesets(ctx, tt.testParams.resourceGroup) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestRestartOldScaleset(t *testing.T) { + ctx := context.Background() + otherVMSSName := "other-vmss" + gwyVMSSName := gatewayVMSSPrefix + "test" + rpVMSSName := rpVMSSPrefix + "test" + + type testParams struct { + resourceGroup string + vmssName string + instanceID string + restartScript string + } + type mock func(*mock_compute.MockVirtualMachineScaleSetVMsClient, testParams) + getInstanceViewMock := func(c *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + c.EXPECT().GetInstanceView(gomock.Any(), tp.resourceGroup, tp.vmssName, tp.instanceID).Return(healthyVMSS, nil) + } + listVMSSVMMock := func(returnError error) mock { + return func(c *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + c.EXPECT().List(ctx, tp.resourceGroup, tp.vmssName, "", "", "").Return(vms, returnError) + } + } + vmRestartMock := func(returnError error) mock { + return func(c *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + c.EXPECT().RunCommandAndWait(ctx, tp.resourceGroup, tp.vmssName, tp.instanceID, mgmtcompute.RunCommandInput{ + CommandID: to.StringPtr("RunShellScript"), + Script: &[]string{tp.restartScript}, + }).Return(returnError) + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantErr string + }{ + { + name: "Return an error if the VMSS is not gateway or RP", + testParams: testParams{vmssName: otherVMSSName}, + wantErr: "400: InvalidResource: : provided vmss other-vmss does not match RP or gateway prefix", + }, + { + name: "list VMSS failed", + testParams: testParams{ + resourceGroup: rgName, + vmssName: gwyVMSSName, + instanceID: instanceID, + }, + mocks: []mock{listVMSSVMMock(errGeneric)}, + wantErr: "generic error", + }, + { + name: "gateway restart script failed", + testParams: testParams{ + resourceGroup: rgName, + vmssName: gwyVMSSName, + instanceID: instanceID, + restartScript: gatewayRestartScript, + }, + mocks: []mock{listVMSSVMMock(nil), vmRestartMock(errGeneric)}, + wantErr: "generic error", + }, + { + name: "rp restart script failed", + testParams: testParams{ + resourceGroup: rgName, + vmssName: rpVMSSName, + instanceID: instanceID, + restartScript: rpRestartScript, + }, + mocks: []mock{listVMSSVMMock(nil), vmRestartMock(errGeneric)}, + wantErr: "generic error", + }, + { + name: "restart script passes and wait for readiness is successful", + testParams: testParams{ + resourceGroup: rgName, + vmssName: rpVMSSName, + instanceID: instanceID, + restartScript: rpRestartScript, + }, + mocks: []mock{listVMSSVMMock(nil), vmRestartMock(nil), getInstanceViewMock}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockVMSS := mock_compute.NewMockVirtualMachineScaleSetVMsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + vmssvms: mockVMSS, + } + + for _, m := range tt.mocks { + m(mockVMSS, tt.testParams) + } + + err := d.restartOldScaleset(ctx, tt.testParams.vmssName, tt.testParams.resourceGroup) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestWaitForReadiness(t *testing.T) { + ctxTimeout, cancel := context.WithTimeout(context.Background(), 11*time.Second) + vmssName := "testVMSS" + + type testParams struct { + resourceGroup string + vmssName string + vmInstanceID string + ctx context.Context + cancel context.CancelFunc + } + type mock func(*mock_compute.MockVirtualMachineScaleSetVMsClient, testParams) + getInstanceViewMock := func(vm mgmtcompute.VirtualMachineScaleSetVMInstanceView) mock { + return func(c *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + c.EXPECT().GetInstanceView(tp.ctx, tp.resourceGroup, tp.vmssName, tp.vmInstanceID).Return(vm, nil).AnyTimes() + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantErr string + }{ + { + name: "fail after context times out", + testParams: testParams{ + resourceGroup: rgName, + vmssName: vmssName, + vmInstanceID: instanceID, + ctx: ctxTimeout, + }, + mocks: []mock{getInstanceViewMock(unhealthyVMSS)}, + wantErr: "timed out waiting for the condition", + }, + { + name: "run successfully after confirming healthy status", + testParams: testParams{ + resourceGroup: rgName, + vmssName: vmssName, + vmInstanceID: instanceID, + ctx: ctxTimeout, + cancel: cancel, + }, + mocks: []mock{getInstanceViewMock(healthyVMSS)}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockVMSS := mock_compute.NewMockVirtualMachineScaleSetVMsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + vmssvms: mockVMSS, + } + + for _, m := range tt.mocks { + m(mockVMSS, tt.testParams) + } + + defer cancel() + err := d.waitForReadiness(tt.testParams.ctx, tt.testParams.resourceGroup, tt.testParams.vmssName, tt.testParams.vmInstanceID) + utilerror.AssertErrorMessage(t, err, tt.wantErr) + }) + } +} + +func TestIsVMInstanceHealthy(t *testing.T) { + ctx := context.Background() + vmssName := "testVMSS" + vmInstanceID := "testVMInstanceID" + rpRGName := "testRPRG" + gatewayRGName := "testGatewayRG" + + type testParams struct { + resourceGroup string + vmssName string + instanceID string + } + type mock func(*mock_compute.MockVirtualMachineScaleSetVMsClient, testParams) + getInstanceViewMock := func(vm mgmtcompute.VirtualMachineScaleSetVMInstanceView, returnError error) mock { + return func(c *mock_compute.MockVirtualMachineScaleSetVMsClient, tp testParams) { + c.EXPECT().GetInstanceView(ctx, tp.resourceGroup, tp.vmssName, tp.instanceID).Return(vm, returnError).AnyTimes() + } + } + + for _, tt := range []struct { + name string + testParams testParams + mocks []mock + wantBool bool + }{ + { + name: "return false if GetInstanceView failed for RP resource group", + testParams: testParams{ + resourceGroup: rpRGName, + vmssName: vmssName, + instanceID: vmInstanceID, + }, + mocks: []mock{getInstanceViewMock(healthyVMSS, errGeneric)}, + wantBool: false, + }, + { + name: "return false if GetInstanceView failed for Gateway resource group", + testParams: testParams{ + resourceGroup: gatewayRGName, + vmssName: vmssName, + instanceID: vmInstanceID, + }, + mocks: []mock{getInstanceViewMock(healthyVMSS, errGeneric)}, + wantBool: false, + }, + { + name: "return false if GetInstanceView return unhealthy VM", + testParams: testParams{ + resourceGroup: rpRGName, + vmssName: vmssName, + instanceID: vmInstanceID, + }, + mocks: []mock{getInstanceViewMock(unhealthyVMSS, nil)}, + wantBool: false, + }, + { + name: "return true if GetInstanceView return healthy VM", + testParams: testParams{ + resourceGroup: rpRGName, + vmssName: vmssName, + instanceID: vmInstanceID, + }, + mocks: []mock{getInstanceViewMock(healthyVMSS, nil)}, + wantBool: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + mockVMSS := mock_compute.NewMockVirtualMachineScaleSetVMsClient(controller) + + d := deployer{ + log: logrus.NewEntry(logrus.StandardLogger()), + vmssvms: mockVMSS, + } + + for _, m := range tt.mocks { + m(mockVMSS, tt.testParams) + } + + got := d.isVMInstanceHealthy(ctx, tt.testParams.resourceGroup, tt.testParams.vmssName, tt.testParams.instanceID) + if tt.wantBool != got { + t.Errorf("%#v", got) + } + }) + } +} diff --git a/pkg/deploy/upgrade_gateway.go b/pkg/deploy/upgrade_gateway.go index a5eedb87d07..bd7c35f8820 100644 --- a/pkg/deploy/upgrade_gateway.go +++ b/pkg/deploy/upgrade_gateway.go @@ -40,10 +40,7 @@ func (d *deployer) gatewayWaitForReadiness(ctx context.Context, vmssName string) d.log.Printf("waiting for %s instances to be healthy", vmssName) return wait.PollImmediateUntil(10*time.Second, func() (bool, error) { for _, vm := range scalesetVMs { - r, err := d.vmssvms.GetInstanceView(ctx, d.config.GatewayResourceGroupName, vmssName, *vm.InstanceID) - instanceUnhealthy := r.VMHealth != nil && r.VMHealth.Status != nil && r.VMHealth.Status.Code != nil && *r.VMHealth.Status.Code != "HealthState/healthy" - if err != nil || instanceUnhealthy { - d.log.Printf("instance %s status %s", *vm.InstanceID, *r.VMHealth.Status.Code) + if !d.isVMInstanceHealthy(ctx, d.config.GatewayResourceGroupName, vmssName, *vm.InstanceID) { return false, nil } } diff --git a/pkg/deploy/upgrade_rp.go b/pkg/deploy/upgrade_rp.go index 875d61c28c1..d03c474808d 100644 --- a/pkg/deploy/upgrade_rp.go +++ b/pkg/deploy/upgrade_rp.go @@ -40,10 +40,7 @@ func (d *deployer) rpWaitForReadiness(ctx context.Context, vmssName string) erro d.log.Printf("waiting for %s instances to be healthy", vmssName) return wait.PollImmediateUntil(10*time.Second, func() (bool, error) { for _, vm := range scalesetVMs { - r, err := d.vmssvms.GetInstanceView(ctx, d.config.RPResourceGroupName, vmssName, *vm.InstanceID) - instanceUnhealthy := r.VMHealth != nil && r.VMHealth.Status != nil && r.VMHealth.Status.Code != nil && *r.VMHealth.Status.Code != "HealthState/healthy" - if err != nil || instanceUnhealthy { - d.log.Printf("instance %s status %s", *vm.InstanceID, *r.VMHealth.Status.Code) + if !d.isVMInstanceHealthy(ctx, d.config.RPResourceGroupName, vmssName, *vm.InstanceID) { return false, nil } } diff --git a/pkg/util/azureclient/mgmt/msi/generate.go b/pkg/util/azureclient/mgmt/msi/generate.go new file mode 100644 index 00000000000..f472fe1efcd --- /dev/null +++ b/pkg/util/azureclient/mgmt/msi/generate.go @@ -0,0 +1,8 @@ +package msi + +// Copyright (c) Microsoft Corporation. +// Licensed under the Apache License 2.0. + +//go:generate rm -rf ../../../../util/mocks/$GOPACKAGE +//go:generate go run ../../../../../vendor/github.com/golang/mock/mockgen -destination=../../../../util/mocks/azureclient/mgmt/$GOPACKAGE/$GOPACKAGE.go github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/$GOPACKAGE UserAssignedIdentitiesClient +//go:generate go run ../../../../../vendor/golang.org/x/tools/cmd/goimports -local=github.com/Azure/ARO-RP -e -w ../../../../util/mocks/azureclient/mgmt/$GOPACKAGE/$GOPACKAGE.go diff --git a/pkg/util/mocks/azureclient/mgmt/msi/msi.go b/pkg/util/mocks/azureclient/mgmt/msi/msi.go new file mode 100644 index 00000000000..d66bb62005a --- /dev/null +++ b/pkg/util/mocks/azureclient/mgmt/msi/msi.go @@ -0,0 +1,51 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/msi (interfaces: UserAssignedIdentitiesClient) + +// Package mock_msi is a generated GoMock package. +package mock_msi + +import ( + context "context" + reflect "reflect" + + msi "github.com/Azure/azure-sdk-for-go/services/msi/mgmt/2018-11-30/msi" + gomock "github.com/golang/mock/gomock" +) + +// MockUserAssignedIdentitiesClient is a mock of UserAssignedIdentitiesClient interface. +type MockUserAssignedIdentitiesClient struct { + ctrl *gomock.Controller + recorder *MockUserAssignedIdentitiesClientMockRecorder +} + +// MockUserAssignedIdentitiesClientMockRecorder is the mock recorder for MockUserAssignedIdentitiesClient. +type MockUserAssignedIdentitiesClientMockRecorder struct { + mock *MockUserAssignedIdentitiesClient +} + +// NewMockUserAssignedIdentitiesClient creates a new mock instance. +func NewMockUserAssignedIdentitiesClient(ctrl *gomock.Controller) *MockUserAssignedIdentitiesClient { + mock := &MockUserAssignedIdentitiesClient{ctrl: ctrl} + mock.recorder = &MockUserAssignedIdentitiesClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUserAssignedIdentitiesClient) EXPECT() *MockUserAssignedIdentitiesClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockUserAssignedIdentitiesClient) Get(arg0 context.Context, arg1, arg2 string) (msi.Identity, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2) + ret0, _ := ret[0].(msi.Identity) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockUserAssignedIdentitiesClientMockRecorder) Get(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockUserAssignedIdentitiesClient)(nil).Get), arg0, arg1, arg2) +}