Skip to content

Commit

Permalink
Authorize Persistent Agent
Browse files Browse the repository at this point in the history
Persistent Agent authorize itself based ot the namespace and the current user

Fixes: #7818
  • Loading branch information
difince committed Jun 16, 2022
1 parent abbb2ab commit f6f7944
Show file tree
Hide file tree
Showing 16 changed files with 416 additions and 48 deletions.
85 changes: 85 additions & 0 deletions backend/src/agent/persistence/client/fake_namespace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package client

import (
"context"
"errors"
"github.com/golang/glog"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
types "k8s.io/apimachinery/pkg/types"
watch "k8s.io/apimachinery/pkg/watch"
corev1 "k8s.io/client-go/applyconfigurations/core/v1"
)

type FakeNamespaceClient struct {
namespace string
user string
}

func (f *FakeNamespaceClient) SetReturnValues(namespace string, user string) {
f.namespace = namespace
f.user = user
}

func (f FakeNamespaceClient) Get(ctx context.Context, name string, opts metav1.GetOptions) (*v1.Namespace, error) {
if f.namespace == name && len(f.user) != 0 {
ns := v1.Namespace{ObjectMeta: metav1.ObjectMeta{
Namespace: f.namespace,
Annotations: map[string]string{
"owner": f.user,
},
}}
return &ns, nil
}
return nil, errors.New("failed to get namespace")
}

func (f FakeNamespaceClient) Create(ctx context.Context, namespace *v1.Namespace, opts metav1.CreateOptions) (*v1.Namespace, error) {
glog.Error("This fake method is not yet implemented.")
return nil, nil
}

func (f FakeNamespaceClient) Update(ctx context.Context, namespace *v1.Namespace, opts metav1.UpdateOptions) (*v1.Namespace, error) {
glog.Error("This fake method is not yet implemented.")
return nil, nil
}

func (f FakeNamespaceClient) UpdateStatus(ctx context.Context, namespace *v1.Namespace, opts metav1.UpdateOptions) (*v1.Namespace, error) {
glog.Error("This fake method is not yet implemented.")
return nil, nil
}

func (f FakeNamespaceClient) Delete(ctx context.Context, name string, opts metav1.DeleteOptions) error {
glog.Error("This fake method is not yet implemented.")
return nil
}

func (f FakeNamespaceClient) List(ctx context.Context, opts metav1.ListOptions) (*v1.NamespaceList, error) {
glog.Error("This fake method is not yet implemented.")
return nil, nil
}

func (f FakeNamespaceClient) Watch(ctx context.Context, opts metav1.ListOptions) (watch.Interface, error) {
glog.Error("This fake method is not yet implemented.")
return nil, nil
}

func (f FakeNamespaceClient) Patch(ctx context.Context, name string, pt types.PatchType, data []byte, opts metav1.PatchOptions, subresources ...string) (result *v1.Namespace, err error) {
glog.Error("This fake method is not yet implemented.")
return nil, nil
}

func (f FakeNamespaceClient) Apply(ctx context.Context, namespace *corev1.NamespaceApplyConfiguration, opts metav1.ApplyOptions) (result *v1.Namespace, err error) {
glog.Error("This fake method is not yet implemented.")
return nil, nil
}

func (f FakeNamespaceClient) ApplyStatus(ctx context.Context, namespace *corev1.NamespaceApplyConfiguration, opts metav1.ApplyOptions) (result *v1.Namespace, err error) {
glog.Error("This fake method is not yet implemented.")
return nil, nil
}

func (f FakeNamespaceClient) Finalize(ctx context.Context, item *v1.Namespace, opts metav1.UpdateOptions) (*v1.Namespace, error) {
glog.Error("This fake method is not yet implemented.")
return nil, nil
}
83 changes: 83 additions & 0 deletions backend/src/agent/persistence/client/kubernetes_core.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package client

import (
"context"
"fmt"
"time"

"github.com/cenkalti/backoff"
"github.com/golang/glog"
"github.com/pkg/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
v1 "k8s.io/client-go/kubernetes/typed/core/v1"
"k8s.io/client-go/rest"

"github.com/kubeflow/pipelines/backend/src/common/util"
)

