Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CBG-2894: Reject user auth when channel threshold is over 500 #6214

Merged
merged 14 commits into from
May 9, 2023
91 changes: 86 additions & 5 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ type Authenticator struct {
}

type AuthenticatorOptions struct {
ClientPartitionWindow time.Duration
ChannelsWarningThreshold *uint32
SessionCookieName string
BcryptCost int
LogCtx context.Context
ClientPartitionWindow time.Duration
ChannelsWarningThreshold *uint32
ServerlessChannelThreshold uint32
SessionCookieName string
BcryptCost int
LogCtx context.Context

// Collections defines the set of collections used by the authenticator when rebuilding channels.
// Channels are only recomputed for collections included in this set.
Expand Down Expand Up @@ -196,6 +197,17 @@ func (auth *Authenticator) getPrincipal(docID string, factory func() Principal)
}
changed = true
}
// If the channel threshold has been set we need to check the inherited channels across all scopes and collections against the limit
if auth.ServerlessChannelThreshold != 0 {
channelsLength, err := auth.getInheritedChannelsLength(user)
if err != nil {
return nil, nil, false, err
}
err = auth.checkChannelLimits(channelsLength, user)
if err != nil {
return nil, nil, false, err
}
}
}

if changed {
Expand Down Expand Up @@ -223,13 +235,81 @@ func (auth *Authenticator) getPrincipal(docID string, factory func() Principal)
return princ, nil
}

// inheritedCollectionChannels returns channels for a given scope + collection
func (auth *Authenticator) inheritedCollectionChannels(user User, scope, collection string) (ch.TimedSet, error) {
gregns1 marked this conversation as resolved.
Show resolved Hide resolved
roles, err := auth.getUserRoles(user)
if err != nil {
return nil, err
}

channels := user.CollectionChannels(scope, collection)
for _, role := range roles {
roleSince := user.RoleNames()[role.Name()]
channels.AddAtSequence(role.CollectionChannels(scope, collection), roleSince.Sequence)
}
return channels, nil
}

// getInheritedChannelsLength returns number of channels a user has access to across all collections
func (auth *Authenticator) getInheritedChannelsLength(user User) (int, error) {
gregns1 marked this conversation as resolved.
Show resolved Hide resolved
var cumulativeChannels int
for scope, collections := range auth.Collections {
for collection := range collections {
channels, err := auth.inheritedCollectionChannels(user, scope, collection)
if err != nil {
return 0, err
}
cumulativeChannels += len(channels)
}
}
return cumulativeChannels, nil
}

// checkChannelLimits logs a warning when the warning threshold is met and will return an error when the channel limit is met
func (auth *Authenticator) checkChannelLimits(channels int, user User) error {
gregns1 marked this conversation as resolved.
Show resolved Hide resolved
// Error if ServerlessChannelThreshold is set and is >= than the threshold
if uint32(channels) >= auth.ServerlessChannelThreshold {
base.ErrorfCtx(auth.LogCtx, "User ID: %v channel count: %d exceeds %d for channels per user threshold. Auth will be rejected until rectified",
base.UD(user.Name()), channels, auth.ServerlessChannelThreshold)
return base.ErrMaximumChannelsForUserExceeded
}

// This function is likely to be called once per session when a channel limit is applied, the sync once
// applied here ensures we don't fill logs with warnings about being over warning threshold. We may want
// to revisit this implementation around the warning threshold in future
user.GetWarnChanSync().Do(func() {
if channelsPerUserThreshold := auth.ChannelsWarningThreshold; channelsPerUserThreshold != nil {
if uint32(channels) >= *channelsPerUserThreshold {
base.WarnfCtx(auth.LogCtx, "User ID: %v channel count: %d exceeds %d for channels per user warning threshold",
base.UD(user.Name()), channels, *channelsPerUserThreshold)
}
}
})
return nil
}

// getUserRoles gets all roles a user has been granted
func (auth *Authenticator) getUserRoles(user User) ([]Role, error) {
roles := make([]Role, 0, len(user.RoleNames()))
for name := range user.RoleNames() {
role, err := auth.GetRole(name)
if err != nil {
return nil, err
} else if role != nil {
roles = append(roles, role)
}
}
return roles, nil
}

