Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
- unit tests added
- do not stop the ticker on stopCh

Signed-off-by: diana <[email protected]>
  • Loading branch information
difince committed Aug 16, 2023
1 parent 1de9bd0 commit bc9e1e0
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 25 deletions.
50 changes: 28 additions & 22 deletions backend/src/agent/persistence/client/token_refresher.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,51 @@ import (

type TokenRefresherInterface interface {
GetToken() string
RefreshToken()
RefreshToken() error
}

const SaTokenFile = "/var/run/secrets/kubeflow/tokens/persistenceagent-sa-token"

type FileReader interface {
ReadFile(filename string) ([]byte, error)
}

type tokenRefresher struct {
mu sync.RWMutex
seconds *time.Duration
token string
mu sync.RWMutex
seconds *time.Duration
token string
fileReader *FileReader
}

type FileReaderImpl struct{}

func (r *FileReaderImpl) ReadFile(filename string) ([]byte, error) {
return os.ReadFile(filename)
}

func NewTokenRefresher(seconds time.Duration) *tokenRefresher {
func NewTokenRefresher(seconds time.Duration, fileReader FileReader) *tokenRefresher {
if fileReader == nil {
fileReader = &FileReaderImpl{}
}

tokenRefresher := &tokenRefresher{
seconds: &seconds,
seconds: &seconds,
fileReader: &fileReader,
}

return tokenRefresher
}

func (tr *tokenRefresher) StartTokenRefreshTicker(stopCh <-chan struct{}) error {
err := tr.readToken()
func (tr *tokenRefresher) StartTokenRefreshTicker() error {
err := tr.RefreshToken()
if err != nil {
return err
}

ticker := time.NewTicker(*tr.seconds)
go func() {
for {
select {
case <-stopCh:
ticker.Stop()
return
case <-ticker.C:
tr.readToken()
}
for range ticker.C {
tr.RefreshToken()
}
}()
return err
Expand All @@ -55,14 +65,10 @@ func (tr *tokenRefresher) GetToken() string {
return tr.token
}

func (tr *tokenRefresher) RefreshToken() {
tr.readToken()
}

func (tr *tokenRefresher) readToken() error {
func (tr *tokenRefresher) RefreshToken() error {
tr.mu.Lock()
defer tr.mu.Unlock()
b, err := os.ReadFile(SaTokenFile)
b, err := (*tr.fileReader).ReadFile(SaTokenFile)
if err != nil {
log.Errorf("Error reading persistence agent service account token '%s': %v", SaTokenFile, err)
return err
Expand Down
111 changes: 111 additions & 0 deletions backend/src/agent/persistence/client/token_refresher_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package client

import (
"fmt"
"io/fs"
"log"
"syscall"
"testing"
"time"
)

const refreshInterval = 2 * time.Second

type FileReaderFake struct {
Data string
Err error
readCounter int
}

func (m *FileReaderFake) ReadFile(filename string) ([]byte, error) {
if m.Err != nil {
return nil, m.Err
}
content := fmt.Sprintf("%s-%v", m.Data, m.readCounter)
m.readCounter++
return []byte(content), nil
}

func Test_token_refresher(t *testing.T) {
tests := []struct {
name string
baseToken string
wanted string
refreshedToken string
err error
}{
{
name: "TestTokenRefresher_GetToken_Success",
baseToken: "rightToken",
wanted: "rightToken-0",
err: nil,
},
{
name: "TestTokenRefresher_GetToken_Failed_PathError",
baseToken: "rightToken",
wanted: "rightToken-0",
err: &fs.PathError{Err: syscall.ENOENT},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// setup
fakeFileReader := &FileReaderFake{
Data: tt.baseToken,
Err: tt.err,
}
tr := NewTokenRefresher(refreshInterval, fakeFileReader)
err := tr.StartTokenRefreshTicker()
if err != nil {
got, sameType := err.(*fs.PathError)
if sameType != true {
t.Errorf("%v(): got = %v, wanted %v", tt.name, got, tt.err)
}
return
}
if err != nil {
log.Fatalf("Error starting Service Account Token Refresh Ticker: %v", err)
}

if got := tr.GetToken(); got != tt.wanted {
t.Errorf("%v(): got %v, wanted %v", tt.name, got, tt.wanted)
}
})
}
}

func TestTokenRefresher_GetToken_After_TickerRefresh_Success(t *testing.T) {
fakeFileReader := &FileReaderFake{
Data: "Token",
Err: nil,
}
tr := NewTokenRefresher(1*time.Second, fakeFileReader)
err := tr.StartTokenRefreshTicker()
if err != nil {
log.Fatalf("Error starting Service Account Token Refresh Ticker: %v", err)
}
time.Sleep(1200 * time.Millisecond)
expectedToken := "Token-1"

if got := tr.GetToken(); got != expectedToken {
t.Errorf("%v(): got %v, wanted 'refreshed baseToken' %v", t.Name(), got, expectedToken)
}
}

func TestTokenRefresher_GetToken_After_ForceRefresh_Success(t *testing.T) {
fakeFileReader := &FileReaderFake{
Data: "Token",
Err: nil,
}
tr := NewTokenRefresher(refreshInterval, fakeFileReader)
err := tr.StartTokenRefreshTicker()
if err != nil {
log.Fatalf("Error starting Service Account Token Refresh Ticker: %v", err)
}
tr.RefreshToken()
expectedToken := "Token-1"

if got := tr.GetToken(); got != expectedToken {
t.Errorf("%v(): got %v, wanted 'refreshed baseToken' %v", t.Name(), got, expectedToken)
}
}
6 changes: 3 additions & 3 deletions backend/src/agent/persistence/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ func main() {
Burst: clientBurst,
})

tokenRefresher := client.NewTokenRefresher(time.Duration(saTokenRefreshInterval))
err = tokenRefresher.StartTokenRefreshTicker(stopCh)
tokenRefresher := client.NewTokenRefresher(time.Duration(saTokenRefreshInterval), nil)
err = tokenRefresher.StartTokenRefreshTicker()
if err != nil {
log.Fatalf("Error starting Service Account Token Refresh Ticker: %v", err)
log.Fatalf("Error starting Service Account Token Refresh Ticker due to: %v", err)
}

pipelineClient, err := client.NewPipelineClient(
Expand Down

0 comments on commit bc9e1e0

Please sign in to comment.