From 2cb41efd975df770bd6ff5d56eb5b069234f9eaa Mon Sep 17 00:00:00 2001 From: Diana Atanasova Date: Thu, 16 Jun 2022 15:46:05 +0300 Subject: [PATCH] Add authorization when Persistent Agent communicate with the api-server Persistent Agent authorize itself based ot the namespace and the current user Fixes: #7818 --- .../persistence/client/fake_namespace.go | 85 +++++++++++++++++++ .../persistence/client/kubernetes_core.go | 83 ++++++++++++++++++ .../client/kubernetes_core_fake.go | 37 ++++++++ .../persistence/client/pipeline_client.go | 35 ++++++-- .../client/pipeline_client_fake.go | 7 +- backend/src/agent/persistence/main.go | 9 ++ .../agent/persistence/persistence_agent.go | 3 +- .../persistence/worker/metrics_reporter.go | 14 +-- .../worker/metrics_reporter_test.go | 80 ++++++++++++++--- .../worker/persistence_worker_test.go | 23 +++-- .../persistence/worker/workflow_saver.go | 12 ++- .../persistence/worker/workflow_saver_test.go | 53 ++++++++++-- .../persistence-agent/cluster-role.yaml | 7 ++ .../persistence-agent/deployment-patch.yaml | 4 + ...-pipeline-persistenceagent-deployment.yaml | 4 + .../ml-pipeline-persistenceagent-role.yaml | 8 +- 16 files changed, 416 insertions(+), 48 deletions(-) create mode 100644 backend/src/agent/persistence/client/fake_namespace.go create mode 100644 backend/src/agent/persistence/client/kubernetes_core.go create mode 100644 backend/src/agent/persistence/client/kubernetes_core_fake.go diff --git a/backend/src/agent/persistence/client/fake_namespace.go b/backend/src/agent/persistence/client/fake_namespace.go new file mode 100644 index 000000000000..bbc8c8e0224a --- /dev/null +++ b/backend/src/agent/persistence/client/fake_namespace.go @@ -0,0 +1,85 @@ +package client + +import ( + "context" + "errors" + "github.com/golang/glog" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + types "k8s.io/apimachinery/pkg/types" + watch "k8s.io/apimachinery/pkg/watch" + corev1 "k8s.io/client-go/applyconfigurations/core/v1" +) + +type FakeNamespaceClient struct { + namespace string + user string +} + +func (f *FakeNamespaceClient) SetReturnValues(namespace string, user string) { + f.namespace = namespace + f.user = user +} + +func (f FakeNamespaceClient) Get(ctx context.Context, name string, opts metav1.GetOptions) (*v1.Namespace, error) { + if f.namespace == name && len(f.user) != 0 { + ns := v1.Namespace{ObjectMeta: metav1.ObjectMeta{ + Namespace: f.namespace, + Annotations: map[string]string{ + "owner": f.user, + }, + }} + return &ns, nil + } + return nil, errors.New("failed to get namespace") +} + +func (f FakeNamespaceClient) Create(ctx context.Context, namespace *v1.Namespace, opts metav1.CreateOptions) (*v1.Namespace, error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) Update(ctx context.Context, namespace *v1.Namespace, opts metav1.UpdateOptions) (*v1.Namespace, error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) UpdateStatus(ctx context.Context, namespace *v1.Namespace, opts metav1.UpdateOptions) (*v1.Namespace, error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) Delete(ctx context.Context, name string, opts metav1.DeleteOptions) error { + glog.Error("This fake method is not yet implemented.") + return nil +} + +func (f FakeNamespaceClient) List(ctx context.Context, opts metav1.ListOptions) (*v1.NamespaceList, error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) Watch(ctx context.Context, opts metav1.ListOptions) (watch.Interface, error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) Patch(ctx context.Context, name string, pt types.PatchType, data []byte, opts metav1.PatchOptions, subresources ...string) (result *v1.Namespace, err error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) Apply(ctx context.Context, namespace *corev1.NamespaceApplyConfiguration, opts metav1.ApplyOptions) (result *v1.Namespace, err error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) ApplyStatus(ctx context.Context, namespace *corev1.NamespaceApplyConfiguration, opts metav1.ApplyOptions) (result *v1.Namespace, err error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} + +func (f FakeNamespaceClient) Finalize(ctx context.Context, item *v1.Namespace, opts metav1.UpdateOptions) (*v1.Namespace, error) { + glog.Error("This fake method is not yet implemented.") + return nil, nil +} diff --git a/backend/src/agent/persistence/client/kubernetes_core.go b/backend/src/agent/persistence/client/kubernetes_core.go new file mode 100644 index 000000000000..da44c9449082 --- /dev/null +++ b/backend/src/agent/persistence/client/kubernetes_core.go @@ -0,0 +1,83 @@ +package client + +import ( + "context" + "fmt" + "time" + + "github.com/cenkalti/backoff" + "github.com/golang/glog" + "github.com/pkg/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + v1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/client-go/rest" + + "github.com/kubeflow/pipelines/backend/src/common/util" +) + +type KubernetesCoreInterface interface { + NamespaceClient() v1.NamespaceInterface + GetNamespaceOwner(namespace string) (string, error) +} + +type KubernetesCore struct { + coreV1Client v1.CoreV1Interface +} + +func (c *KubernetesCore) NamespaceClient() v1.NamespaceInterface { + return c.coreV1Client.Namespaces() +} + +func (c *KubernetesCore) GetNamespaceOwner(namespace string) (string, error) { + ns, err := c.NamespaceClient().Get(context.Background(), namespace, metav1.GetOptions{}) + if err != nil { + return "", err + } + owner, ok := ns.Annotations["owner"] + if !ok { + return "", errors.New(fmt.Sprintf("namespace '%v' has no owner in the annotations", namespace)) + } + return owner, nil +} + +func createKubernetesCore(clientParams util.ClientParameters) (KubernetesCoreInterface, error) { + clientSet, err := getKubernetesClientset(clientParams) + if err != nil { + return nil, err + } + return &KubernetesCore{clientSet.CoreV1()}, nil +} + +// CreateKubernetesCoreOrFatal creates a new client for the Kubernetes pod. +func CreateKubernetesCoreOrFatal(initConnectionTimeout time.Duration, clientParams util.ClientParameters) KubernetesCoreInterface { + var client KubernetesCoreInterface + var err error + var operation = func() error { + client, err = createKubernetesCore(clientParams) + return err + } + b := backoff.NewExponentialBackOff() + b.MaxElapsedTime = initConnectionTimeout + err = backoff.Retry(operation, b) + + if err != nil { + glog.Fatalf("Failed to create namespace client. Error: %v", err) + } + return client +} + +func getKubernetesClientset(clientParams util.ClientParameters) (*kubernetes.Clientset, error) { + restConfig, err := rest.InClusterConfig() + if err != nil { + return nil, errors.Wrap(err, "Failed to initialize kubernetes client.") + } + restConfig.QPS = float32(clientParams.QPS) + restConfig.Burst = clientParams.Burst + + clientSet, err := kubernetes.NewForConfig(restConfig) + if err != nil { + return nil, errors.Wrap(err, "Failed to initialize kubernetes client set.") + } + return clientSet, nil +} diff --git a/backend/src/agent/persistence/client/kubernetes_core_fake.go b/backend/src/agent/persistence/client/kubernetes_core_fake.go new file mode 100644 index 000000000000..73fa0e34fef5 --- /dev/null +++ b/backend/src/agent/persistence/client/kubernetes_core_fake.go @@ -0,0 +1,37 @@ +package client + +import ( + "context" + "errors" + "fmt" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v1 "k8s.io/client-go/kubernetes/typed/core/v1" +) + +type KubernetesCoreFake struct { + coreV1ClientFake *FakeNamespaceClient +} + +func (c *KubernetesCoreFake) NamespaceClient() v1.NamespaceInterface { + return c.coreV1ClientFake +} + +func (c *KubernetesCoreFake) GetNamespaceOwner(namespace string) (string, error) { + ns, err := c.NamespaceClient().Get(context.Background(), namespace, metav1.GetOptions{}) + if err != nil { + return "", err + } + owner, ok := ns.Annotations["owner"] + if !ok { + return "", errors.New(fmt.Sprintf("namespace '%v' has no owner in the annotations", namespace)) + } + return owner, nil +} + +func NewKubernetesCoreFake() *KubernetesCoreFake { + return &KubernetesCoreFake{&FakeNamespaceClient{}} +} +func (c *KubernetesCoreFake) Set(namespaceToReturn string, userToReturn string) { + c.coreV1ClientFake.SetReturnValues(namespaceToReturn, userToReturn) +} diff --git a/backend/src/agent/persistence/client/pipeline_client.go b/backend/src/agent/persistence/client/pipeline_client.go index 3884bcd013e5..39b9e2507a8a 100644 --- a/backend/src/agent/persistence/client/pipeline_client.go +++ b/backend/src/agent/persistence/client/pipeline_client.go @@ -17,6 +17,9 @@ package client import ( "context" "fmt" + "github.com/kubeflow/pipelines/backend/src/apiserver/common" + "google.golang.org/grpc/metadata" + "os" "time" api "github.com/kubeflow/pipelines/backend/api/go_client" @@ -33,8 +36,8 @@ const ( type PipelineClientInterface interface { ReportWorkflow(workflow *util.Workflow) error ReportScheduledWorkflow(swf *util.ScheduledWorkflow) error - ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error) - ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error) + ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error) + ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error) } type PipelineClient struct { @@ -139,8 +142,10 @@ func (p *PipelineClient) ReportScheduledWorkflow(swf *util.ScheduledWorkflow) er // ReadArtifact reads artifact content from run service. If the artifact is not present, returns // nil response. -func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) +func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error) { + pctx := metadata.AppendToOutgoingContext(context.Background(), getKubeflowUserIDHeader(), + getKubeflowUserIDPrefix()+user) + ctx, cancel := context.WithTimeout(pctx, time.Minute) defer cancel() response, err := p.runServiceClient.ReadArtifact(ctx, request) @@ -153,8 +158,10 @@ func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest) (*api.Re } // ReportRunMetrics reports run metrics to run service. -func (p *PipelineClient) ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) +func (p *PipelineClient) ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error) { + pctx := metadata.AppendToOutgoingContext(context.Background(), getKubeflowUserIDHeader(), + getKubeflowUserIDPrefix()+user) + ctx, cancel := context.WithTimeout(pctx, time.Minute) defer cancel() response, err := p.runServiceClient.ReportRunMetrics(ctx, request) @@ -166,3 +173,19 @@ func (p *PipelineClient) ReportRunMetrics(request *api.ReportRunMetricsRequest) } return response, nil } + +//TODO use config file & viper and "github.com/kubeflow/pipelines/backend/src/apiserver/common.GetKubeflowUserIDHeader()" +func getKubeflowUserIDHeader() string { + if value, ok := os.LookupEnv(common.KubeflowUserIDHeader); ok { + return value + } + return common.GoogleIAPUserIdentityHeader +} + +//TODO use of viper & viper and "github.com/kubeflow/pipelines/backend/src/apiserver/common.GetKubeflowUserIDPrefix()" +func getKubeflowUserIDPrefix() string { + if value, ok := os.LookupEnv(common.KubeflowUserIDPrefix); ok { + return value + } + return common.GoogleIAPUserIdentityPrefix +} diff --git a/backend/src/agent/persistence/client/pipeline_client_fake.go b/backend/src/agent/persistence/client/pipeline_client_fake.go index 4215478948b9..f87f43353d38 100644 --- a/backend/src/agent/persistence/client/pipeline_client_fake.go +++ b/backend/src/agent/persistence/client/pipeline_client_fake.go @@ -57,12 +57,15 @@ func (p *PipelineClientFake) ReportScheduledWorkflow(swf *util.ScheduledWorkflow return nil } -func (p *PipelineClientFake) ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error) { +func (p *PipelineClientFake) ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error) { + if p.err != nil { + return nil, p.err + } p.readArtifactRequest = request return p.artifacts[request.String()], nil } -func (p *PipelineClientFake) ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error) { +func (p *PipelineClientFake) ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error) { p.reportedMetricsRequest = request return p.reportMetricsResponseStub, p.reportMetricsErrorStub } diff --git a/backend/src/agent/persistence/main.go b/backend/src/agent/persistence/main.go index 3e88065bf728..be6f96f6f380 100644 --- a/backend/src/agent/persistence/main.go +++ b/backend/src/agent/persistence/main.go @@ -63,6 +63,10 @@ const ( clientBurstFlagName = "clientBurst" ) +const ( + DefaultConnectionTimeout = 6 * time.Minute +) + func main() { flag.Parse() @@ -95,6 +99,10 @@ func main() { swfInformerFactory = swfinformers.NewFilteredSharedInformerFactory(swfClient, time.Second*30, namespace, nil) workflowInformerFactory = workflowinformers.NewFilteredSharedInformerFactory(workflowClient, time.Second*30, namespace, nil) } + k8sCoreClient := client.CreateKubernetesCoreOrFatal(DefaultConnectionTimeout, util.ClientParameters{ + QPS: clientQPS, + Burst: clientBurst, + }) pipelineClient, err := client.NewPipelineClient( initializeTimeout, @@ -111,6 +119,7 @@ func main() { swfInformerFactory, workflowInformerFactory, pipelineClient, + k8sCoreClient, util.NewRealTime()) go swfInformerFactory.Start(stopCh) diff --git a/backend/src/agent/persistence/persistence_agent.go b/backend/src/agent/persistence/persistence_agent.go index fdf0e602e24c..14332f43202f 100644 --- a/backend/src/agent/persistence/persistence_agent.go +++ b/backend/src/agent/persistence/persistence_agent.go @@ -47,6 +47,7 @@ func NewPersistenceAgent( swfInformerFactory swfinformers.SharedInformerFactory, workflowInformerFactory workflowinformers.SharedInformerFactory, pipelineClient *client.PipelineClient, + k8sCoreClient client.KubernetesCoreInterface, time util.TimeInterface) *PersistenceAgent { // obtain references to shared informers swfInformer := swfInformerFactory.Scheduledworkflow().V1beta1().ScheduledWorkflows() @@ -64,7 +65,7 @@ func NewPersistenceAgent( workflowWorker := worker.NewPersistenceWorker(time, workflowregister.WorkflowKind, workflowInformer.Informer(), true, - worker.NewWorkflowSaver(workflowClient, pipelineClient, ttlSecondsAfterWorkflowFinish)) + worker.NewWorkflowSaver(workflowClient, pipelineClient, k8sCoreClient, ttlSecondsAfterWorkflowFinish)) agent := &PersistenceAgent{ swfClient: swfClient, diff --git a/backend/src/agent/persistence/worker/metrics_reporter.go b/backend/src/agent/persistence/worker/metrics_reporter.go index 619ba7b91180..d689dd33dcbc 100644 --- a/backend/src/agent/persistence/worker/metrics_reporter.go +++ b/backend/src/agent/persistence/worker/metrics_reporter.go @@ -45,7 +45,7 @@ func NewMetricsReporter(pipelineClient client.PipelineClientInterface) *MetricsR } // ReportMetrics reports workflow metrics to pipeline server. -func (r MetricsReporter) ReportMetrics(workflow *util.Workflow) error { +func (r MetricsReporter) ReportMetrics(workflow *util.Workflow, user string) error { if workflow.Status.Nodes == nil { return nil } @@ -57,7 +57,7 @@ func (r MetricsReporter) ReportMetrics(workflow *util.Workflow) error { runMetrics := []*api.RunMetric{} partialFailures := []error{} for _, nodeStatus := range workflow.Status.Nodes { - nodeMetrics, err := r.collectNodeMetricsOrNil(runID, nodeStatus) + nodeMetrics, err := r.collectNodeMetricsOrNil(runID, nodeStatus, user) if err != nil { partialFailures = append(partialFailures, err) continue @@ -79,7 +79,7 @@ func (r MetricsReporter) ReportMetrics(workflow *util.Workflow) error { reportMetricsResponse, err := r.pipelineClient.ReportRunMetrics(&api.ReportRunMetricsRequest{ RunId: runID, Metrics: runMetrics, - }) + }, user) if err != nil { return err } @@ -89,12 +89,12 @@ func (r MetricsReporter) ReportMetrics(workflow *util.Workflow) error { } func (r MetricsReporter) collectNodeMetricsOrNil( - runID string, nodeStatus workflowapi.NodeStatus) ( + runID string, nodeStatus workflowapi.NodeStatus, user string) ( []*api.RunMetric, error) { if !nodeStatus.Completed() { return nil, nil } - metricsJSON, err := r.readNodeMetricsJSONOrEmpty(runID, nodeStatus) + metricsJSON, err := r.readNodeMetricsJSONOrEmpty(runID, nodeStatus, user) if err != nil || metricsJSON == "" { return nil, err } @@ -126,7 +126,7 @@ func (r MetricsReporter) collectNodeMetricsOrNil( return reportMetricsRequest.GetMetrics(), nil } -func (r MetricsReporter) readNodeMetricsJSONOrEmpty(runID string, nodeStatus workflowapi.NodeStatus) (string, error) { +func (r MetricsReporter) readNodeMetricsJSONOrEmpty(runID string, nodeStatus workflowapi.NodeStatus, user string) (string, error) { if nodeStatus.Outputs == nil || nodeStatus.Outputs.Artifacts == nil { return "", nil // No output artifacts, skip the reporting } @@ -146,7 +146,7 @@ func (r MetricsReporter) readNodeMetricsJSONOrEmpty(runID string, nodeStatus wor NodeId: nodeStatus.ID, ArtifactName: metricsArtifactName, } - artifactResponse, err := r.pipelineClient.ReadArtifact(artifactRequest) + artifactResponse, err := r.pipelineClient.ReadArtifact(artifactRequest, user) if err != nil { return "", err } diff --git a/backend/src/agent/persistence/worker/metrics_reporter_test.go b/backend/src/agent/persistence/worker/metrics_reporter_test.go index 35a0db5b9f4c..c1e117c8ec78 100644 --- a/backend/src/agent/persistence/worker/metrics_reporter_test.go +++ b/backend/src/agent/persistence/worker/metrics_reporter_test.go @@ -16,6 +16,7 @@ package worker import ( "encoding/json" + "errors" "fmt" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -31,6 +32,11 @@ import ( "k8s.io/apimachinery/pkg/types" ) +const ( + NamespaceName = "kf-namespace" + USER = "test-user@example.com" +) + func TestReportMetrics_NoCompletedNode_NoOP(t *testing.T) { pipelineFake := client.NewPipelineClientFake() @@ -51,7 +57,7 @@ func TestReportMetrics_NoCompletedNode_NoOP(t *testing.T) { }, }, }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.Nil(t, err) assert.Nil(t, pipelineFake.GetReportedMetricsRequest()) } @@ -76,7 +82,7 @@ func TestReportMetrics_NoRunID_NoOP(t *testing.T) { }, }, }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.Nil(t, err) assert.Nil(t, pipelineFake.GetReadArtifactRequest()) assert.Nil(t, pipelineFake.GetReportedMetricsRequest()) @@ -103,7 +109,7 @@ func TestReportMetrics_NoArtifact_NoOP(t *testing.T) { }, }, }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.Nil(t, err) assert.Nil(t, pipelineFake.GetReadArtifactRequest()) assert.Nil(t, pipelineFake.GetReportedMetricsRequest()) @@ -133,7 +139,7 @@ func TestReportMetrics_NoMetricsArtifact_NoOP(t *testing.T) { }, }, }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.Nil(t, err) assert.Nil(t, pipelineFake.GetReadArtifactRequest()) assert.Nil(t, pipelineFake.GetReportedMetricsRequest()) @@ -176,9 +182,9 @@ func TestReportMetrics_Succeed(t *testing.T) { Results: []*api.ReportRunMetricsResponse_ReportRunMetricResult{}, }, nil) - err := reporter.ReportMetrics(workflow) + err1 := reporter.ReportMetrics(workflow, USER) - assert.Nil(t, err) + assert.Nil(t, err1) expectedMetricsRequest := &api.ReportRunMetricsRequest{ RunId: "run-1", Metrics: []*api.RunMetric{ @@ -197,7 +203,7 @@ func TestReportMetrics_Succeed(t *testing.T) { got := pipelineFake.GetReportedMetricsRequest() if diff := cmp.Diff(expectedMetricsRequest, got, cmpopts.EquateEmpty(), protocmp.Transform()); diff != "" { t.Errorf("parseRuntimeInfo() = %+v, want %+v\nDiff (-want, +got)\n%s", got, expectedMetricsRequest, diff) - s, _ := json.MarshalIndent(expectedMetricsRequest ,"", " ") + s, _ := json.MarshalIndent(expectedMetricsRequest, "", " ") fmt.Printf("Want %s", s) } } @@ -235,7 +241,7 @@ func TestReportMetrics_EmptyArchive_Fail(t *testing.T) { Data: []byte(artifactData), }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.NotNil(t, err) assert.True(t, util.HasCustomCode(err, util.CUSTOM_CODE_PERMANENT)) @@ -278,7 +284,7 @@ func TestReportMetrics_MultipleFilesInArchive_Fail(t *testing.T) { Data: []byte(artifactData), }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.NotNil(t, err) assert.True(t, util.HasCustomCode(err, util.CUSTOM_CODE_PERMANENT)) @@ -320,7 +326,7 @@ func TestReportMetrics_InvalidMetricsJSON_Fail(t *testing.T) { Data: []byte(artifactData), }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.NotNil(t, err) assert.True(t, util.HasCustomCode(err, util.CUSTOM_CODE_PERMANENT)) @@ -381,7 +387,7 @@ func TestReportMetrics_InvalidMetricsJSON_PartialFail(t *testing.T) { Data: []byte(validArtifactData), }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) // Partial failure is reported while valid metrics are reported. assert.NotNil(t, err) @@ -404,7 +410,7 @@ func TestReportMetrics_InvalidMetricsJSON_PartialFail(t *testing.T) { got := pipelineFake.GetReportedMetricsRequest() if diff := cmp.Diff(expectedMetricsRequest, got, cmpopts.EquateEmpty(), protocmp.Transform()); diff != "" { t.Errorf("parseRuntimeInfo() = %+v, want %+v\nDiff (-want, +got)\n%s", got, expectedMetricsRequest, diff) - s, _ := json.MarshalIndent(expectedMetricsRequest ,"", " ") + s, _ := json.MarshalIndent(expectedMetricsRequest, "", " ") fmt.Printf("Want %s", s) } } @@ -441,7 +447,7 @@ func TestReportMetrics_CorruptedArchiveFile_Fail(t *testing.T) { Data: []byte("invalid tgz content"), }) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.NotNil(t, err) assert.True(t, util.HasCustomCode(err, util.CUSTOM_CODE_PERMANENT)) @@ -505,8 +511,54 @@ func TestReportMetrics_MultiplMetricErrors_TransientErrowWin(t *testing.T) { }, }, nil) - err := reporter.ReportMetrics(workflow) + err := reporter.ReportMetrics(workflow, USER) assert.NotNil(t, err) assert.True(t, util.HasCustomCode(err, util.CUSTOM_CODE_TRANSIENT)) } + +func TestReportMetrics_Unauthorized(t *testing.T) { + pipelineFake := client.NewPipelineClientFake() + reporter := NewMetricsReporter(pipelineFake) + k8sFake := client.NewKubernetesCoreFake() + k8sFake.Set(NamespaceName, USER) + + workflow := util.NewWorkflow(&workflowapi.Workflow{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "MY_NAMESPACE", + Name: "MY_NAME", + UID: types.UID("run-1"), + Labels: map[string]string{util.LabelKeyWorkflowRunId: "run-1"}, + }, + Status: workflowapi.WorkflowStatus{ + Nodes: map[string]workflowapi.NodeStatus{ + "node-1": workflowapi.NodeStatus{ + ID: "node-1", + Phase: workflowapi.NodeSucceeded, + Outputs: &workflowapi.Outputs{ + Artifacts: []workflowapi.Artifact{{Name: "mlpipeline-metrics"}}, + }, + }, + }, + }, + }) + metricsJSON := `{"metrics": [{"name": "accuracy", "numberValue": 0.77}, {"name": "logloss", "numberValue": 1.2}]}` + artifactData, _ := util.ArchiveTgz(map[string]string{"file": metricsJSON}) + pipelineFake.StubArtifact( + &api.ReadArtifactRequest{ + RunId: "run-1", + NodeId: "node-1", + ArtifactName: "mlpipeline-metrics", + }, + &api.ReadArtifactResponse{ + Data: []byte(artifactData), + }) + pipelineFake.StubReportRunMetrics(&api.ReportRunMetricsResponse{ + Results: []*api.ReportRunMetricsResponse_ReportRunMetricResult{}, + }, errors.New("failed to read artifacts")) + + err1 := reporter.ReportMetrics(workflow, USER) + + assert.NotNil(t, err1) + assert.Contains(t, err1.Error(), "failed to read artifacts") +} diff --git a/backend/src/agent/persistence/worker/persistence_worker_test.go b/backend/src/agent/persistence/worker/persistence_worker_test.go index bde3ef7e4e66..e29226d1407e 100644 --- a/backend/src/agent/persistence/worker/persistence_worker_test.go +++ b/backend/src/agent/persistence/worker/persistence_worker_test.go @@ -53,9 +53,11 @@ func TestPersistenceWorker_Success(t *testing.T) { // Set up pipeline client pipelineClient := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) // Set up peristence worker - saver := NewWorkflowSaver(workflowClient, pipelineClient, 100) + saver := NewWorkflowSaver(workflowClient, pipelineClient, k8sClient, 100) eventHandler := NewFakeEventHandler() worker := NewPersistenceWorker( util.NewFakeTimeForEpoch(), @@ -81,11 +83,12 @@ func TestPersistenceWorker_NotFoundError(t *testing.T) { }) workflowClient := client.NewWorkflowClientFake() - // Set up pipeline client + // Set up pipeline client and kubernetes client pipelineClient := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() // Set up peristence worker - saver := NewWorkflowSaver(workflowClient, pipelineClient, 100) + saver := NewWorkflowSaver(workflowClient, pipelineClient, k8sClient, 100) eventHandler := NewFakeEventHandler() worker := NewPersistenceWorker( util.NewFakeTimeForEpoch(), @@ -112,11 +115,12 @@ func TestPersistenceWorker_GetWorklowError(t *testing.T) { workflowClient := client.NewWorkflowClientFake() workflowClient.Put("MY_NAMESPACE", "MY_NAME", nil) - // Set up pipeline client + // Set up pipeline client and kubernetes client pipelineClient := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() // Set up peristence worker - saver := NewWorkflowSaver(workflowClient, pipelineClient, 100) + saver := NewWorkflowSaver(workflowClient, pipelineClient, k8sClient, 100) eventHandler := NewFakeEventHandler() worker := NewPersistenceWorker( util.NewFakeTimeForEpoch(), @@ -148,9 +152,12 @@ func TestPersistenceWorker_ReportWorkflowRetryableError(t *testing.T) { pipelineClient := client.NewPipelineClientFake() pipelineClient.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_TRANSIENT, "My Retriable Error")) + //Set up kubernetes client + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) // Set up peristence worker - saver := NewWorkflowSaver(workflowClient, pipelineClient, 100) + saver := NewWorkflowSaver(workflowClient, pipelineClient, k8sClient, 100) eventHandler := NewFakeEventHandler() worker := NewPersistenceWorker( util.NewFakeTimeForEpoch(), @@ -181,9 +188,11 @@ func TestPersistenceWorker_ReportWorkflowNonRetryableError(t *testing.T) { pipelineClient := client.NewPipelineClientFake() pipelineClient.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_PERMANENT, "My Permanent Error")) + // Set up kubernetes client + k8sClient := client.NewKubernetesCoreFake() // Set up peristence worker - saver := NewWorkflowSaver(workflowClient, pipelineClient, 100) + saver := NewWorkflowSaver(workflowClient, pipelineClient, k8sClient, 100) eventHandler := NewFakeEventHandler() worker := NewPersistenceWorker( util.NewFakeTimeForEpoch(), diff --git a/backend/src/agent/persistence/worker/workflow_saver.go b/backend/src/agent/persistence/worker/workflow_saver.go index 9635d020194d..b79fbeae0921 100644 --- a/backend/src/agent/persistence/worker/workflow_saver.go +++ b/backend/src/agent/persistence/worker/workflow_saver.go @@ -26,15 +26,17 @@ import ( type WorkflowSaver struct { client client.WorkflowClientInterface pipelineClient client.PipelineClientInterface + k8sClient client.KubernetesCoreInterface metricsReporter *MetricsReporter ttlSecondsAfterWorkflowFinish int64 } func NewWorkflowSaver(client client.WorkflowClientInterface, - pipelineClient client.PipelineClientInterface, ttlSecondsAfterWorkflowFinish int64) *WorkflowSaver { + pipelineClient client.PipelineClientInterface, k8sClient client.KubernetesCoreInterface, ttlSecondsAfterWorkflowFinish int64) *WorkflowSaver { return &WorkflowSaver{ client: client, pipelineClient: pipelineClient, + k8sClient: k8sClient, metricsReporter: NewMetricsReporter(pipelineClient), ttlSecondsAfterWorkflowFinish: ttlSecondsAfterWorkflowFinish, } @@ -66,6 +68,12 @@ func (s *WorkflowSaver) Save(key string, namespace string, name string, nowEpoch log.Infof("Skip syncing Workflow (%v): workflow marked as persisted.", name) return nil } + + user, err1 := s.k8sClient.GetNamespaceOwner(namespace) + if err1 != nil { + return util.Wrapf(err1, "Failed get '%v' namespace", namespace) + } + // Save this Workflow to the database. err = s.pipelineClient.ReportWorkflow(wf) retry := util.HasCustomCode(err, util.CUSTOM_CODE_TRANSIENT) @@ -85,5 +93,5 @@ func (s *WorkflowSaver) Save(key string, namespace string, name string, nowEpoch log.WithFields(log.Fields{ "Workflow": name, }).Infof("Syncing Workflow (%v): success, processing complete.", name) - return s.metricsReporter.ReportMetrics(wf) + return s.metricsReporter.ReportMetrics(wf, user) } diff --git a/backend/src/agent/persistence/worker/workflow_saver_test.go b/backend/src/agent/persistence/worker/workflow_saver_test.go index 358f36600c54..10a16b7ccdad 100644 --- a/backend/src/agent/persistence/worker/workflow_saver_test.go +++ b/backend/src/agent/persistence/worker/workflow_saver_test.go @@ -30,6 +30,8 @@ import ( func TestWorkflow_Save_Success(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) workflow := util.NewWorkflow(&workflowapi.Workflow{ ObjectMeta: metav1.ObjectMeta{ @@ -41,7 +43,7 @@ func TestWorkflow_Save_Success(t *testing.T) { workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) @@ -52,8 +54,10 @@ func TestWorkflow_Save_Success(t *testing.T) { func TestWorkflow_Save_NotFoundDuringGet(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) @@ -65,10 +69,12 @@ func TestWorkflow_Save_NotFoundDuringGet(t *testing.T) { func TestWorkflow_Save_ErrorDuringGet(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) workflowFake.Put("MY_NAMESPACE", "MY_NAME", nil) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) @@ -80,6 +86,8 @@ func TestWorkflow_Save_ErrorDuringGet(t *testing.T) { func TestWorkflow_Save_PermanentFailureWhileReporting(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) pipelineFake.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_PERMANENT, "My Permanent Error")) @@ -94,7 +102,7 @@ func TestWorkflow_Save_PermanentFailureWhileReporting(t *testing.T) { workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) @@ -106,6 +114,8 @@ func TestWorkflow_Save_PermanentFailureWhileReporting(t *testing.T) { func TestWorkflow_Save_TransientFailureWhileReporting(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) pipelineFake.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_TRANSIENT, "My Transient Error")) @@ -120,7 +130,7 @@ func TestWorkflow_Save_TransientFailureWhileReporting(t *testing.T) { workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) @@ -132,6 +142,7 @@ func TestWorkflow_Save_TransientFailureWhileReporting(t *testing.T) { func TestWorkflow_Save_SkippedDueToFinalStatue(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() // Add this will result in failure unless reporting is skipped pipelineFake.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_PERMANENT, @@ -150,7 +161,7 @@ func TestWorkflow_Save_SkippedDueToFinalStatue(t *testing.T) { workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) @@ -161,6 +172,8 @@ func TestWorkflow_Save_SkippedDueToFinalStatue(t *testing.T) { func TestWorkflow_Save_FinalStatueNotSkippedDueToExceedTTL(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("MY_NAMESPACE", USER) // Add this will result in failure unless reporting is skipped pipelineFake.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_PERMANENT, @@ -182,7 +195,7 @@ func TestWorkflow_Save_FinalStatueNotSkippedDueToExceedTTL(t *testing.T) { workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 1) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 1) // Sleep 2 seconds to make sure workflow passed TTL time.Sleep(2 * time.Second) @@ -197,6 +210,7 @@ func TestWorkflow_Save_FinalStatueNotSkippedDueToExceedTTL(t *testing.T) { func TestWorkflow_Save_SkippedDDueToMissingRunID(t *testing.T) { workflowFake := client.NewWorkflowClientFake() pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() // Add this will result in failure unless reporting is skipped pipelineFake.SetError(util.NewCustomError(fmt.Errorf("Error"), util.CUSTOM_CODE_PERMANENT, @@ -211,10 +225,33 @@ func TestWorkflow_Save_SkippedDDueToMissingRunID(t *testing.T) { workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) - saver := NewWorkflowSaver(workflowFake, pipelineFake, 100) + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) assert.Equal(t, false, util.HasCustomCode(err, util.CUSTOM_CODE_TRANSIENT)) assert.Equal(t, nil, err) } + +func TestWorkflow_Save_FailedToGetUser(t *testing.T) { + workflowFake := client.NewWorkflowClientFake() + pipelineFake := client.NewPipelineClientFake() + k8sClient := client.NewKubernetesCoreFake() + k8sClient.Set("ORIGINAL_NAMESPACE", USER) + + workflow := util.NewWorkflow(&workflowapi.Workflow{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "MY_NAMESPACE", + Name: "MY_NAME", + Labels: map[string]string{util.LabelKeyWorkflowRunId: "MY_UUID"}, + }, + }) + + workflowFake.Put("MY_NAMESPACE", "MY_NAME", workflow) + + saver := NewWorkflowSaver(workflowFake, pipelineFake, k8sClient, 100) + + err := saver.Save("MY_KEY", "MY_NAMESPACE", "MY_NAME", 20) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("Failed get '%v' namespace", "MY_NAMESPACE")) +} diff --git a/manifests/kustomize/base/installs/multi-user/persistence-agent/cluster-role.yaml b/manifests/kustomize/base/installs/multi-user/persistence-agent/cluster-role.yaml index b3053317b536..ff580968eed4 100644 --- a/manifests/kustomize/base/installs/multi-user/persistence-agent/cluster-role.yaml +++ b/manifests/kustomize/base/installs/multi-user/persistence-agent/cluster-role.yaml @@ -11,6 +11,13 @@ rules: - get - list - watch + - patch +- apiGroups: + - '' + resources: + - namespaces + verbs: + - get - apiGroups: - kubeflow.org resources: diff --git a/manifests/kustomize/base/installs/multi-user/persistence-agent/deployment-patch.yaml b/manifests/kustomize/base/installs/multi-user/persistence-agent/deployment-patch.yaml index 1e165def422e..2e32f26646cc 100644 --- a/manifests/kustomize/base/installs/multi-user/persistence-agent/deployment-patch.yaml +++ b/manifests/kustomize/base/installs/multi-user/persistence-agent/deployment-patch.yaml @@ -11,3 +11,7 @@ spec: - name: NAMESPACE value: '' valueFrom: null + - name: KUBEFLOW_USERID_HEADER + value: kubeflow-userid + - name: KUBEFLOW_USERID_PREFIX + value: "" \ No newline at end of file diff --git a/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-deployment.yaml b/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-deployment.yaml index bc5032e51a85..74c19c9d793e 100644 --- a/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-deployment.yaml +++ b/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-deployment.yaml @@ -25,6 +25,10 @@ spec: value: "86400" - name: NUM_WORKERS value: "2" + - name: KUBEFLOW_USERID_HEADER + value: kubeflow-userid + - name: KUBEFLOW_USERID_PREFIX + value: "" image: gcr.io/ml-pipeline/persistenceagent:dummy imagePullPolicy: IfNotPresent name: ml-pipeline-persistenceagent diff --git a/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-role.yaml b/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-role.yaml index 830ee8b14e7e..39c6a8026bab 100644 --- a/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-role.yaml +++ b/manifests/kustomize/base/pipeline/ml-pipeline-persistenceagent-role.yaml @@ -18,4 +18,10 @@ rules: verbs: - get - list - - watch \ No newline at end of file + - watch +- apiGroups: + - '' + resources: + - namespace + verbs: + - ge \ No newline at end of file