Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend): enforce SA Token based auth b/w Persistence Agent and Pipeline API Server #9957

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 0 additions & 85 deletions backend/src/agent/persistence/client/fake_namespace.go

This file was deleted.

87 changes: 0 additions & 87 deletions backend/src/agent/persistence/client/kubernetes_core.go

This file was deleted.

37 changes: 0 additions & 37 deletions backend/src/agent/persistence/client/kubernetes_core_fake.go

This file was deleted.

60 changes: 30 additions & 30 deletions backend/src/agent/persistence/client/pipeline_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ package client
import (
"context"
"fmt"
"os"
"strings"
"time"

"github.com/kubeflow/pipelines/backend/src/apiserver/common"
"google.golang.org/grpc/metadata"

api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client"
Expand All @@ -38,8 +36,8 @@ const (
type PipelineClientInterface interface {
ReportWorkflow(workflow util.ExecutionSpec) error
ReportScheduledWorkflow(swf *util.ScheduledWorkflow) error
ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error)
ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error)
ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error)
ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error)
}

type PipelineClient struct {
Expand Down Expand Up @@ -173,17 +171,26 @@ func (p *PipelineClient) ReportScheduledWorkflow(swf *util.ScheduledWorkflow) er

// ReadArtifact reads artifact content from run service. If the artifact is not present, returns
// nil response.
func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error) {
func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error) {
pctx := context.Background()
if user != "" {
pctx = metadata.AppendToOutgoingContext(pctx, getKubeflowUserIDHeader(),
getKubeflowUserIDPrefix()+user)
}
pctx = metadata.AppendToOutgoingContext(pctx, "Authorization",
"Bearer "+p.tokenRefresher.GetToken())

ctx, cancel := context.WithTimeout(pctx, time.Minute)
defer cancel()

response, err := p.runServiceClient.ReadArtifactV1(ctx, request)
if err != nil {
statusCode, _ := status.FromError(err)
if statusCode.Code() == codes.Unauthenticated && strings.Contains(err.Error(), "service account token has expired") {
// If unauthenticated because SA token is expired, re-read/refresh the token and try again
p.tokenRefresher.RefreshToken()
return nil, util.NewCustomError(err, util.CUSTOM_CODE_TRANSIENT,
"Error while reporting workflow resource (code: %v, message: %v): %v",
statusCode.Code(),
statusCode.Message(),
err.Error())
}
// TODO(hongyes): check NotFound error code before skip the error.
return nil, nil
}
Expand All @@ -192,37 +199,30 @@ func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest, user str
}

// ReportRunMetrics reports run metrics to run service.
func (p *PipelineClient) ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error) {
func (p *PipelineClient) ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error) {
pctx := context.Background()
if user != "" {
pctx = metadata.AppendToOutgoingContext(pctx, getKubeflowUserIDHeader(),
getKubeflowUserIDPrefix()+user)
}
pctx = metadata.AppendToOutgoingContext(pctx, "Authorization",
"Bearer "+p.tokenRefresher.GetToken())

ctx, cancel := context.WithTimeout(pctx, time.Minute)
defer cancel()

response, err := p.runServiceClient.ReportRunMetricsV1(ctx, request)
if err != nil {
statusCode, _ := status.FromError(err)
if statusCode.Code() == codes.Unauthenticated && strings.Contains(err.Error(), "service account token has expired") {
// If unauthenticated because SA token is expired, re-read/refresh the token and try again
p.tokenRefresher.RefreshToken()
return nil, util.NewCustomError(err, util.CUSTOM_CODE_TRANSIENT,
"Error while reporting workflow resource (code: %v, message: %v): %v",
statusCode.Code(),
statusCode.Message(),
err.Error())
}
// This call should always succeed unless the run doesn't exist or server is broken. In
// either cases, the job should retry at a later time.
return nil, util.NewCustomError(err, util.CUSTOM_CODE_TRANSIENT,
"Error while reporting metrics (%+v): %+v", request, err)
}
return response, nil
}

// TODO use config file & viper and "github.com/kubeflow/pipelines/backend/src/apiserver/common.GetKubeflowUserIDHeader()"
func getKubeflowUserIDHeader() string {
if value, ok := os.LookupEnv(common.KubeflowUserIDHeader); ok {
return value
}
return common.GoogleIAPUserIdentityHeader
}