type KubernetesCoreInterface interface {
NamespaceClient() v1.NamespaceInterface
GetNamespaceOwner(namespace string) (string, error)
}

type KubernetesCore struct {
coreV1Client v1.CoreV1Interface
}

func (c *KubernetesCore) NamespaceClient() v1.NamespaceInterface {
return c.coreV1Client.Namespaces()
}

func (c *KubernetesCore) GetNamespaceOwner(namespace string) (string, error) {
ns, err := c.NamespaceClient().Get(context.Background(), namespace, metav1.GetOptions{})
if err != nil {
return "", err
}
owner, ok := ns.Annotations["owner"]
if !ok {
return "", errors.New(fmt.Sprintf("namespace '%v' has no owner in the annotations", namespace))
}
return owner, nil
}

func createKubernetesCore(clientParams util.ClientParameters) (KubernetesCoreInterface, error) {
clientSet, err := getKubernetesClientset(clientParams)
if err != nil {
return nil, err
}
return &KubernetesCore{clientSet.CoreV1()}, nil
}

// CreateKubernetesCoreOrFatal creates a new client for the Kubernetes pod.
func CreateKubernetesCoreOrFatal(initConnectionTimeout time.Duration, clientParams util.ClientParameters) KubernetesCoreInterface {
var client KubernetesCoreInterface
var err error
var operation = func() error {
client, err = createKubernetesCore(clientParams)
return err
}
b := backoff.NewExponentialBackOff()
b.MaxElapsedTime = initConnectionTimeout
err = backoff.Retry(operation, b)

if err != nil {
glog.Fatalf("Failed to create namespace client. Error: %v", err)
}
return client
}

func getKubernetesClientset(clientParams util.ClientParameters) (*kubernetes.Clientset, error) {
restConfig, err := rest.InClusterConfig()
if err != nil {
return nil, errors.Wrap(err, "Failed to initialize kubernetes client.")
}
restConfig.QPS = float32(clientParams.QPS)
restConfig.Burst = clientParams.Burst

clientSet, err := kubernetes.NewForConfig(restConfig)
if err != nil {
return nil, errors.Wrap(err, "Failed to initialize kubernetes client set.")
}
return clientSet, nil
}
37 changes: 37 additions & 0 deletions backend/src/agent/persistence/client/kubernetes_core_fake.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package client

import (
"context"
"errors"
"fmt"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
v1 "k8s.io/client-go/kubernetes/typed/core/v1"
)

type KubernetesCoreFake struct {
coreV1ClientFake *FakeNamespaceClient
}

func (c *KubernetesCoreFake) NamespaceClient() v1.NamespaceInterface {
return c.coreV1ClientFake
}

func (c *KubernetesCoreFake) GetNamespaceOwner(namespace string) (string, error) {
ns, err := c.NamespaceClient().Get(context.Background(), namespace, metav1.GetOptions{})
if err != nil {
return "", err
}
owner, ok := ns.Annotations["owner"]
if !ok {
return "", errors.New(fmt.Sprintf("namespace '%v' has no owner in the annotations", namespace))
}
return owner, nil
}

func NewKubernetesCoreFake() *KubernetesCoreFake {
return &KubernetesCoreFake{&FakeNamespaceClient{}}
}
func (c *KubernetesCoreFake) Set(namespaceToReturn string, userToReturn string) {
c.coreV1ClientFake.SetReturnValues(namespaceToReturn, userToReturn)
}
35 changes: 29 additions & 6 deletions backend/src/agent/persistence/client/pipeline_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ package client
import (
"context"
"fmt"
"github.com/kubeflow/pipelines/backend/src/apiserver/common"
"google.golang.org/grpc/metadata"
"os"
"time"

api "github.com/kubeflow/pipelines/backend/api/go_client"
Expand All @@ -33,8 +36,8 @@ const (
type PipelineClientInterface interface {
ReportWorkflow(workflow *util.Workflow) error
ReportScheduledWorkflow(swf *util.ScheduledWorkflow) error
ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error)
ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error)
ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error)
ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error)
}

