Skip to content

Commit

Permalink
CBG-2894: Reject user auth when channel threshold is over 500 (#6214)
Browse files Browse the repository at this point in the history
* CBG-2894: Reject user auth when channel threshold is over 500 in serverless mode

* fix panic where authetciator was needed and it wasn't availible

* linter issue

* linter issue again

* remove extra methods off interface

* pass user into function

* rebase

* ensure 500 code is retruned for http error added

* updates based off comments

* fix panic

* updates based off comments

* updates based off dicussion yesterday

* lint error

* updates based of comments
  • Loading branch information
gregns1 authored and bbrks committed Mar 28, 2024
1 parent 1ef14c4 commit a16d696
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 12 deletions.
91 changes: 86 additions & 5 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ type Authenticator struct {
}

type AuthenticatorOptions struct {
ClientPartitionWindow time.Duration
ChannelsWarningThreshold *uint32
SessionCookieName string
BcryptCost int
LogCtx context.Context
ClientPartitionWindow time.Duration
ChannelsWarningThreshold *uint32
ServerlessChannelThreshold uint32
SessionCookieName string
BcryptCost int
LogCtx context.Context

// Collections defines the set of collections used by the authenticator when rebuilding channels.
// Channels are only recomputed for collections included in this set.
Expand Down Expand Up @@ -196,6 +197,17 @@ func (auth *Authenticator) getPrincipal(docID string, factory func() Principal)
}
changed = true
}
// If the channel threshold has been set we need to check the inherited channels across all scopes and collections against the limit
if auth.ServerlessChannelThreshold != 0 {
channelsLength, err := auth.getInheritedChannelsLength(user)
if err != nil {
return nil, nil, false, err
}
err = auth.checkChannelLimits(channelsLength, user)
if err != nil {
return nil, nil, false, err
}
}
}

if changed {
Expand Down Expand Up @@ -223,13 +235,81 @@ func (auth *Authenticator) getPrincipal(docID string, factory func() Principal)
return princ, nil
}

// inheritedCollectionChannels returns channels for a given scope + collection
func (auth *Authenticator) inheritedCollectionChannels(user User, scope, collection string) (ch.TimedSet, error) {
roles, err := auth.getUserRoles(user)
if err != nil {
return nil, err
}

channels := user.CollectionChannels(scope, collection)
for _, role := range roles {
roleSince := user.RoleNames()[role.Name()]
channels.AddAtSequence(role.CollectionChannels(scope, collection), roleSince.Sequence)
}
return channels, nil
}

// getInheritedChannelsLength returns number of channels a user has access to across all collections
func (auth *Authenticator) getInheritedChannelsLength(user User) (int, error) {
var cumulativeChannels int
for scope, collections := range auth.Collections {
for collection := range collections {
channels, err := auth.inheritedCollectionChannels(user, scope, collection)
if err != nil {
return 0, err
}
cumulativeChannels += len(channels)
}
}
return cumulativeChannels, nil
}

// checkChannelLimits logs a warning when the warning threshold is met and will return an error when the channel limit is met
func (auth *Authenticator) checkChannelLimits(channels int, user User) error {
// Error if ServerlessChannelThreshold is set and is >= than the threshold
if uint32(channels) >= auth.ServerlessChannelThreshold {
base.ErrorfCtx(auth.LogCtx, "User ID: %v channel count: %d exceeds %d for channels per user threshold. Auth will be rejected until rectified",
base.UD(user.Name()), channels, auth.ServerlessChannelThreshold)
return base.ErrMaximumChannelsForUserExceeded
}

// This function is likely to be called once per session when a channel limit is applied, the sync once
// applied here ensures we don't fill logs with warnings about being over warning threshold. We may want
// to revisit this implementation around the warning threshold in future
user.GetWarnChanSync().Do(func() {
if channelsPerUserThreshold := auth.ChannelsWarningThreshold; channelsPerUserThreshold != nil {
if uint32(channels) >= *channelsPerUserThreshold {
base.WarnfCtx(auth.LogCtx, "User ID: %v channel count: %d exceeds %d for channels per user warning threshold",
base.UD(user.Name()), channels, *channelsPerUserThreshold)
}
}
})
return nil
}

// getUserRoles gets all roles a user has been granted
func (auth *Authenticator) getUserRoles(user User) ([]Role, error) {
roles := make([]Role, 0, len(user.RoleNames()))
for name := range user.RoleNames() {
role, err := auth.GetRole(name)
if err != nil {
return nil, err
} else if role != nil {
roles = append(roles, role)
}
}
return roles, nil
}