// Rebuild channels computes the full set of channels for all collections defined for the authenticator.
// For each collection in Authenticator.collections:
// - if there is no CollectionAccess on the principal for the collection, rebuilds channels for that collection
// - If CollectionAccess on the principal has been invalidated, rebuilds channels for that collection
func (auth *Authenticator) rebuildChannels(princ Principal) (changed bool, err error) {

changed = false

for scope, collections := range auth.Collections {
for collection, _ := range collections {
// If collection channels are nil, they have been invalidated and must be rebuilt
Expand All @@ -242,6 +322,7 @@ func (auth *Authenticator) rebuildChannels(princ Principal) (changed bool, err e
}
}
}

return changed, nil
}

Expand Down
111 changes: 111 additions & 0 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2752,6 +2752,117 @@ func TestObtainChannelsForDeletedRole(t *testing.T) {
}
}

func TestServerlessChannelLimitsRoles(t *testing.T) {
testCases := []struct {
Name string
Collection bool
}{
{
Name: "Single role",
},
{
Name: "Muliple roles",
},
}
for _, testCase := range testCases {
t.Run(testCase.Name, func(t *testing.T) {
testBucket := base.GetTestBucket(t)
defer testBucket.Close()
dataStore := testBucket.GetSingleDataStore()
var role2 Role

opts := DefaultAuthenticatorOptions()
opts.ServerlessChannelThreshold = 5
opts.Collections = map[string]map[string]struct{}{
"scope1": {"collection1": struct{}{}, "collection2": struct{}{}},
}
auth := NewAuthenticator(dataStore, nil, opts)
user1, err := auth.NewUser("user1", "pass", ch.BaseSetOf(t, "ABC"))
require.NoError(t, err)
err = auth.Save(user1)
require.NoError(t, err)
_, err = auth.AuthenticateUser("user1", "pass")
require.NoError(t, err)

role1, err := auth.NewRole("role1", nil)
require.NoError(t, err)
if testCase.Name == "Single role" {
user1.SetExplicitRoles(ch.TimedSet{"role1": ch.NewVbSimpleSequence(1)}, 1)
require.NoError(t, auth.Save(user1))
_, err = auth.AuthenticateUser("user1", "pass")
require.NoError(t, err)

role1.SetCollectionExplicitChannels("scope1", "collection1", ch.AtSequence(ch.BaseSetOf(t, "ABC", "DEF", "GHI", "JKL"), 1), 1)
require.NoError(t, auth.Save(role1))
} else {
role2, err = auth.NewRole("role2", nil)
require.NoError(t, err)
user1.SetExplicitRoles(ch.TimedSet{"role1": ch.NewVbSimpleSequence(1), "role2": ch.NewVbSimpleSequence(1)}, 1)
require.NoError(t, auth.Save(user1))
role1.SetCollectionExplicitChannels("scope1", "collection1", ch.AtSequence(ch.BaseSetOf(t, "ABC", "DEF", "GHI", "JKL"), 1), 1)
role2.SetCollectionExplicitChannels("scope1", "collection2", ch.AtSequence(ch.BaseSetOf(t, "MNO", "PQR"), 1), 1)
require.NoError(t, auth.Save(role1))
require.NoError(t, auth.Save(role2))
}
_, err = auth.AuthenticateUser("user1", "pass")
require.Error(t, err)
})
}
}

func TestServerlessChannelLimits(t *testing.T) {

testCases := []struct {
Name string
Collection bool
}{
{
Name: "Collection not enabled",
Collection: false,
},
{
Name: "Collection is enabled",
Collection: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.Name, func(t *testing.T) {
testBucket := base.GetTestBucket(t)
defer testBucket.Close()
dataStore := testBucket.GetSingleDataStore()

opts := DefaultAuthenticatorOptions()
opts.ServerlessChannelThreshold = 5
if testCase.Collection {
opts.Collections = map[string]map[string]struct{}{
"scope1": {"collection1": struct{}{}, "collection2": struct{}{}},
}
}
auth := NewAuthenticator(dataStore, nil, opts)
user1, err := auth.NewUser("user1", "pass", ch.BaseSetOf(t, "ABC"))
require.NoError(t, err)
err = auth.Save(user1)
require.NoError(t, err)
_, err = auth.AuthenticateUser("user1", "pass")
require.NoError(t, err)

if !testCase.Collection {
user1.SetCollectionExplicitChannels("_default", "_default", ch.AtSequence(ch.BaseSetOf(t, "ABC", "DEF", "GHI", "JKL", "MNO", "PQR"), 1), 1)
err = auth.Save(user1)
require.NoError(t, err)
} else {
user1.SetCollectionExplicitChannels("scope1", "collection1", ch.AtSequence(ch.BaseSetOf(t, "ABC", "DEF", "GHI", "JKL"), 1), 1)
user1.SetCollectionExplicitChannels("scope1", "collection2", ch.AtSequence(ch.BaseSetOf(t, "MNO", "PQR"), 1), 1)
err = auth.Save(user1)
require.NoError(t, err)
}
_, err = auth.AuthenticateUser("user1", "pass")
require.Error(t, err)
assert.Contains(t, err.Error(), base.ErrMaximumChannelsForUserExceeded.Error())
})
}
}