type PipelineClient struct {
Expand Down Expand Up @@ -139,8 +142,10 @@ func (p *PipelineClient) ReportScheduledWorkflow(swf *util.ScheduledWorkflow) er

// ReadArtifact reads artifact content from run service. If the artifact is not present, returns
// nil response.
func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error) {
pctx := metadata.AppendToOutgoingContext(context.Background(), getKubeflowUserIDHeader(),
getKubeflowUserIDPrefix()+user)
ctx, cancel := context.WithTimeout(pctx, time.Minute)
defer cancel()

response, err := p.runServiceClient.ReadArtifact(ctx, request)
Expand All @@ -153,8 +158,10 @@ func (p *PipelineClient) ReadArtifact(request *api.ReadArtifactRequest) (*api.Re
}

// ReportRunMetrics reports run metrics to run service.
func (p *PipelineClient) ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
func (p *PipelineClient) ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error) {
pctx := metadata.AppendToOutgoingContext(context.Background(), getKubeflowUserIDHeader(),
getKubeflowUserIDPrefix()+user)
ctx, cancel := context.WithTimeout(pctx, time.Minute)
defer cancel()

response, err := p.runServiceClient.ReportRunMetrics(ctx, request)
Expand All @@ -166,3 +173,19 @@ func (p *PipelineClient) ReportRunMetrics(request *api.ReportRunMetricsRequest)
}
return response, nil
}

//TODO use config file & viper and "github.com/kubeflow/pipelines/backend/src/apiserver/common.GetKubeflowUserIDHeader()"
func getKubeflowUserIDHeader() string {
if value, ok := os.LookupEnv(common.KubeflowUserIDHeader); ok {
return value
}
return common.GoogleIAPUserIdentityHeader
}

//TODO use of viper & viper and "github.com/kubeflow/pipelines/backend/src/apiserver/common.GetKubeflowUserIDPrefix()"
func getKubeflowUserIDPrefix() string {
if value, ok := os.LookupEnv(common.KubeflowUserIDPrefix); ok {
return value
}
return common.GoogleIAPUserIdentityPrefix
}
7 changes: 5 additions & 2 deletions backend/src/agent/persistence/client/pipeline_client_fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,15 @@ func (p *PipelineClientFake) ReportScheduledWorkflow(swf *util.ScheduledWorkflow
return nil
}