// Rebuild channels computes the full set of channels for all collections defined for the authenticator.
// For each collection in Authenticator.collections:
// - if there is no CollectionAccess on the principal for the collection, rebuilds channels for that collection
// - If CollectionAccess on the principal has been invalidated, rebuilds channels for that collection
func (auth *Authenticator) rebuildChannels(princ Principal) (changed bool, err error) {

changed = false

for scope, collections := range auth.Collections {
for collection, _ := range collections {
// If collection channels are nil, they have been invalidated and must be rebuilt
Expand All @@ -242,6 +322,7 @@ func (auth *Authenticator) rebuildChannels(princ Principal) (changed bool, err e
}
}
}

return changed, nil
}

Expand Down
113 changes: 113 additions & 0 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2759,6 +2759,119 @@ func TestObtainChannelsForDeletedRole(t *testing.T) {
}
}

func TestServerlessChannelLimitsRoles(t *testing.T) {
testCases := []struct {
Name string
Collection bool
}{
{
Name: "Single role",
},
{
Name: "Muliple roles",
},
}
for _, testCase := range testCases {
t.Run(testCase.Name, func(t *testing.T) {
ctx := base.TestCtx(t)
testBucket := base.GetTestBucket(t)
defer testBucket.Close(ctx)
dataStore := testBucket.GetSingleDataStore()
var role2 Role

opts := DefaultAuthenticatorOptions(ctx)
opts.ServerlessChannelThreshold = 5
opts.Collections = map[string]map[string]struct{}{
"scope1": {"collection1": struct{}{}, "collection2": struct{}{}},
}
auth := NewAuthenticator(dataStore, nil, opts)
user1, err := auth.NewUser("user1", "pass", ch.BaseSetOf(t, "ABC"))
require.NoError(t, err)
err = auth.Save(user1)
require.NoError(t, err)
_, err = auth.AuthenticateUser("user1", "pass")
require.NoError(t, err)

role1, err := auth.NewRole("role1", nil)
require.NoError(t, err)
if testCase.Name == "Single role" {
user1.SetExplicitRoles(ch.TimedSet{"role1": ch.NewVbSimpleSequence(1)}, 1)
require.NoError(t, auth.Save(user1))
_, err = auth.AuthenticateUser("user1", "pass")
require.NoError(t, err)

role1.SetCollectionExplicitChannels("scope1", "collection1", ch.AtSequence(ch.BaseSetOf(t, "ABC", "DEF", "GHI", "JKL"), 1), 1)
require.NoError(t, auth.Save(role1))
} else {
role2, err = auth.NewRole("role2", nil)
require.NoError(t, err)
user1.SetExplicitRoles(ch.TimedSet{"role1": ch.NewVbSimpleSequence(1), "role2": ch.NewVbSimpleSequence(1)}, 1)
require.NoError(t, auth.Save(user1))
role1.SetCollectionExplicitChannels("scope1", "collection1", ch.AtSequence(ch.BaseSetOf(t, "ABC", "DEF", "GHI", "JKL"), 1), 1)
role2.SetCollectionExplicitChannels("scope1", "collection2", ch.AtSequence(ch.BaseSetOf(t, "MNO", "PQR"), 1), 1)
require.NoError(t, auth.Save(role1))
require.NoError(t, auth.Save(role2))
}
_, err = auth.AuthenticateUser("user1", "pass")
require.Error(t, err)
})
}
}

func TestServerlessChannelLimits(t *testing.T) {

testCases := []struct {
Name string
Collection bool
}{
{
Name: "Collection not enabled",
Collection: false,
},
{
Name: "Collection is enabled",
Collection: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.Name, func(t *testing.T) {
ctx := base.TestCtx(t)
testBucket := base.GetTestBucket(t)
defer testBucket.Close(ctx)
dataStore := testBucket.GetSingleDataStore()

opts := DefaultAuthenticatorOptions(ctx)
opts.ServerlessChannelThreshold = 5
if testCase.Collection {
opts.Collections = map[string]map[string]struct{}{
"scope1": {"collection1": struct{}{}, "collection2": struct{}{}},
}
}
auth := NewAuthenticator(dataStore, nil, opts)
user1, err := auth.NewUser("user1", "pass", ch.BaseSetOf(t, "ABC"))
require.NoError(t, err)
err = auth.Save(user1)
require.NoError(t, err)
_, err = auth.AuthenticateUser("user1", "pass")
require.NoError(t, err)

if !testCase.Collection {
user1.SetCollectionExplicitChannels("_default", "_default", ch.AtSequence(ch.BaseSetOf(t, "ABC", "DEF", "GHI", "JKL", "MNO", "PQR"), 1), 1)
err = auth.Save(user1)
require.NoError(t, err)
} else {
user1.SetCollectionExplicitChannels("scope1", "collection1", ch.AtSequence(ch.BaseSetOf(t, "ABC", "DEF", "GHI", "JKL"), 1), 1)
user1.SetCollectionExplicitChannels("scope1", "collection2", ch.AtSequence(ch.BaseSetOf(t, "MNO", "PQR"), 1), 1)
err = auth.Save(user1)
require.NoError(t, err)
}
_, err = auth.AuthenticateUser("user1", "pass")
require.Error(t, err)
assert.Contains(t, err.Error(), base.ErrMaximumChannelsForUserExceeded.Error())
})
}
}