func TestInvalidateRoles(t *testing.T) {
testBucket := base.GetTestBucket(t)
defer testBucket.Close()
Expand Down
3 changes: 3 additions & 0 deletions auth/principal.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package auth

import (
"sync"
"time"

"github.com/couchbase/sync_gateway/base"
Expand Down Expand Up @@ -125,6 +126,8 @@ type User interface {

InitializeRoles()

GetWarnChanSync() *sync.Once

revokedChannels(since uint64, lowSeq uint64, triggeredBy uint64) RevokedChannels

// Obtains the period over which the user had access to the given channel. Either directly or via a role.
Expand Down
4 changes: 4 additions & 0 deletions auth/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ func (user *userImpl) SetEmail(email string) error {
return nil
}

func (user *userImpl) GetWarnChanSync() *sync.Once {
return &user.warnChanThresholdOnce
}

func (user *userImpl) RoleNames() ch.TimedSet {
if user.RoleInvalSeq != 0 {
return nil
Expand Down
3 changes: 3 additions & 0 deletions base/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ const (
// DefaultJavascriptTimeoutSecs is number of seconds before Javascript functions (i.e. the sync function or import filter) timeout
// If set to zero, timeout is disabled.
DefaultJavascriptTimeoutSecs = uint32(0)

// ServerlessChannelLimit is hard limit on channels allowed per user when running in serverless mode
ServerlessChannelLimit = 500
)

const (
Expand Down
5 changes: 5 additions & 0 deletions base/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ var (

// ErrConfigRegistryReloadRequired is returned when a db config fetch requires a registry reload based on version mismatch (config is newer)
ErrConfigRegistryReloadRequired = &sgError{"Config registry reload required"}

// ErrMaximumChannelsForUserExceeded is returned when running in serverless mode and the user has more than 500 channels granted to them
ErrMaximumChannelsForUserExceeded = &sgError{fmt.Sprintf("User has exceeded maximum of %d channels", ServerlessChannelLimit)}
)

func (e *sgError) Error() string {
Expand Down Expand Up @@ -115,6 +118,8 @@ func ErrorAsHTTPStatus(err error) (int, string) {
return http.StatusRequestEntityTooLarge, "Document too large!"
case ErrViewTimeoutError:
return http.StatusServiceUnavailable, unwrappedErr.Error()
case ErrMaximumChannelsForUserExceeded:
return http.StatusInternalServerError, "Maximum number of channels exceeded for this user"
}

// gocb V2 errors
Expand Down
19 changes: 12 additions & 7 deletions db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -1063,16 +1063,21 @@ func (context *DatabaseContext) Authenticator(ctx context.Context) *auth.Authent
if context.Options.UnsupportedOptions != nil && context.Options.UnsupportedOptions.WarningThresholds != nil {
channelsWarningThreshold = context.Options.UnsupportedOptions.WarningThresholds.ChannelsPerUser
}
var channelServerlessThreshold uint32
if context.IsServerless() {
channelServerlessThreshold = base.ServerlessChannelLimit
}

// Authenticators are lightweight & stateless, so it's OK to return a new one every time
authenticator := auth.NewAuthenticator(context.MetadataStore, context, auth.AuthenticatorOptions{
ClientPartitionWindow: context.Options.ClientPartitionWindow,
ChannelsWarningThreshold: channelsWarningThreshold,
SessionCookieName: sessionCookieName,
BcryptCost: context.Options.BcryptCost,
LogCtx: ctx,
Collections: context.CollectionNames,
MetaKeys: context.MetadataKeys,
ClientPartitionWindow: context.Options.ClientPartitionWindow,
ChannelsWarningThreshold: channelsWarningThreshold,
ServerlessChannelThreshold: channelServerlessThreshold,
SessionCookieName: sessionCookieName,
BcryptCost: context.Options.BcryptCost,
LogCtx: ctx,
Collections: context.CollectionNames,
MetaKeys: context.MetadataKeys,
})

return authenticator
Expand Down