Skip to content

Commit

Permalink
Address review comments.
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Haus <[email protected]>
  • Loading branch information
dhaus67 committed Mar 3, 2022
1 parent ca7e536 commit 55658b5
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 26 deletions.
50 changes: 25 additions & 25 deletions connector/gitlab/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ package gitlab
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"

"golang.org/x/oauth2"
Expand Down Expand Up @@ -64,7 +64,7 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
type connectorData struct {
// Support GitLab's Access Tokens and Refresh tokens.
AccessToken string `json:"accessToken"`
RefreshToken []byte
RefreshToken string `json:"refreshToken"`
}

var (
Expand Down Expand Up @@ -174,7 +174,7 @@ func (c *gitlabConnector) identity(ctx context.Context, s connector.Scopes, toke
}

if s.OfflineAccess {
data := connectorData{RefreshToken: []byte(token.RefreshToken), AccessToken: token.AccessToken}
data := connectorData{RefreshToken: token.RefreshToken, AccessToken: token.AccessToken}
connData, err := json.Marshal(data)
if err != nil {
return identity, fmt.Errorf("gitlab: marshal connector data: %v", err)
Expand All @@ -196,29 +196,29 @@ func (c *gitlabConnector) Refresh(ctx context.Context, s connector.Scopes, ident
ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
}

t := &oauth2.Token{
AccessToken: data.AccessToken,
RefreshToken: string(data.RefreshToken),
}

// Try accessing the GitLab API without refreshing the token at first.
id, err := c.identity(ctx, s, t)
if err == nil {
return id, nil
}

// Do not retry with refreshing token if a non-unauthorized error occurred.
if !strings.Contains(err.Error(), strconv.Itoa(http.StatusUnauthorized)) {
return id, err
}

// Refresh the token and retry.
t.Expiry = time.Now().Add(-time.Hour)
token, err := oauth2Config.TokenSource(ctx, t).Token()
if err != nil {
return ident, fmt.Errorf("gitlab: failed to get refresh token: %v", err)
switch {
case data.RefreshToken != "":
{
t := &oauth2.Token{
RefreshToken: data.RefreshToken,
Expiry: time.Now().Add(-time.Hour),
}
token, err := oauth2Config.TokenSource(ctx, t).Token()
if err != nil {
return ident, fmt.Errorf("gitlab: failed to get refresh token: %v", err)
}
return c.identity(ctx, s, token)
}
case data.AccessToken != "":
{
token := &oauth2.Token{
AccessToken: data.AccessToken,
}
return c.identity(ctx, s, token)
}
default:
return ident, errors.New("no refresh or access token found")
}
return c.identity(ctx, s, token)
}

func (c *gitlabConnector) groupsRequired(groupScope bool) bool {
Expand Down
30 changes: 29 additions & 1 deletion connector/gitlab/gitlab_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func TestRefresh(t *testing.T) {
c := gitlabConnector{baseURL: s.URL, httpClient: newClient()}

expectedConnectorData, err := json.Marshal(connectorData{
RefreshToken: []byte("oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC"),
RefreshToken: "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC",
AccessToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9",
})
expectNil(t, err)
Expand All @@ -221,6 +221,34 @@ func TestRefresh(t *testing.T) {
expectEquals(t, identity.ConnectorData, expectedConnectorData)
}

func TestRefreshWithEmptyConnectorData(t *testing.T) {
s := newTestServer(map[string]interface{}{
"/api/v4/user": gitlabUser{Email: "[email protected]", ID: 12345678},
"/oauth/token": map[string]interface{}{
"access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9",
"refresh_token": "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC",
"expires_in": "30",
},
"/oauth/userinfo": userInfo{
Groups: []string{"team-1"},
},
})
defer s.Close()

emptyConnectorData, err := json.Marshal(connectorData{
RefreshToken: "",
AccessToken: "",
})
expectNil(t, err)

c := gitlabConnector{baseURL: s.URL, httpClient: newClient()}
emptyIdentity := connector.Identity{ConnectorData: emptyConnectorData}

identity, err := c.Refresh(context.Background(), connector.Scopes{OfflineAccess: true}, emptyIdentity)
expectNotNil(t, err, "Refresh error")
expectEquals(t, emptyIdentity, identity)
}

func newTestServer(responses map[string]interface{}) *httptest.Server {
return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := responses[r.RequestURI]
Expand Down

0 comments on commit 55658b5

Please sign in to comment.