Skip to content

Commit

Permalink
Merge pull request #3655 from Azure/ARO-8608
Browse files Browse the repository at this point in the history
Add tenant ID to internal apis for CMSI usage
  • Loading branch information
mociarain authored Jul 4, 2024
2 parents 66a4a9b + 786e0cf commit 6b91187
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 54 deletions.
1 change: 1 addition & 0 deletions pkg/api/openshiftcluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -818,4 +818,5 @@ type Identity struct {
Type string `json:"type,omitempty"`
UserAssignedIdentities UserAssignedIdentities `json:"userAssignedIdentities,omitempty"`
IdentityURL string `json:"identityURL,omitempty" mutable:"true"`
TenantID string `json:"tenantId,omitempty" mutable:"true"`
}
45 changes: 29 additions & 16 deletions pkg/frontend/openshiftcluster_putorpatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
"github.com/Azure/ARO-RP/pkg/util/version"
)

var errMissingIdentityURL error = fmt.Errorf("identityURL not provided but required for workload identity cluster")
var errMissingIdentityParameter error = fmt.Errorf("identity parameter not provided but required for workload identity cluster")

func (f *frontend) putOrPatchOpenShiftCluster(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
Expand All @@ -44,19 +44,21 @@ func (f *frontend) putOrPatchOpenShiftCluster(w http.ResponseWriter, r *http.Req
resourceProviderNamespace := chi.URLParam(r, "resourceProviderNamespace")

identityURL := r.Header.Get("x-ms-identity-url")
identityTenantID := r.Header.Get("x-ms-home-tenant-id")

apiVersion := r.URL.Query().Get(api.APIVersionKey)
err := cosmosdb.RetryOnPreconditionFailed(func() error {
var err error
b, err = f._putOrPatchOpenShiftCluster(ctx, log, body, correlationData, systemData, r.URL.Path, originalPath, r.Method, referer, &header, f.apis[apiVersion].OpenShiftClusterConverter, f.apis[apiVersion].OpenShiftClusterStaticValidator, subId, resourceProviderNamespace, apiVersion, identityURL)
b, err = f._putOrPatchOpenShiftCluster(ctx, log, body, correlationData, systemData, r.URL.Path, originalPath, r.Method, referer, &header, f.apis[apiVersion].OpenShiftClusterConverter, f.apis[apiVersion].OpenShiftClusterStaticValidator, subId, resourceProviderNamespace, apiVersion, identityURL, identityTenantID)
return err
})

frontendOperationResultLog(log, r.Method, err)
reply(log, w, header, b, err)
}

func (f *frontend) _putOrPatchOpenShiftCluster(ctx context.Context, log *logrus.Entry, body []byte, correlationData *api.CorrelationData, systemData *api.SystemData, path, originalPath, method, referer string, header *http.Header, converter api.OpenShiftClusterConverter, staticValidator api.OpenShiftClusterStaticValidator, subId, resourceProviderNamespace string, apiVersion string, identityURL string) ([]byte, error) {
// TODO - refactor this function to reduce the number of parameters
func (f *frontend) _putOrPatchOpenShiftCluster(ctx context.Context, log *logrus.Entry, body []byte, correlationData *api.CorrelationData, systemData *api.SystemData, path, originalPath, method, referer string, header *http.Header, converter api.OpenShiftClusterConverter, staticValidator api.OpenShiftClusterStaticValidator, subId, resourceProviderNamespace string, apiVersion string, identityURL string, identityTenantID string) ([]byte, error) {
subscription, err := f.validateSubscriptionState(ctx, path, api.SubscriptionStateRegistered)
if err != nil {
return nil, err
Expand Down Expand Up @@ -96,9 +98,18 @@ func (f *frontend) _putOrPatchOpenShiftCluster(ctx context.Context, log *logrus.
}
}

err = validateIdentityUrl(doc.OpenShiftCluster, identityURL, isCreate)
if err != nil {
return nil, err
if isCreate {
// Persist identity URL and tenant ID only for managed/workload identity cluster create
// We don't support updating cluster managed identity after cluster creation
// TODO - use a common function to check if the cluster is a managed/workload identity cluster
if !(doc.OpenShiftCluster.Properties.ServicePrincipalProfile != nil || doc.OpenShiftCluster.Identity == nil) {
if err := validateIdentityUrl(doc.OpenShiftCluster, identityURL); err != nil {
return nil, err
}
if err := validateIdentityTenantID(doc.OpenShiftCluster, identityTenantID); err != nil {
return nil, err
}
}
}

doc.CorrelationData = correlationData
Expand Down Expand Up @@ -299,24 +310,26 @@ func enrichClusterSystemData(doc *api.OpenShiftClusterDocument, systemData *api.
}
}

func validateIdentityUrl(cluster *api.OpenShiftCluster, identityURL string, isCreate bool) error {
// Don't persist identity URL in non-wimi clusters
if cluster.Properties.ServicePrincipalProfile != nil || cluster.Identity == nil {
return nil
}

func validateIdentityUrl(cluster *api.OpenShiftCluster, identityURL string) error {
if identityURL == "" {
if isCreate {
return errMissingIdentityURL
}
return nil
return fmt.Errorf("%w: %s", errMissingIdentityParameter, "identity URL")
}

cluster.Identity.IdentityURL = identityURL

return nil
}

func validateIdentityTenantID(cluster *api.OpenShiftCluster, identityTenantID string) error {
if identityTenantID == "" {
return fmt.Errorf("%w: %s", errMissingIdentityParameter, "identity tenant ID")
}

cluster.Identity.TenantID = identityTenantID

return nil
}

func (f *frontend) ValidateNewCluster(ctx context.Context, subscription *api.SubscriptionDocument, cluster *api.OpenShiftCluster, staticValidator api.OpenShiftClusterStaticValidator, ext interface{}, path string) error {
err := staticValidator.Static(ext, nil, f.env.Location(), f.env.Domain(), f.env.FeatureIsSet(env.FeatureRequireD2sV3Workers), path)
if err != nil {
Expand Down
78 changes: 40 additions & 38 deletions pkg/frontend/openshiftcluster_putorpatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package frontend
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"
Expand Down Expand Up @@ -3312,71 +3313,72 @@ func TestValidateIdentityUrl(t *testing.T) {
identityURL string
cluster *api.OpenShiftCluster
expected *api.OpenShiftCluster
isCreate bool
wantError error
}{
{
name: "identity URL is empty, is not wi/mi cluster create",
name: "identity URL is empty",
identityURL: "",
cluster: &api.OpenShiftCluster{},
expected: &api.OpenShiftCluster{},
isCreate: false,
wantError: errMissingIdentityParameter,
},
{
name: "identity URL is empty, is wi/mi cluster create",
identityURL: "",
cluster: &api.OpenShiftCluster{},
expected: &api.OpenShiftCluster{},
isCreate: true,
wantError: errMissingIdentityURL,
},
{
name: "cluster is not wi/mi, identityURL passed",
identityURL: "http://foo.bar",
name: "pass - identity URL passed",
cluster: &api.OpenShiftCluster{
Properties: api.OpenShiftClusterProperties{
ServicePrincipalProfile: &api.ServicePrincipalProfile{},
},
Identity: &api.Identity{},
},
identityURL: "http://foo.bar",
expected: &api.OpenShiftCluster{
Properties: api.OpenShiftClusterProperties{
ServicePrincipalProfile: &api.ServicePrincipalProfile{},
Identity: &api.Identity{
IdentityURL: "http://foo.bar",
},
},
isCreate: true,
},
} {
t.Run(tt.name, func(t *testing.T) {
err := validateIdentityUrl(tt.cluster, tt.identityURL)
if !errors.Is(err, tt.wantError) {
t.Error(cmp.Diff(err, tt.wantError))
}

if !reflect.DeepEqual(tt.cluster, tt.expected) {
t.Error(cmp.Diff(tt.cluster, tt.expected))
}
})
}
}

func TestValidateIdentityTenantID(t *testing.T) {
for _, tt := range []struct {
name string
tenantID string
cluster *api.OpenShiftCluster
expected *api.OpenShiftCluster
wantError error
}{
{
name: "cluster is not wi/mi, identityURL not passed",
identityURL: "",
cluster: &api.OpenShiftCluster{
Properties: api.OpenShiftClusterProperties{
ServicePrincipalProfile: &api.ServicePrincipalProfile{},
},
},
expected: &api.OpenShiftCluster{
Properties: api.OpenShiftClusterProperties{
ServicePrincipalProfile: &api.ServicePrincipalProfile{},
},
},
isCreate: true,
name: "tenantID is empty",
tenantID: "",
cluster: &api.OpenShiftCluster{},
expected: &api.OpenShiftCluster{},
wantError: errMissingIdentityParameter,
},
{
name: "pass - identity URL passed on wi/mi cluster",
name: "pass - tenantID passed",
cluster: &api.OpenShiftCluster{
Identity: &api.Identity{},
},
identityURL: "http://foo.bar",
tenantID: "bogus",
expected: &api.OpenShiftCluster{
Identity: &api.Identity{
IdentityURL: "http://foo.bar",
TenantID: "bogus",
},
},
isCreate: true,
},
} {
t.Run(tt.name, func(t *testing.T) {
err := validateIdentityUrl(tt.cluster, tt.identityURL, tt.isCreate)
if err != nil && err != tt.wantError {
err := validateIdentityTenantID(tt.cluster, tt.tenantID)
if !errors.Is(err, tt.wantError) {
t.Error(cmp.Diff(err, tt.wantError))
}

Expand Down

0 comments on commit 6b91187

Please sign in to comment.