// TODO use of viper & viper and "github.com/kubeflow/pipelines/backend/src/apiserver/common.GetKubeflowUserIDPrefix()"
func getKubeflowUserIDPrefix() string {
if value, ok := os.LookupEnv(common.KubeflowUserIDPrefix); ok {
return value
}
return common.GoogleIAPUserIdentityPrefix
}
4 changes: 2 additions & 2 deletions backend/src/agent/persistence/client/pipeline_client_fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ func (p *PipelineClientFake) ReportScheduledWorkflow(swf *util.ScheduledWorkflow
return nil
}

func (p *PipelineClientFake) ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error) {
func (p *PipelineClientFake) ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error) {
if p.err != nil {
return nil, p.err
}
p.readArtifactRequest = request
return p.artifacts[request.String()], nil
}

func (p *PipelineClientFake) ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error) {
func (p *PipelineClientFake) ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error) {
p.reportedMetricsRequest = request
return p.reportMetricsResponseStub, p.reportMetricsErrorStub
}
Expand Down
5 changes: 0 additions & 5 deletions backend/src/agent/persistence/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ func main() {
} else {
swfInformerFactory = swfinformers.NewFilteredSharedInformerFactory(swfClient, time.Second*30, namespace, nil)
}
k8sCoreClient := client.CreateKubernetesCoreOrFatal(DefaultConnectionTimeout, util.ClientParameters{
QPS: clientQPS,
Burst: clientBurst,
})

tokenRefresher := client.NewTokenRefresher(time.Duration(saTokenRefreshIntervalInSecs)*time.Second, nil)
err = tokenRefresher.StartTokenRefreshTicker()
Expand All @@ -122,7 +118,6 @@ func main() {
swfInformerFactory,
execInformer,
pipelineClient,
k8sCoreClient,
util.NewRealTime())

go swfInformerFactory.Start(stopCh)
Expand Down
3 changes: 1 addition & 2 deletions backend/src/agent/persistence/persistence_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ func NewPersistenceAgent(
swfInformerFactory swfinformers.SharedInformerFactory,
execInformer util.ExecutionInformer,
pipelineClient *client.PipelineClient,
k8sCoreClient client.KubernetesCoreInterface,
time util.TimeInterface) *PersistenceAgent {
// obtain references to shared informers
swfInformer := swfInformerFactory.Scheduledworkflow().V1beta1().ScheduledWorkflows()
Expand All @@ -63,7 +62,7 @@ func NewPersistenceAgent(

workflowWorker := worker.NewPersistenceWorker(time, workflowregister.WorkflowKind,
execInformer, true,
worker.NewWorkflowSaver(workflowClient, pipelineClient, k8sCoreClient, ttlSecondsAfterWorkflowFinish))
worker.NewWorkflowSaver(workflowClient, pipelineClient, ttlSecondsAfterWorkflowFinish))

agent := &PersistenceAgent{
swfClient: swfClient,
Expand Down
6 changes: 3 additions & 3 deletions backend/src/agent/persistence/worker/metrics_reporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func NewMetricsReporter(pipelineClient client.PipelineClientInterface) *MetricsR
}

// ReportMetrics reports workflow metrics to pipeline server.
func (r MetricsReporter) ReportMetrics(workflow util.ExecutionSpec, user string) error {
func (r MetricsReporter) ReportMetrics(workflow util.ExecutionSpec) error {
if !workflow.ExecutionStatus().HasMetrics() {
return nil
}
Expand All @@ -52,14 +52,14 @@ func (r MetricsReporter) ReportMetrics(workflow util.ExecutionSpec, user string)
// Skip reporting if the workflow doesn't have the run id label
return nil
}
runMetrics, partialFailures := workflow.ExecutionStatus().CollectionMetrics(r.pipelineClient.ReadArtifact, user)
runMetrics, partialFailures := workflow.ExecutionStatus().CollectionMetrics(r.pipelineClient.ReadArtifact)
if len(runMetrics) == 0 {
return aggregateErrors(partialFailures)
}
reportMetricsResponse, err := r.pipelineClient.ReportRunMetrics(&api.ReportRunMetricsRequest{
RunId: runID,
Metrics: runMetrics,
}, user)
})
if err != nil {
return err
}
Expand Down
Loading