Skip to content

Commit

Permalink
Minor function refactors for default privileges
Browse files Browse the repository at this point in the history
Release note: None
  • Loading branch information
RichardJCai committed Jul 20, 2021
1 parent 78eabf5 commit 175b1e6
Show file tree
Hide file tree
Showing 12 changed files with 46 additions and 33 deletions.
2 changes: 1 addition & 1 deletion pkg/ccl/backupccl/restore_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -2408,7 +2408,7 @@ func getRestoringPrivileges(

// TODO(dt): Make this more configurable.
updatedPrivileges = descpb.CreatePrivilegesFromDefaultPrivileges(
parentDB.GetID(), parentDB.DatabaseDesc().GetDefaultPrivileges(), user, tree.Tables,
parentDB.GetID(), parentDB.GetDefaultPrivileges(), user, tree.Tables,
parentDB.GetPrivileges(),
)
}
Expand Down
15 changes: 5 additions & 10 deletions pkg/sql/alter_default_privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,7 @@ func (n *alterDefaultPrivilegesNode) startExec(params runParams) error {
}
}

users, err := params.p.GetAllRoles(params.ctx)
if err != nil {
return err
}

if err = validateGrantees(users, targetRoles); err != nil {
if err := params.p.validateRoles(params.ctx, targetRoles, false /* isPublicValid */); err != nil {
return err
}

Expand All @@ -105,7 +100,7 @@ func (n *alterDefaultPrivilegesNode) startExec(params runParams) error {
granteesSQLUsername[i] = user
}

if err = validateGrantees(users, granteesSQLUsername); err != nil {
if err := params.p.validateRoles(params.ctx, granteesSQLUsername, true /* isPublicValid */); err != nil {
return err
}

Expand All @@ -132,11 +127,11 @@ func (n *alterDefaultPrivilegesNode) startExec(params runParams) error {
return err
}

if n.dbDesc.DatabaseDesc().GetDefaultPrivileges() == nil {
n.dbDesc.DatabaseDesc().DefaultPrivileges = descpb.InitDefaultPrivilegeDescriptor()
if n.dbDesc.GetDefaultPrivileges() == nil {
n.dbDesc.SetInitialDefaultPrivilegeDescriptor(descpb.InitDefaultPrivilegeDescriptor())
}

defaultPrivs := n.dbDesc.DatabaseDesc().GetDefaultPrivileges()
defaultPrivs := n.dbDesc.GetDefaultPrivileges()

for _, targetRole := range targetRoles {
if n.n.IsGrant {
Expand Down
8 changes: 8 additions & 0 deletions pkg/sql/catalog/dbdesc/database_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,11 @@ func (desc *Mutable) SetRegionConfig(cfg *descpb.DatabaseDescriptor_RegionConfig
func (desc *Mutable) HasPostDeserializationChanges() bool {
return desc.changed
}

// SetInitialDefaultPrivilegeDescriptor sets the initial default privilege descriptor
// for the database.
func (desc *Mutable) SetInitialDefaultPrivilegeDescriptor(
defaultPrivilegeDescriptor *descpb.DefaultPrivilegeDescriptor,
) {
desc.DefaultPrivileges = defaultPrivilegeDescriptor
}
4 changes: 2 additions & 2 deletions pkg/sql/catalog/descpb/default_privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ func (p *DefaultPrivilegeDescriptor) Validate() error {
privilegeObjectType := targetObjectToPrivilegeObject[objectType]
valid, u, remaining := defaultPrivileges.IsValidPrivilegesForObjectType(privilegeObjectType)
if !valid {
return errors.AssertionFailedf("user %s must not have sv privileges on %s",
u.User(), privilege.ListFromBitField(remaining, privilege.Any), objectType)
return errors.AssertionFailedf("user %s must not have %s privileges on %s",
u.User(), privilege.ListFromBitField(remaining, privilege.Any), privilegeObjectType)
}
}
}
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/catalog/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ type DatabaseDescriptor interface {
ForEachSchemaInfo(func(id descpb.ID, name string, isDropped bool) error) error
GetSchemaID(name string) descpb.ID
GetNonDroppedSchemaName(schemaID descpb.ID) string
GetDefaultPrivileges() *descpb.DefaultPrivilegeDescriptor
}

// TableDescriptor is an interface around the table descriptor types.
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/create_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func CreateUserDefinedSchemaDescriptor(
}

privs := descpb.CreatePrivilegesFromDefaultPrivileges(
db.GetID(), db.DatabaseDesc().GetDefaultPrivileges(), user, tree.Schemas, db.GetPrivileges(),
db.GetID(), db.GetDefaultPrivileges(), user, tree.Schemas, db.GetPrivileges(),
)

if !n.AuthRole.Undefined() {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/create_sequence.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func doCreateSequence(
}

privs := descpb.CreatePrivilegesFromDefaultPrivileges(
dbDesc.GetID(), dbDesc.DatabaseDesc().GetDefaultPrivileges(),
dbDesc.GetID(), dbDesc.GetDefaultPrivileges(),
params.SessionData().User(), tree.Sequences, dbDesc.GetPrivileges(),
)

Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ func (n *createTableNode) startExec(params runParams) error {
}

privs := descpb.CreatePrivilegesFromDefaultPrivileges(
n.dbDesc.GetID(), n.dbDesc.DatabaseDesc().GetDefaultPrivileges(),
n.dbDesc.GetID(), n.dbDesc.GetDefaultPrivileges(),
params.SessionData().User(), tree.Tables, n.dbDesc.GetPrivileges(),
)
var desc *tabledesc.Mutable
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/create_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ func CreateEnumTypeDesc(
}

privs := descpb.CreatePrivilegesFromDefaultPrivileges(
dbDesc.GetID(), dbDesc.DatabaseDesc().GetDefaultPrivileges(),
dbDesc.GetID(), dbDesc.GetDefaultPrivileges(),
params.p.User(), tree.Types, dbDesc.GetPrivileges(),
)

Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/create_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (n *createViewNode) startExec(params runParams) error {
}

privs := descpb.CreatePrivilegesFromDefaultPrivileges(
n.dbDesc.GetID(), n.dbDesc.DatabaseDesc().GetDefaultPrivileges(),
n.dbDesc.GetID(), n.dbDesc.GetDefaultPrivileges(),
params.SessionData().User(), tree.Tables, n.dbDesc.GetPrivileges(),
)

Expand Down
31 changes: 16 additions & 15 deletions pkg/sql/grant_revoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,13 @@ func (n *changePrivilegesNode) ReadingOwnWrites() {}
func (n *changePrivilegesNode) startExec(params runParams) error {
ctx := params.ctx
p := params.p
// Check whether grantees exists
users, err := p.GetAllRoles(ctx)
if err != nil {
return err
}

// We're allowed to grant/revoke privileges to/from the "public" role even though
// it does not exist: add it to the list of all users and roles.
users[security.PublicRoleName()] = true // isRole

if err = validateGrantees(users, n.grantees); err != nil {
if err := p.validateRoles(ctx, n.grantees, true /* isPublicValid */); err != nil {
return err
}

var descriptors []catalog.Descriptor
var err error
// DDL statements avoid the cache to avoid leases, and can view non-public descriptors.
// TODO(vivek): check if the cache can be used.
p.runWithOptions(resolveFlags{skipCache: true}, func() {
Expand Down Expand Up @@ -324,12 +316,21 @@ func getGrantOnObject(targets tree.TargetList, incIAMFunc func(on string)) privi
}
}

// validateGrantees checks that all the grantees are valid users.
// users should be returned the result of GetAllRoles.
func validateGrantees(users map[security.SQLUsername]bool, grantees []security.SQLUsername) error {
for i, grantee := range grantees {
// validateRoles checks that all the roles are valid users.
// isPublicValid determines whether or not Public is a valid role.
func (p *planner) validateRoles(
ctx context.Context, roles []security.SQLUsername, isPublicValid bool,
) error {
users, err := p.GetAllRoles(ctx)
if err != nil {
return err
}
if isPublicValid {
users[security.PublicRoleName()] = true // isRole
}
for i, grantee := range roles {
if _, ok := users[grantee]; !ok {
sqlName := tree.Name(grantees[i].Normalized())
sqlName := tree.Name(roles[i].Normalized())
return errors.Errorf("user or role %s does not exist", &sqlName)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,11 @@ database_name schema_name table_name grantee privilege_type
d public t12 admin ALL
d public t12 root ALL
d public t12 testuser CREATE

# Cannot specify PUBLIC as the target role.
statement error pq: user or role public does not exist
ALTER DEFAULT PRIVILEGES FOR ROLE public REVOKE SELECT ON TABLES FROM testuser2, testuser3

# Can specify PUBLIC as a grantee.
statement ok
ALTER DEFAULT PRIVILEGES REVOKE SELECT ON TABLES FROM public

0 comments on commit 175b1e6

Please sign in to comment.