diff --git a/auth/auth.go b/auth/auth.go index ce4a5bc442..aa3d22ed4a 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -32,11 +32,12 @@ type Authenticator struct { } type AuthenticatorOptions struct { - ClientPartitionWindow time.Duration - ChannelsWarningThreshold *uint32 - SessionCookieName string - BcryptCost int - LogCtx context.Context + ClientPartitionWindow time.Duration + ChannelsWarningThreshold *uint32 + ServerlessChannelThreshold uint32 + SessionCookieName string + BcryptCost int + LogCtx context.Context // Collections defines the set of collections used by the authenticator when rebuilding channels. // Channels are only recomputed for collections included in this set. @@ -196,6 +197,17 @@ func (auth *Authenticator) getPrincipal(docID string, factory func() Principal) } changed = true } + // If the channel threshold has been set we need to check the inherited channels across all scopes and collections against the limit + if auth.ServerlessChannelThreshold != 0 { + channelsLength, err := auth.getInheritedChannelsLength(user) + if err != nil { + return nil, nil, false, err + } + err = auth.checkChannelLimits(channelsLength, user) + if err != nil { + return nil, nil, false, err + } + } } if changed { @@ -223,6 +235,73 @@ func (auth *Authenticator) getPrincipal(docID string, factory func() Principal) return princ, nil } +// inheritedCollectionChannels returns channels for a given scope + collection +func (auth *Authenticator) inheritedCollectionChannels(user User, scope, collection string) (ch.TimedSet, error) { + roles, err := auth.getUserRoles(user) + if err != nil { + return nil, err + } + + channels := user.CollectionChannels(scope, collection) + for _, role := range roles { + roleSince := user.RoleNames()[role.Name()] + channels.AddAtSequence(role.CollectionChannels(scope, collection), roleSince.Sequence) + } + return channels, nil +} + +// getInheritedChannelsLength returns number of channels a user has access to across all collections +func (auth *Authenticator) getInheritedChannelsLength(user User) (int, error) { + var cumulativeChannels int + for scope, collections := range auth.Collections { + for collection := range collections { + channels, err := auth.inheritedCollectionChannels(user, scope, collection) + if err != nil { + return 0, err + } + cumulativeChannels += len(channels) + } + } + return cumulativeChannels, nil +} + +// checkChannelLimits logs a warning when the warning threshold is met and will return an error when the channel limit is met +func (auth *Authenticator) checkChannelLimits(channels int, user User) error { + // Error if ServerlessChannelThreshold is set and is >= than the threshold + if uint32(channels) >= auth.ServerlessChannelThreshold { + base.ErrorfCtx(auth.LogCtx, "User ID: %v channel count: %d exceeds %d for channels per user threshold. Auth will be rejected until rectified", + base.UD(user.Name()), channels, auth.ServerlessChannelThreshold) + return base.ErrMaximumChannelsForUserExceeded + } + + // This function is likely to be called once per session when a channel limit is applied, the sync once + // applied here ensures we don't fill logs with warnings about being over warning threshold. We may want + // to revisit this implementation around the warning threshold in future + user.GetWarnChanSync().Do(func() { + if channelsPerUserThreshold := auth.ChannelsWarningThreshold; channelsPerUserThreshold != nil { + if uint32(channels) >= *channelsPerUserThreshold { + base.WarnfCtx(auth.LogCtx, "User ID: %v channel count: %d exceeds %d for channels per user warning threshold", + base.UD(user.Name()), channels, *channelsPerUserThreshold) + } + } + }) + return nil +} + +// getUserRoles gets all roles a user has been granted +func (auth *Authenticator) getUserRoles(user User) ([]Role, error) { + roles := make([]Role, 0, len(user.RoleNames())) + for name := range user.RoleNames() { + role, err := auth.GetRole(name) + if err != nil { + return nil, err + } else if role != nil { + roles = append(roles, role) + } + } + return roles, nil +} + // Rebuild channels computes the full set of channels for all collections defined for the authenticator. // For each collection in Authenticator.collections: // - if there is no CollectionAccess on the principal for the collection, rebuilds channels for that collection @@ -230,6 +309,7 @@ func (auth *Authenticator) getPrincipal(docID string, factory func() Principal) func (auth *Authenticator) rebuildChannels(princ Principal) (changed bool, err error) { changed = false + for scope, collections := range auth.Collections { for collection, _ := range collections { // If collection channels are nil, they have been invalidated and must be rebuilt @@ -242,6 +322,7 @@ func (auth *Authenticator) rebuildChannels(princ Principal) (changed bool, err e } } } + return changed, nil } diff --git a/auth/auth_test.go b/auth/auth_test.go index 7d2ddc41e7..b6dd83d8aa 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -2759,6 +2759,117 @@ func TestObtainChannelsForDeletedRole(t *testing.T) { } } +func TestServerlessChannelLimitsRoles(t *testing.T) { + testCases := []struct { + Name string + Collection bool + }{ + { + Name: "Single role", + }, + { + Name: "Muliple roles", + }, + } + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + testBucket := base.GetTestBucket(t) + defer testBucket.Close() + dataStore := testBucket.GetSingleDataStore() + var role2 Role + + opts := DefaultAuthenticatorOptions() + opts.ServerlessChannelThreshold = 5 + opts.Collections = map[string]map[string]struct{}{ + "scope1": {"collection1": struct{}{}, "collection2": struct{}{}}, + } + auth := NewAuthenticator(dataStore, nil, opts) + user1, err := auth.NewUser("user1", "pass", ch.BaseSetOf(t, "ABC")) + require.NoError(t, err) + err = auth.Save(user1) + require.NoError(t, err) + _, err = auth.AuthenticateUser("user1", "pass") + require.NoError(t, err) + + role1, err := auth.NewRole("role1", nil) + require.NoError(t, err) + if testCase.Name == "Single role" { + user1.SetExplicitRoles(ch.TimedSet{"role1": ch.NewVbSimpleSequence(1)}, 1) + require.NoError(t, auth.Save(user1)) + _, err = auth.AuthenticateUser("user1", "pass") + require.NoError(t, err) + + role1.SetCollectionExplicitChannels("scope1", "collection1", ch.AtSequence(ch.BaseSetOf(t, "ABC", "DEF", "GHI", "JKL"), 1), 1) + require.NoError(t, auth.Save(role1)) + } else { + role2, err = auth.NewRole("role2", nil) + require.NoError(t, err) + user1.SetExplicitRoles(ch.TimedSet{"role1": ch.NewVbSimpleSequence(1), "role2": ch.NewVbSimpleSequence(1)}, 1) + require.NoError(t, auth.Save(user1)) + role1.SetCollectionExplicitChannels("scope1", "collection1", ch.AtSequence(ch.BaseSetOf(t, "ABC", "DEF", "GHI", "JKL"), 1), 1) + role2.SetCollectionExplicitChannels("scope1", "collection2", ch.AtSequence(ch.BaseSetOf(t, "MNO", "PQR"), 1), 1) + require.NoError(t, auth.Save(role1)) + require.NoError(t, auth.Save(role2)) + } + _, err = auth.AuthenticateUser("user1", "pass") + require.Error(t, err) + }) + } +} + +func TestServerlessChannelLimits(t *testing.T) { + + testCases := []struct { + Name string + Collection bool + }{ + { + Name: "Collection not enabled", + Collection: false, + }, + { + Name: "Collection is enabled", + Collection: true, + }, + } + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + testBucket := base.GetTestBucket(t) + defer testBucket.Close() + dataStore := testBucket.GetSingleDataStore() + + opts := DefaultAuthenticatorOptions() + opts.ServerlessChannelThreshold = 5 + if testCase.Collection { + opts.Collections = map[string]map[string]struct{}{ + "scope1": {"collection1": struct{}{}, "collection2": struct{}{}}, + } + } + auth := NewAuthenticator(dataStore, nil, opts) + user1, err := auth.NewUser("user1", "pass", ch.BaseSetOf(t, "ABC")) + require.NoError(t, err) + err = auth.Save(user1) + require.NoError(t, err) + _, err = auth.AuthenticateUser("user1", "pass") + require.NoError(t, err) + + if !testCase.Collection { + user1.SetCollectionExplicitChannels("_default", "_default", ch.AtSequence(ch.BaseSetOf(t, "ABC", "DEF", "GHI", "JKL", "MNO", "PQR"), 1), 1) + err = auth.Save(user1) + require.NoError(t, err) + } else { + user1.SetCollectionExplicitChannels("scope1", "collection1", ch.AtSequence(ch.BaseSetOf(t, "ABC", "DEF", "GHI", "JKL"), 1), 1) + user1.SetCollectionExplicitChannels("scope1", "collection2", ch.AtSequence(ch.BaseSetOf(t, "MNO", "PQR"), 1), 1) + err = auth.Save(user1) + require.NoError(t, err) + } + _, err = auth.AuthenticateUser("user1", "pass") + require.Error(t, err) + assert.Contains(t, err.Error(), base.ErrMaximumChannelsForUserExceeded.Error()) + }) + } +} + func TestInvalidateRoles(t *testing.T) { ctx := base.TestCtx(t) testBucket := base.GetTestBucket(t) diff --git a/auth/principal.go b/auth/principal.go index 0f9d83ef7d..39c439c749 100644 --- a/auth/principal.go +++ b/auth/principal.go @@ -9,6 +9,7 @@ package auth import ( + "sync" "time" "github.com/couchbase/sync_gateway/base" @@ -125,6 +126,8 @@ type User interface { InitializeRoles() + GetWarnChanSync() *sync.Once + revokedChannels(since uint64, lowSeq uint64, triggeredBy uint64) RevokedChannels // Obtains the period over which the user had access to the given channel. Either directly or via a role. diff --git a/auth/user.go b/auth/user.go index 34582e76d3..54a2ca9090 100644 --- a/auth/user.go +++ b/auth/user.go @@ -183,6 +183,10 @@ func (user *userImpl) SetEmail(email string) error { return nil } +func (user *userImpl) GetWarnChanSync() *sync.Once { + return &user.warnChanThresholdOnce +} + func (user *userImpl) RoleNames() ch.TimedSet { if user.RoleInvalSeq != 0 { return nil diff --git a/db/database.go b/db/database.go index 8508e4bbf3..0934654843 100644 --- a/db/database.go +++ b/db/database.go @@ -835,16 +835,21 @@ func (context *DatabaseContext) Authenticator(ctx context.Context) *auth.Authent if context.Options.UnsupportedOptions != nil && context.Options.UnsupportedOptions.WarningThresholds != nil { channelsWarningThreshold = context.Options.UnsupportedOptions.WarningThresholds.ChannelsPerUser } + var channelServerlessThreshold uint32 + if context.IsServerless() { + channelServerlessThreshold = base.ServerlessChannelLimit + } // Authenticators are lightweight & stateless, so it's OK to return a new one every time authenticator := auth.NewAuthenticator(context.MetadataStore, context, auth.AuthenticatorOptions{ - ClientPartitionWindow: context.Options.ClientPartitionWindow, - ChannelsWarningThreshold: channelsWarningThreshold, - SessionCookieName: sessionCookieName, - BcryptCost: context.Options.BcryptCost, - LogCtx: ctx, - Collections: context.CollectionNames, - MetaKeys: context.MetadataKeys, + ClientPartitionWindow: context.Options.ClientPartitionWindow, + ChannelsWarningThreshold: channelsWarningThreshold, + ServerlessChannelThreshold: channelServerlessThreshold, + SessionCookieName: sessionCookieName, + BcryptCost: context.Options.BcryptCost, + LogCtx: ctx, + Collections: context.CollectionNames, + MetaKeys: context.MetadataKeys, }) return authenticator