func (p *PipelineClientFake) ReadArtifact(request *api.ReadArtifactRequest) (*api.ReadArtifactResponse, error) {
func (p *PipelineClientFake) ReadArtifact(request *api.ReadArtifactRequest, user string) (*api.ReadArtifactResponse, error) {
if p.err != nil {
return nil, p.err
}
p.readArtifactRequest = request
return p.artifacts[request.String()], nil
}

func (p *PipelineClientFake) ReportRunMetrics(request *api.ReportRunMetricsRequest) (*api.ReportRunMetricsResponse, error) {
func (p *PipelineClientFake) ReportRunMetrics(request *api.ReportRunMetricsRequest, user string) (*api.ReportRunMetricsResponse, error) {
p.reportedMetricsRequest = request
return p.reportMetricsResponseStub, p.reportMetricsErrorStub
}
Expand Down
9 changes: 9 additions & 0 deletions backend/src/agent/persistence/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ const (
clientBurstFlagName = "clientBurst"
)

const (
DefaultConnectionTimeout = 6 * time.Minute
)

func main() {
flag.Parse()

Expand Down Expand Up @@ -95,6 +99,10 @@ func main() {
swfInformerFactory = swfinformers.NewFilteredSharedInformerFactory(swfClient, time.Second*30, namespace, nil)
workflowInformerFactory = workflowinformers.NewFilteredSharedInformerFactory(workflowClient, time.Second*30, namespace, nil)
}
k8sCoreClient := client.CreateKubernetesCoreOrFatal(DefaultConnectionTimeout, util.ClientParameters{
QPS: clientQPS,
Burst: clientBurst,
})

pipelineClient, err := client.NewPipelineClient(
initializeTimeout,
Expand All @@ -111,6 +119,7 @@ func main() {
swfInformerFactory,
workflowInformerFactory,
pipelineClient,
k8sCoreClient,
util.NewRealTime())

go swfInformerFactory.Start(stopCh)
Expand Down
3 changes: 2 additions & 1 deletion backend/src/agent/persistence/persistence_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func NewPersistenceAgent(
swfInformerFactory swfinformers.SharedInformerFactory,
workflowInformerFactory workflowinformers.SharedInformerFactory,
pipelineClient *client.PipelineClient,
k8sCoreClient client.KubernetesCoreInterface,
time util.TimeInterface) *PersistenceAgent {
// obtain references to shared informers
swfInformer := swfInformerFactory.Scheduledworkflow().V1beta1().ScheduledWorkflows()
Expand All @@ -64,7 +65,7 @@ func NewPersistenceAgent(

workflowWorker := worker.NewPersistenceWorker(time, workflowregister.WorkflowKind,
workflowInformer.Informer(), true,
worker.NewWorkflowSaver(workflowClient, pipelineClient, ttlSecondsAfterWorkflowFinish))
worker.NewWorkflowSaver(workflowClient, pipelineClient, k8sCoreClient, ttlSecondsAfterWorkflowFinish))

agent := &PersistenceAgent{
swfClient: swfClient,
Expand Down
14 changes: 7 additions & 7 deletions backend/src/agent/persistence/worker/metrics_reporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func NewMetricsReporter(pipelineClient client.PipelineClientInterface) *MetricsR
}

// ReportMetrics reports workflow metrics to pipeline server.
func (r MetricsReporter) ReportMetrics(workflow *util.Workflow) error {
func (r MetricsReporter) ReportMetrics(workflow *util.Workflow, user string) error {
if workflow.Status.Nodes == nil {
return nil
}
Expand All @@ -57,7 +57,7 @@ func (r MetricsReporter) ReportMetrics(workflow *util.Workflow) error {
runMetrics := []*api.RunMetric{}
partialFailures := []error{}
for _, nodeStatus := range workflow.Status.Nodes {
nodeMetrics, err := r.collectNodeMetricsOrNil(runID, nodeStatus)
nodeMetrics, err := r.collectNodeMetricsOrNil(runID, nodeStatus, user)
if err != nil {
partialFailures = append(partialFailures, err)
continue
Expand All @@ -79,7 +79,7 @@ func (r MetricsReporter) ReportMetrics(workflow *util.Workflow) error {
reportMetricsResponse, err := r.pipelineClient.ReportRunMetrics(&api.ReportRunMetricsRequest{
RunId: runID,
Metrics: runMetrics,
})
}, user)
if err != nil {
return err
}
Expand All @@ -89,12 +89,12 @@ func (r MetricsReporter) ReportMetrics(workflow *util.Workflow) error {
}

func (r MetricsReporter) collectNodeMetricsOrNil(
runID string, nodeStatus workflowapi.NodeStatus) (
runID string, nodeStatus workflowapi.NodeStatus, user string) (
[]*api.RunMetric, error) {
if !nodeStatus.Completed() {
return nil, nil
}
metricsJSON, err := r.readNodeMetricsJSONOrEmpty(runID, nodeStatus)
metricsJSON, err := r.readNodeMetricsJSONOrEmpty(runID, nodeStatus, user)
if err != nil || metricsJSON == "" {
return nil, err
}
Expand Down Expand Up @@ -126,7 +126,7 @@ func (r MetricsReporter) collectNodeMetricsOrNil(
return reportMetricsRequest.GetMetrics(), nil
}

func (r MetricsReporter) readNodeMetricsJSONOrEmpty(runID string, nodeStatus workflowapi.NodeStatus) (string, error) {
func (r MetricsReporter) readNodeMetricsJSONOrEmpty(runID string, nodeStatus workflowapi.NodeStatus, user string) (string, error) {
if nodeStatus.Outputs == nil || nodeStatus.Outputs.Artifacts == nil {
return "", nil // No output artifacts, skip the reporting
}
Expand All @@ -146,7 +146,7 @@ func (r MetricsReporter) readNodeMetricsJSONOrEmpty(runID string, nodeStatus wor
NodeId: nodeStatus.ID,
ArtifactName: metricsArtifactName,
}
artifactResponse, err := r.pipelineClient.ReadArtifact(artifactRequest)
artifactResponse, err := r.pipelineClient.ReadArtifact(artifactRequest, user)
if err != nil {
return "", err
}
Expand Down
Loading

0 comments on commit f6f7944

Please sign in to comment.