Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend) Enable auth between pesistence agent and pipelineAPI (ReportServer) #9699

Merged
merged 3 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions backend/src/agent/persistence/client/pipeline_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"fmt"
"os"
"strings"
"time"

"github.com/kubeflow/pipelines/backend/src/apiserver/common"
Expand Down Expand Up @@ -46,11 +47,13 @@ type PipelineClient struct {
timeout time.Duration
reportServiceClient api.ReportServiceClient
runServiceClient api.RunServiceClient
tokenRefresher TokenRefresherInterface
}

func NewPipelineClient(
initializeTimeout time.Duration,
timeout time.Duration,
tokenRefresher TokenRefresherInterface,
basePath string,
mlPipelineServiceName string,
mlPipelineServiceHttpPort string,
Expand All @@ -71,13 +74,18 @@ func NewPipelineClient(
return &PipelineClient{
initializeTimeout: initializeTimeout,
timeout: timeout,
tokenRefresher: tokenRefresher,
reportServiceClient: api.NewReportServiceClient(connection),
runServiceClient: api.NewRunServiceClient(connection),
}, nil
}

func (p *PipelineClient) ReportWorkflow(workflow util.ExecutionSpec) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
pctx := context.Background()
pctx = metadata.AppendToOutgoingContext(pctx, "Authorization",
"Bearer "+p.tokenRefresher.GetToken())

ctx, cancel := context.WithTimeout(pctx, time.Minute)
defer cancel()

_, err := p.reportServiceClient.ReportWorkflowV1(ctx, &api.ReportWorkflowRequest{
Expand All @@ -96,6 +104,15 @@ func (p *PipelineClient) ReportWorkflow(workflow util.ExecutionSpec) error {
statusCode.Message(),
err.Error(),
workflow.ToStringForStore())
} else if statusCode.Code() == codes.Unauthenticated && strings.Contains(err.Error(), "service account token has expired") {
// If unauthenticated because SA token is expired, re-read/refresh the token and try again
p.tokenRefresher.RefreshToken()
return util.NewCustomError(err, util.CUSTOM_CODE_TRANSIENT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I see you're following an existing pattern, for my education how would user discover such error? Would this be surfaced back by apiserver?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! Sorry for my late response.
The user could observe the status of the Run and inspect the logs of the persistent agent. The errors are not directly displayed on the UI.
If a "transient" error occurs the Persistent Agent will retry with an exponential delay

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, If an error occurs on token refresh, I just log the error and let the worker thread retries. Any other suggestions on what to do if the refresh token fails?

"Error while reporting workflow resource (code: %v, message: %v): %v, %+v",
statusCode.Code(),
statusCode.Message(),
err.Error(),
workflow.ToStringForStore())
} else {
// Retry otherwise
return util.NewCustomError(err, util.CUSTOM_CODE_TRANSIENT,
Expand All @@ -110,7 +127,11 @@ func (p *PipelineClient) ReportWorkflow(workflow util.ExecutionSpec) error {
}

func (p *PipelineClient) ReportScheduledWorkflow(swf *util.ScheduledWorkflow) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
pctx := context.Background()
pctx = metadata.AppendToOutgoingContext(pctx, "Authorization",
"Bearer "+p.tokenRefresher.GetToken())

ctx, cancel := context.WithTimeout(pctx, time.Minute)
defer cancel()

_, err := p.reportServiceClient.ReportScheduledWorkflowV1(ctx,
Expand All @@ -128,6 +149,15 @@ func (p *PipelineClient) ReportScheduledWorkflow(swf *util.ScheduledWorkflow) er
statusCode.Message(),
err.Error(),
swf.ScheduledWorkflow)
} else if statusCode.Code() == codes.Unauthenticated && strings.Contains(err.Error(), "service account token has expired") {
// If unauthenticated because SA token is expired, re-read/refresh the token and try again
p.tokenRefresher.RefreshToken()
return util.NewCustomError(err, util.CUSTOM_CODE_TRANSIENT,
"Error while reporting workflow resource (code: %v, message: %v): %v, %+v",
statusCode.Code(),
statusCode.Message(),
err.Error(),
swf.ScheduledWorkflow)
} else {
// Retry otherwise
return util.NewCustomError(err, util.CUSTOM_CODE_TRANSIENT,
Expand Down
78 changes: 78 additions & 0 deletions backend/src/agent/persistence/client/token_refresher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package client

import (
log "github.com/sirupsen/logrus"
"os"
"sync"
"time"
)

type TokenRefresherInterface interface {
GetToken() string
RefreshToken() error
}

const SaTokenFile = "/var/run/secrets/kubeflow/tokens/persistenceagent-sa-token"

type FileReader interface {
ReadFile(filename string) ([]byte, error)
}

type tokenRefresher struct {
mu sync.RWMutex
seconds *time.Duration
token string
fileReader *FileReader
}

type FileReaderImpl struct{}

func (r *FileReaderImpl) ReadFile(filename string) ([]byte, error) {
return os.ReadFile(filename)
}

func NewTokenRefresher(seconds time.Duration, fileReader FileReader) *tokenRefresher {
if fileReader == nil {
fileReader = &FileReaderImpl{}
}

tokenRefresher := &tokenRefresher{
seconds: &seconds,
fileReader: &fileReader,
}

return tokenRefresher
}

func (tr *tokenRefresher) StartTokenRefreshTicker() error {
err := tr.RefreshToken()
if err != nil {
return err
}

ticker := time.NewTicker(*tr.seconds)
go func() {
for range ticker.C {
Copy link
Member Author

@difince difince Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was unsure about the use of the stopCh, so I removed it. WDYT? Please inspect the changes introduced by my third commit.

tr.RefreshToken()
}
}()
return err
}

func (tr *tokenRefresher) GetToken() string {
tr.mu.RLock()
defer tr.mu.RUnlock()
return tr.token
}

func (tr *tokenRefresher) RefreshToken() error {
tr.mu.Lock()
defer tr.mu.Unlock()
b, err := (*tr.fileReader).ReadFile(SaTokenFile)
if err != nil {
log.Errorf("Error reading persistence agent service account token '%s': %v", SaTokenFile, err)
return err
}
tr.token = string(b)
return nil
}
111 changes: 111 additions & 0 deletions backend/src/agent/persistence/client/token_refresher_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package client

import (
"fmt"
"io/fs"
"log"
"syscall"
"testing"
"time"
)

const refreshInterval = 2 * time.Second

type FileReaderFake struct {
Data string
Err error
readCounter int
}

func (m *FileReaderFake) ReadFile(filename string) ([]byte, error) {
if m.Err != nil {
return nil, m.Err
}
content := fmt.Sprintf("%s-%v", m.Data, m.readCounter)
m.readCounter++
return []byte(content), nil
}

func Test_token_refresher(t *testing.T) {
tests := []struct {
name string
baseToken string
wanted string
refreshedToken string
err error
}{
{
name: "TestTokenRefresher_GetToken_Success",
baseToken: "rightToken",
wanted: "rightToken-0",
err: nil,
},
{
name: "TestTokenRefresher_GetToken_Failed_PathError",
baseToken: "rightToken",
wanted: "rightToken-0",
err: &fs.PathError{Err: syscall.ENOENT},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// setup
fakeFileReader := &FileReaderFake{
Data: tt.baseToken,
Err: tt.err,
}
tr := NewTokenRefresher(refreshInterval, fakeFileReader)
err := tr.StartTokenRefreshTicker()
if err != nil {
got, sameType := err.(*fs.PathError)
if sameType != true {
t.Errorf("%v(): got = %v, wanted %v", tt.name, got, tt.err)
}
return
}
if err != nil {
log.Fatalf("Error starting Service Account Token Refresh Ticker: %v", err)
}

if got := tr.GetToken(); got != tt.wanted {
t.Errorf("%v(): got %v, wanted %v", tt.name, got, tt.wanted)
}
})
}
}

func TestTokenRefresher_GetToken_After_TickerRefresh_Success(t *testing.T) {
fakeFileReader := &FileReaderFake{
Data: "Token",
Err: nil,
}
tr := NewTokenRefresher(1*time.Second, fakeFileReader)
err := tr.StartTokenRefreshTicker()
if err != nil {
log.Fatalf("Error starting Service Account Token Refresh Ticker: %v", err)
}
time.Sleep(1200 * time.Millisecond)
expectedToken := "Token-1"

if got := tr.GetToken(); got != expectedToken {
t.Errorf("%v(): got %v, wanted 'refreshed baseToken' %v", t.Name(), got, expectedToken)
}
}

func TestTokenRefresher_GetToken_After_ForceRefresh_Success(t *testing.T) {
fakeFileReader := &FileReaderFake{
Data: "Token",
Err: nil,
}
tr := NewTokenRefresher(refreshInterval, fakeFileReader)
err := tr.StartTokenRefreshTicker()
if err != nil {
log.Fatalf("Error starting Service Account Token Refresh Ticker: %v", err)
}
tr.RefreshToken()
expectedToken := "Token-1"

if got := tr.GetToken(); got != expectedToken {
t.Errorf("%v(): got %v, wanted 'refreshed baseToken' %v", t.Name(), got, expectedToken)
}
}
16 changes: 15 additions & 1 deletion backend/src/agent/persistence/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ var (
numWorker int
clientQPS float64
clientBurst int
saTokenRefreshInterval float64
)

const (
Expand All @@ -59,10 +60,12 @@ const (
numWorkerName = "numWorker"
clientQPSFlagName = "clientQPS"
clientBurstFlagName = "clientBurst"
saTokenRefreshIntervalFlagName = "saTokenRefreshInterval"
)

const (
DefaultConnectionTimeout = 6 * time.Minute
DefaultConnectionTimeout = 6 * time.Minute
DefaultTokenRefresherInterval = 1 * time.Hour
)

func main() {
Expand Down Expand Up @@ -97,9 +100,16 @@ func main() {
Burst: clientBurst,
})

tokenRefresher := client.NewTokenRefresher(time.Duration(saTokenRefreshInterval), nil)
err = tokenRefresher.StartTokenRefreshTicker()
if err != nil {
log.Fatalf("Error starting Service Account Token Refresh Ticker due to: %v", err)
}

pipelineClient, err := client.NewPipelineClient(
initializeTimeout,
timeout,
tokenRefresher,
mlPipelineAPIServerBasePath,
mlPipelineAPIServerName,
mlPipelineServiceHttpPort,
Expand Down Expand Up @@ -140,4 +150,8 @@ func init() {
// k8s.io/client-go/rest/config.go#RESTClientFor
flag.Float64Var(&clientQPS, clientQPSFlagName, 5, "The maximum QPS to the master from this client.")
flag.IntVar(&clientBurst, clientBurstFlagName, 10, "Maximum burst for throttle from this client.")
// TODO use viper/config file instead. Sync `saTokenRefreshIntervalFlagName` with the value from manifest file by using ENV var.
flag.Float64Var(&saTokenRefreshInterval, saTokenRefreshIntervalFlagName, DefaultTokenRefresherInterval.Seconds(), "Persistence agent service account token read interval in seconds. "+
"Defines how often `/var/run/secrets/kubeflow/tokens/kubeflow-persistent_agent-api-token` to be read")

}
3 changes: 2 additions & 1 deletion backend/src/apiserver/auth/authenticator_token_review.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ func (tra *TokenReviewAuthenticator) doTokenReview(ctx context.Context, userIden
if !review.Status.Authenticated {
return nil, util.NewUnauthenticatedError(
errors.New("Failed to authenticate token review"),
"Review.Status.Authenticated is false",
"Review.Status.Authenticated is false. Error %s",
review.Status.Error,
)
}
if !tra.ensureAudience(review.Status.Audiences) {
Expand Down
15 changes: 9 additions & 6 deletions backend/src/apiserver/common/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ const (
RbacPipelinesGroup = "pipelines.kubeflow.org"
RbacPipelinesVersion = "v1beta1"

RbacResourceTypePipelines = "pipelines"
RbacResourceTypeExperiments = "experiments"
RbacResourceTypeRuns = "runs"
RbacResourceTypeJobs = "jobs"
RbacResourceTypeViewers = "viewers"
RbacResourceTypeVisualizations = "visualizations"
RbacResourceTypePipelines = "pipelines"
RbacResourceTypeExperiments = "experiments"
RbacResourceTypeRuns = "runs"
RbacResourceTypeJobs = "jobs"
RbacResourceTypeViewers = "viewers"
RbacResourceTypeVisualizations = "visualizations"
RbacResourceTypeScheduledWorkflows = "scheduledworkflows"
RbacResourceTypeWorkflows = "workflows"

RbacResourceVerbArchive = "archive"
RbacResourceVerbUpdate = "update"
Expand All @@ -39,6 +41,7 @@ const (
RbacResourceVerbUnarchive = "unarchive"
RbacResourceVerbReportMetrics = "reportMetrics"
RbacResourceVerbReadArtifact = "readArtifact"
RbacResourceVerbReport = "report"
)

const (
Expand Down
Loading