func TestInvalidateRoles(t *testing.T) {
ctx := base.TestCtx(t)
testBucket := base.GetTestBucket(t)
Expand Down
3 changes: 3 additions & 0 deletions auth/principal.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package auth

import (
"sync"
"time"

"github.com/couchbase/sync_gateway/base"
Expand Down Expand Up @@ -125,6 +126,8 @@ type User interface {

InitializeRoles()

GetWarnChanSync() *sync.Once

revokedChannels(since uint64, lowSeq uint64, triggeredBy uint64) RevokedChannels

// Obtains the period over which the user had access to the given channel. Either directly or via a role.
Expand Down
4 changes: 4 additions & 0 deletions auth/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ func (user *userImpl) SetEmail(email string) error {
return nil
}

func (user *userImpl) GetWarnChanSync() *sync.Once {
return &user.warnChanThresholdOnce
}

func (user *userImpl) RoleNames() ch.TimedSet {
if user.RoleInvalSeq != 0 {
return nil
Expand Down
3 changes: 3 additions & 0 deletions base/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ const (
// If set to zero, timeout is disabled.
DefaultJavascriptTimeoutSecs = uint32(0)

// ServerlessChannelLimit is hard limit on channels allowed per user when running in serverless mode
ServerlessChannelLimit = 500

// FromConnStrWarningThreshold determines the amount of time it should take before we warn about parsing a connstr (mostly for DNS resolution)
FromConnStrWarningThreshold = 10 * time.Second
)
Expand Down
3 changes: 3 additions & 0 deletions base/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ var (
// ErrConfigRegistryReloadRequired is returned when a db config fetch requires a registry reload based on version mismatch (config is newer)
ErrConfigRegistryReloadRequired = &sgError{"Config registry reload required"}

// ErrMaximumChannelsForUserExceeded is returned when running in serverless mode and the user has more than 500 channels granted to them
ErrMaximumChannelsForUserExceeded = &sgError{fmt.Sprintf("User has exceeded maximum of %d channels", ServerlessChannelLimit)}

// ErrReplicationLimitExceeded is returned when then replication connection threshold is exceeded
ErrReplicationLimitExceeded = &sgError{"Replication limit exceeded. Try again later."}
)
Expand Down
19 changes: 12 additions & 7 deletions db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -842,16 +842,21 @@ func (context *DatabaseContext) Authenticator(ctx context.Context) *auth.Authent
if context.Options.UnsupportedOptions != nil && context.Options.UnsupportedOptions.WarningThresholds != nil {
channelsWarningThreshold = context.Options.UnsupportedOptions.WarningThresholds.ChannelsPerUser
}
var channelServerlessThreshold uint32
if context.IsServerless() {
channelServerlessThreshold = base.ServerlessChannelLimit
}

// Authenticators are lightweight & stateless, so it's OK to return a new one every time
authenticator := auth.NewAuthenticator(context.MetadataStore, context, auth.AuthenticatorOptions{
ClientPartitionWindow: context.Options.ClientPartitionWindow,
ChannelsWarningThreshold: channelsWarningThreshold,
SessionCookieName: sessionCookieName,
BcryptCost: context.Options.BcryptCost,
LogCtx: ctx,
Collections: context.CollectionNames,
MetaKeys: context.MetadataKeys,
ClientPartitionWindow: context.Options.ClientPartitionWindow,
ChannelsWarningThreshold: channelsWarningThreshold,
ServerlessChannelThreshold: channelServerlessThreshold,
SessionCookieName: sessionCookieName,
BcryptCost: context.Options.BcryptCost,
LogCtx: ctx,
Collections: context.CollectionNames,
MetaKeys: context.MetadataKeys,
})

return authenticator
Expand Down

0 comments on commit a16d696

Please sign in to comment.