Skip to content

Commit

Permalink
fix: update token invalidation for recent API changes (#1110)
Browse files Browse the repository at this point in the history
- Update the token invalidation code to use the latest API on SaaS
- Return `jwt.ErrTokenExpired` when the token has been invalidated so that it is treated as an auth error and tiggers re-auth
  • Loading branch information
hanyucui authored and determined-ci committed Apr 18, 2024
1 parent 401af27 commit 6fefbb3
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 14 deletions.
2 changes: 1 addition & 1 deletion master/internal/user/external_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func ByExternalToken(ctx context.Context, tokenText string,
claims := token.Claims.(*model.JWT)

if ext.Validate(claims) != nil {
return nil, nil, errors.New("token has been invalidated")
return nil, nil, jwt.ErrTokenExpired
}

var isAdmin bool
Expand Down
35 changes: 22 additions & 13 deletions master/pkg/model/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ import (
"sync"
"time"

"github.com/golang-jwt/jwt/v4"

"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/uptrace/bun"
"golang.org/x/crypto/bcrypt"
"google.golang.org/protobuf/types/known/timestamppb"
"gopkg.in/guregu/null.v3"

"github.com/uptrace/bun"

"github.com/determined-ai/determined/proto/pkg/userv1"
)

Expand Down Expand Up @@ -188,9 +189,9 @@ func (e *ExternalSessions) Validate(claims *JWT) error {
return nil
}
d := time.Unix(claims.IssuedAt, 0)
v := e.Invalidations.ValidFrom(claims.UserID)
v := e.Invalidations.GetInvalidatonTime(claims.UserID)
if d.Before(v) {
return errors.New("token has been invalidated")
return jwt.ErrTokenExpired
}
return nil
}
Expand Down Expand Up @@ -254,18 +255,26 @@ func (e *ExternalSessions) StartInvalidationPoll(cert *tls.Certificate) {

// InvalidationMap tracks times before which users should be considered invalid.
type InvalidationMap struct {
DefaultTime time.Time `json:"defaultTime"`
LastUpdated time.Time `json:"lastUpdated"`
Overrides map[string]time.Time `json:"overrides"`
DefaultTime time.Time `json:"defaultTime"`
LastUpdated time.Time `json:"lastUpdated"`
InvalidationTimes map[string]map[string]time.Time `json:"invalidationTimes"`
}

// ValidFrom returns the time from which tokens for the specified user are valid.
func (im *InvalidationMap) ValidFrom(id string) time.Time {
ts, ok := im.Overrides[id]
if ok {
return ts
// GetInvalidatonTime returns which the token invalidation time for the specified user.
func (im *InvalidationMap) GetInvalidatonTime(id string) time.Time {
times, ok := im.InvalidationTimes[id]
if !ok {
return im.DefaultTime
}
return im.DefaultTime

var latest time.Time
for _, t := range times {
if latest.IsZero() || latest.Before(t) {
latest = t
}
}

return latest
}

// UserWebSetting is a record of user web setting.
Expand Down
39 changes: 39 additions & 0 deletions master/pkg/model/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
Expand All @@ -19,6 +20,44 @@ func TestUserNonNilLastAuthAtProto(t *testing.T) {
require.WithinDuration(t, expectedTime, u.Proto().LastAuthAt.AsTime(), time.Millisecond)
}

func TestInvalidation(t *testing.T) {
tt := time.Date(2023, 2, 2, 15, 4, 5, 0, time.UTC)
extSess := ExternalSessions{
Invalidations: &InvalidationMap{
DefaultTime: tt.Add(3 * time.Hour),
LastUpdated: tt.Add(9 * time.Hour),
InvalidationTimes: map[string]map[string]time.Time{
"user-1": {
"logout": tt.Add(9 * time.Hour),
"new-perm": tt.Add(5 * time.Hour),
},
},
},
}

j := JWT{
StandardClaims: jwt.StandardClaims{ //nolint
IssuedAt: tt.Unix(),
ExpiresAt: tt.Add(20 * time.Hour).Unix(),
},
UserID: "user-1",
Email: "[email protected]",
Name: "Test User",
OrgRoles: map[OrgID]OrgRoleClaims{},
}
require.ErrorIs(t, extSess.Validate(&j), jwt.ErrTokenExpired)

j.StandardClaims.IssuedAt = tt.Add(10 * time.Hour).Unix()
require.NoError(t, extSess.Validate(&j))

j.StandardClaims.IssuedAt = tt.Add(6 * time.Hour).Unix()
require.ErrorIs(t, extSess.Validate(&j), jwt.ErrTokenExpired)

// No such user, using default time
j.UserID = "user-23"
require.NoError(t, extSess.Validate(&j))
}

func TestUserWebSetting_Proto(t *testing.T) {
in := UserWebSetting{
UserID: UserID(42),
Expand Down

0 comments on commit 6fefbb3

Please sign in to comment.