diff --git a/master/internal/db/postgres_scim.go b/master/internal/db/postgres_scim.go deleted file mode 100644 index 5fcc3cb99ea9..000000000000 --- a/master/internal/db/postgres_scim.go +++ /dev/null @@ -1,414 +0,0 @@ -package db - -import ( - "context" - "database/sql" - "fmt" - "sort" - - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" - - "github.com/determined-ai/determined/master/pkg/model" -) - -// scimUserRow is a row in the SCIM table. The SCIM table contains the -// additional information needed to implement the SCIM protocol. This differs -// from model.SCIMUser because the latter is the result of joining a scimUserRow -// with a model.User. -type scimUserRow struct { - ID model.UUID `db:"id"` - UserID model.UserID `db:"user_id"` - ExternalID string `db:"external_id"` - Name model.SCIMName `db:"name"` - Emails model.SCIMEmails `db:"emails"` - RawAttributes map[string]interface{} `db:"raw_attributes"` -} - -// RetrofitSCIMUser "upgrades" an existing user to one tracked in the SCIM table. This is a -// temporary measure for SaaS clusters to migrate existing users to SCIM users. -func (db *PgDB) RetrofitSCIMUser(suser *model.SCIMUser, userID model.UserID) (*model.SCIMUser, - error, -) { - row := &scimUserRow{ - ExternalID: suser.ExternalID, - Emails: suser.Emails, - Name: suser.Name, - RawAttributes: suser.RawAttributes, - } - - tx, err := db.sql.Beginx() - if err != nil { - return nil, errors.WithStack(err) - } - - defer func() { - if tx == nil { - return - } - - if rErr := tx.Rollback(); rErr != nil { - log.Errorf("error during rollback: %v", rErr) - } - }() - - id, err := addSCIMUser(tx, userID, row) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, errors.WithStack(err) - } - - tx = nil - - added := *suser - added.ID = id - - return &added, nil -} - -// AddSCIMUser adds a user as well as additional SCIM-specific fields. If -// the user already exists, this function will return an error. -func (db *PgDB) AddSCIMUser(suser *model.SCIMUser) (*model.SCIMUser, error) { - row := &scimUserRow{ - ExternalID: suser.ExternalID, - Emails: suser.Emails, - Name: suser.Name, - RawAttributes: suser.RawAttributes, - } - - user := &model.User{ - Username: suser.Username, - Active: true, - PasswordHash: suser.PasswordHash, - Remote: true, - } - - tx, err := db.sql.Beginx() - if err != nil { - return nil, errors.WithStack(err) - } - - defer func() { - if tx == nil { - return - } - - if rErr := tx.Rollback(); rErr != nil { - log.Errorf("error during rollback: %v", rErr) - } - }() - - userID, err := HackAddUser(context.TODO(), user) - if err != nil { - return nil, err - } - - id, err := addSCIMUser(tx, userID, row) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, errors.WithStack(err) - } - - tx = nil - - added := *suser - added.ID = id - - return &added, nil -} - -func addSCIMUser(tx *sqlx.Tx, userID model.UserID, row *scimUserRow) (model.UUID, error) { - next := *row - next.UserID = userID - next.ID = model.NewUUID() - - stmt, err := tx.PrepareNamed(` -INSERT INTO scim.users -(id, user_id, external_id, name, emails, raw_attributes) -VALUES (:id, :user_id, :external_id, :name, :emails, :raw_attributes)`) - if err != nil { - return model.UUID{}, errors.WithStack(err) - } - defer stmt.Close() - - if _, err := stmt.Exec(next); err != nil { - return model.UUID{}, errors.WithStack(err) - } - - return next.ID, nil -} - -// SCIMUserList returns at most count SCIM users starting at startIndex -// (1-indexed). If username is set, restrict results to users with the matching -// username. -func (db *PgDB) SCIMUserList(startIndex, count int, username string) (*model.SCIMUsers, error) { - var rows *sqlx.Rows - var err error - - if len(username) == 0 { - rows, err = db.sql.Queryx(` -SELECT - s.id, u.username, s.external_id, s.name, s.emails, u.active -FROM users u, scim.users s -WHERE u.id = s.user_id -ORDER BY id`) - } else { - rows, err = db.sql.Queryx(` -SELECT - s.id, u.username, s.external_id, s.name, s.emails, u.active -FROM users u, scim.users s -WHERE u.id = s.user_id AND u.username = $1 -ORDER BY id`, username) - } - - if err != nil { - return nil, errors.WithStack(err) - } - defer rows.Close() - - var users []*model.SCIMUser - for rows.Next() { - var u model.SCIMUser - if err := rows.StructScan(&u); err != nil { - return nil, errors.WithStack(err) - } - users = append(users, &u) - } - - offset := startIndex - if offset > 0 { - // startIndex is 1-indexed according to the SCIM specification. - offset-- - } - - total := len(users) - if offset > total { - offset = total - } - if offset+count > total { - count = total - offset - } - - startIndex = offset + 1 - - return &model.SCIMUsers{ - TotalResults: total, - StartIndex: startIndex, - Resources: users[offset : offset+count], - ItemsPerPage: count, - }, nil -} - -// SCIMUserByID returns the SCIM user with the given ID. -func (db *PgDB) SCIMUserByID(id model.UUID) (*model.SCIMUser, error) { - tx, err := db.sql.Beginx() - if err != nil { - return nil, errors.WithStack(err) - } - - defer func() { - if tx == nil { - return - } - - if rErr := tx.Rollback(); rErr != nil { - log.Errorf("error during rollback: %v", rErr) - } - }() - - suser, err := db.scimUserByID(tx, id) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, errors.WithStack(err) - } - - tx = nil - - return suser, nil -} - -func (db *PgDB) scimUserByID(tx *sqlx.Tx, id model.UUID) (*model.SCIMUser, error) { - var suser model.SCIMUser - if err := tx.QueryRowx(` -SELECT - s.id, u.username, s.external_id, s.name, s.emails, u.active -FROM users u, scim.users s -WHERE u.id = s.user_id AND s.id = $1`, id).StructScan(&suser); err == sql.ErrNoRows { - return nil, errors.WithStack(ErrNotFound) - } else if err != nil { - return nil, errors.WithStack(err) - } - - return &suser, nil -} - -// SCIMUserByAttribute returns the SCIM user with the given value for the given attribute. -func (db *PgDB) SCIMUserByAttribute(name, value string) (*model.SCIMUser, error) { - var suser model.SCIMUser - err := db.sql.QueryRowx(` -SELECT - s.id, u.username, s.external_id, s.name, s.emails, u.active -FROM users u, scim.users s -WHERE u.id = s.user_id AND s.raw_attributes->>$1 = $2`, name, value).StructScan(&suser) - if err == sql.ErrNoRows { - return nil, errors.WithStack(ErrNotFound) - } else if err != nil { - return nil, errors.WithStack(err) - } - - return &suser, nil -} - -// UserBySCIMAttribute returns the user with the given value for the given SCIM attribute. -func (db *PgDB) UserBySCIMAttribute(name, value string) (*model.User, error) { - var user model.User - err := db.sql.QueryRowx(` -SELECT - u.* -FROM users u, scim.users s -WHERE u.id = s.user_id AND s.raw_attributes->>$1 = $2`, name, value).StructScan(&user) - if err == sql.ErrNoRows { - return nil, errors.WithStack(ErrNotFound) - } else if err != nil { - return nil, errors.WithStack(err) - } - - return &user, nil -} - -// SetSCIMUser updates fields on an existing SCIM user. -func (db *PgDB) SetSCIMUser(id string, user *model.SCIMUser) (*model.SCIMUser, error) { - return db.UpdateSCIMUser(id, user, - []string{ - "active", - "emails", - "external_id", - "name", - "username", - "password_hash", - "raw_attributes", - }) -} - -// UpdateSCIMUser updates some fields on an existing SCIM user. -func (db *PgDB) UpdateSCIMUser( - id string, - user *model.SCIMUser, - fields []string, -) (*model.SCIMUser, error) { - if e, f := id, user.ID.String(); e != f { - return nil, errors.Errorf("user ID %s does not match updated user ID %s", e, f) - } - - tx, err := db.sql.Beginx() - if err != nil { - return nil, errors.WithStack(err) - } - - defer func() { - if tx == nil { - return - } - - if rErr := tx.Rollback(); rErr != nil { - log.Errorf("error during rollback: %v", rErr) - } - }() - - if err = db.updateSCIMUser(tx, user, fields); err != nil { - return nil, err - } - - updated, err := db.scimUserByID(tx, user.ID) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, errors.WithStack(err) - } - - tx = nil - - return updated, nil -} - -func (db *PgDB) updateSCIMUser(tx *sqlx.Tx, user *model.SCIMUser, fields []string) error { - fieldSet := make(map[string]bool) - for _, v := range fields { - fieldSet[v] = true - } - - var usersFields []string - for _, v := range []string{"active", "username", "password_hash"} { - if fieldSet[v] { - usersFields = append(usersFields, v) - } - delete(fieldSet, v) - } - - var scimUsersFields []string - for v := range fieldSet { - scimUsersFields = append(scimUsersFields, v) - } - sort.Strings(scimUsersFields) - - if fs := usersFields; len(fs) > 0 { - stmt, err := tx.PrepareNamed(fmt.Sprintf(` -UPDATE users -%v -WHERE id = (SELECT user_id FROM scim.users s WHERE s.id = :id)`, SetClause(fs))) - if err == sql.ErrNoRows { - return errors.WithStack(ErrNotFound) - } else if err != nil { - return errors.WithStack(err) - } - - defer stmt.Close() - - if _, err := stmt.Exec(user); err != nil { - return errors.WithStack(err) - } - } - - if fs := scimUsersFields; len(fs) > 0 { - stmt, err := tx.PrepareNamed(fmt.Sprintf(` -UPDATE scim.users -%v -WHERE id = :id`, SetClause(fs))) - if err == sql.ErrNoRows { - return errors.WithStack(ErrNotFound) - } else if err != nil { - return errors.WithStack(err) - } - - defer stmt.Close() - - if _, err := stmt.Exec(user); err != nil { - return errors.WithStack(err) - } - } - - return nil -} - -// DeleteSessionsForSCIMUser deletes sessions belonging to a given scim user ID. -func (db *PgDB) DeleteSessionsForSCIMUser(user *model.SCIMUser) error { - _, err := db.sql.Exec(` -DELETE FROM user_sessions -WHERE user_id IN (SELECT u.id - FROM users u - JOIN scim.users su on u.id = su.user_id - WHERE su.id = $1)`, user.ID) - return err -} diff --git a/master/internal/plugin/oidc/service.go b/master/internal/plugin/oidc/service.go index 8590758094bb..246881785bfd 100644 --- a/master/internal/plugin/oidc/service.go +++ b/master/internal/plugin/oidc/service.go @@ -252,7 +252,7 @@ func (s *Service) toIDTokenClaim(userInfo *oidc.UserInfo) (*IDTokenClaims, error // lookupUser: First try finding user in our users.scim table. // If we don't find them there and the scim attribute is userName & look in the user table. func (s *Service) lookupUser(ctx context.Context, claimValue string) (*model.User, error) { - u, err := s.db.UserBySCIMAttribute(s.config.SCIMAuthenticationAttribute, claimValue) + u, err := user.UserBySCIMAttribute(ctx, s.config.SCIMAuthenticationAttribute, claimValue) if errors.Is(err, db.ErrNotFound) { if s.config.SCIMAuthenticationAttribute != "userName" { return nil, errNotProvisioned diff --git a/master/internal/plugin/scim/service.go b/master/internal/plugin/scim/service.go index f6543c54474e..7fc3533e2c4f 100644 --- a/master/internal/plugin/scim/service.go +++ b/master/internal/plugin/scim/service.go @@ -28,6 +28,7 @@ import ( "github.com/determined-ai/determined/master/internal/config" "github.com/determined-ai/determined/master/internal/db" "github.com/determined-ai/determined/master/internal/plugin/oauth" + "github.com/determined-ai/determined/master/internal/user" "github.com/determined-ai/determined/master/pkg/check" "github.com/determined-ai/determined/master/pkg/model" ) @@ -114,7 +115,7 @@ func (s *service) GetUsers(c echo.Context) (interface{}, error) { } } - users, err := s.db.SCIMUserList(startIndex, count, username) + users, err := user.SCIMUserList(c.Request().Context(), startIndex, count, username) if err != nil { return nil, err } @@ -146,16 +147,16 @@ func (s *service) GetUser(c echo.Context) (interface{}, error) { return nil, newNotFoundError(err) } - user, err := s.db.SCIMUserByID(id) + u, err := user.SCIMUserByID(c.Request().Context(), db.Bun(), id) if err != nil { return nil, err } - if err := user.SetSCIMFields(s.locationRoot); err != nil { + if err := u.SetSCIMFields(s.locationRoot); err != nil { return nil, err } - return user, nil + return u, nil } // PostUser creates a new SCIM user. @@ -165,28 +166,28 @@ func (s *service) PostUser(c echo.Context) (interface{}, error) { return nil, newBadRequestError(err) } - var user model.SCIMUser - if err = json.Unmarshal(body, &user); err != nil { + var u model.SCIMUser + if err = json.Unmarshal(body, &u); err != nil { return nil, newBadRequestError(err) } - if err = json.Unmarshal(body, &user.RawAttributes); err != nil { + if err = json.Unmarshal(body, &u.RawAttributes); err != nil { return nil, newBadRequestError(err) } - if err = check.Validate(user); err != nil { + if err = check.Validate(u); err != nil { return nil, newBadRequestError(err) - } else if user.ID.Valid { + } else if u.ID.Valid { return nil, newBadRequestError(errors.New("ID set")) } - user.Sanitize() + u.Sanitize() - err = user.UpdatePasswordHash(user.Password) + err = u.UpdatePasswordHash(u.Password) if err != nil { return nil, errors.WithStack(err) } - added, err := s.db.AddSCIMUser(&user) + added, err := user.AddSCIMUser(c.Request().Context(), &u) if err == db.ErrDuplicateRecord { return nil, newConflictError(err) } else if err != nil { @@ -219,38 +220,32 @@ func (s *service) PutUser(c echo.Context) (interface{}, error) { return nil, newBadRequestError(err) } - var user model.SCIMUser - if err = json.Unmarshal(body, &user); err != nil { + var u model.SCIMUser + if err = json.Unmarshal(body, &u); err != nil { return nil, newBadRequestError(err) } - if err = json.Unmarshal(body, &user.RawAttributes); err != nil { + if err = json.Unmarshal(body, &u.RawAttributes); err != nil { return nil, newBadRequestError(err) } - if err = check.Validate(user); err != nil { + if err = check.Validate(u); err != nil { return nil, newBadRequestError(err) - } else if user.ID.String() != req.ID { + } else if u.ID.String() != req.ID { return nil, newBadRequestError(errors.New("ID does not match path")) } - user.Sanitize() + u.Sanitize() - err = user.UpdatePasswordHash(user.Password) + err = u.UpdatePasswordHash(u.Password) if err != nil { return nil, errors.WithStack(err) } - updated, err := s.db.SetSCIMUser(req.ID, &user) + updated, err := user.SetSCIMUser(c.Request().Context(), req.ID, &u) if err != nil { return nil, err } - if !updated.Active { - if err := s.db.DeleteSessionsForSCIMUser(updated); err != nil { - return nil, err - } - } - if err := updated.SetSCIMFields(s.locationRoot); err != nil { return nil, err } @@ -350,17 +345,11 @@ func (s *service) PatchUser(c echo.Context) (interface{}, error) { return nil, newBadRequestError(err) } - updated, err := s.db.UpdateSCIMUser(req.ID, &changes, toUpdate) + updated, err := user.UpdateUserAndDeleteSession(c.Request().Context(), req.ID, &changes, toUpdate) if err != nil { return nil, err } - if !updated.Active { - if err := s.db.DeleteSessionsForSCIMUser(updated); err != nil { - return nil, err - } - } - if err := updated.SetSCIMFields(s.locationRoot); err != nil { return nil, err } diff --git a/master/internal/user/external_users.go b/master/internal/user/external_users.go index e58b109085d2..4de60e7a5fb3 100644 --- a/master/internal/user/external_users.go +++ b/master/internal/user/external_users.go @@ -68,8 +68,8 @@ func ByExternalToken(ctx context.Context, tokenText string, scimLock.Lock() defer scimLock.Unlock() - scimUser, err := db.SingleDB().SCIMUserByAttribute("user_id", claims.UserID) - var user *model.User + scimUser, err := scimUserByAttribute(ctx, "user_id", claims.UserID) + var u *model.User if err != nil { if !errors.Is(err, db.ErrNotFound) { return nil, nil, err @@ -88,24 +88,24 @@ func ByExternalToken(ctx context.Context, tokenText string, } // Check for the temporary case where their email exists in users but no SCIM user exists - user, err = ByUsername(context.TODO(), claims.Email) + u, err = ByUsername(ctx, claims.Email) if err != nil { - if err != db.ErrNotFound { + if !errors.Is(err, db.ErrNotFound) { return nil, nil, err } // Legacy user was not found, so creating... - _, err = db.SingleDB().AddSCIMUser(scimUser) + _, err = AddSCIMUser(ctx, scimUser) if err != nil { return nil, nil, errors.WithStack(err) } - user, err = db.SingleDB().UserBySCIMAttribute("user_id", claims.UserID) + u, err = UserBySCIMAttribute(ctx, "user_id", claims.UserID) if err != nil { return nil, nil, errors.WithStack(err) } } else { // Legacy user was found, so retrofit it... - _, err = db.SingleDB().RetrofitSCIMUser(scimUser, user.ID) + _, err = retrofitSCIMUser(ctx, scimUser, u.ID) if err != nil { return nil, nil, errors.WithStack(err) } @@ -113,7 +113,7 @@ func ByExternalToken(ctx context.Context, tokenText string, } else { // Existing SCIM user was found: retrieve or update all details. - user, err = db.SingleDB().UserBySCIMAttribute("user_id", claims.UserID) + u, err = UserBySCIMAttribute(ctx, "user_id", claims.UserID) if err != nil { return nil, nil, errors.WithStack(err) } @@ -122,23 +122,23 @@ func ByExternalToken(ctx context.Context, tokenText string, scimUser.Name = model.SCIMNameFromJWT(claims) scimUser.Username = claims.Email - _, err = db.SingleDB().SetSCIMUser(scimUser.ID.String(), scimUser) + _, err = SetSCIMUser(ctx, scimUser.ID.String(), scimUser) if err != nil { return nil, nil, errors.WithStack(err) } - user.Username = claims.Email - user.Admin = isAdmin - user.Active = true + u.Username = claims.Email + u.Admin = isAdmin + u.Active = true - err = Update(context.TODO(), user, []string{"username", "admin", "active"}, nil) + err = Update(ctx, u, []string{"username", "admin", "active"}, nil) if err != nil { return nil, nil, errors.WithStack(err) } } - user = &model.User{ - ID: user.ID, + u = &model.User{ + ID: u.ID, Username: claims.Email, PasswordHash: null.NewString("", false), Admin: isAdmin, @@ -146,10 +146,10 @@ func ByExternalToken(ctx context.Context, tokenText string, } session := &model.UserSession{ - ID: model.SessionID(user.ID), - UserID: user.ID, + ID: model.SessionID(u.ID), + UserID: u.ID, Expiry: time.Unix(claims.ExpiresAt, 0), } - return user, session, nil + return u, session, nil } diff --git a/master/internal/user/postgres_scim_users.go b/master/internal/user/postgres_scim_users.go new file mode 100644 index 000000000000..af14c3a98eab --- /dev/null +++ b/master/internal/user/postgres_scim_users.go @@ -0,0 +1,274 @@ +package user + +import ( + "context" + "database/sql" + "fmt" + + "github.com/pkg/errors" + "github.com/uptrace/bun" + + "github.com/determined-ai/determined/master/internal/db" + "github.com/determined-ai/determined/master/pkg/model" + "github.com/determined-ai/determined/master/pkg/set" +) + +// retrofitSCIMUser "upgrades" an existing user to one tracked in the SCIM table. This is a +// temporary measure for SaaS clusters to migrate existing users to SCIM users. +func retrofitSCIMUser(ctx context.Context, suser *model.SCIMUser, userID model.UserID) (*model.SCIMUser, error) { + suser.UserID = userID + id, err := addSCIMUserTx(ctx, db.Bun(), suser) + if err != nil { + return nil, err + } + + suser.ID = id + + return suser, err +} + +// AddSCIMUser adds a user as well as additional SCIM-specific fields. If +// the user already exists, this function will return an error. +func AddSCIMUser(ctx context.Context, suser *model.SCIMUser) (*model.SCIMUser, error) { + if err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + userID, err := AddUserTx(ctx, tx, &model.User{ + Username: suser.Username, + Active: true, + PasswordHash: suser.PasswordHash, + Remote: true, + }) + if err != nil { + return err + } + + suser.UserID = userID + + id, err := addSCIMUserTx(ctx, tx, suser) + if err != nil { + return err + } + + suser.ID = id + return nil + }); err != nil { + return nil, fmt.Errorf("adding SCIM user: %w", err) + } + + return suser, nil +} + +func addSCIMUserTx(ctx context.Context, tx bun.IDB, user *model.SCIMUser) (model.UUID, error) { + id := model.NewUUID() + s := struct { + bun.BaseModel `bun:"table:scim.users"` + + Name model.SCIMName + ID model.UUID + ExternalID string + UserID model.UserID + Emails model.SCIMEmails + RawAttributes map[string]any + }{ + Name: user.Name, + ID: id, + ExternalID: user.ExternalID, + UserID: user.UserID, + Emails: user.Emails, + RawAttributes: user.RawAttributes, + } + + if _, err := tx.NewInsert().Model(&s).Exec(ctx); err != nil { + return model.UUID{}, errors.WithStack(err) + } + + return id, nil +} + +// SCIMUserList returns at most count SCIM users starting at startIndex +// (1-indexed). If username is set, restrict results to users with the matching +// username. +func SCIMUserList(ctx context.Context, startIndex, count int, username string) (*model.SCIMUsers, error) { + var users []*model.SCIMUser + q := db.Bun().NewSelect().TableExpr("users AS u, scim.users AS s"). + ColumnExpr("s.id, u.username, s.external_id, s.name, s.emails, u.active"). + Where("u.id = s.user_id").Order("id") + if username != "" { + q = q.Where("u.username = ?", username) + } + if err := q.Scan(ctx, &users); err != nil { + return nil, errors.WithStack(err) + } + + offset := startIndex + if offset > 0 { + // startIndex is 1-indexed according to the SCIM specification. + offset-- + } + + total := len(users) + if offset > total { + offset = total + } + if offset+count > total { + count = total - offset + } + + startIndex = offset + 1 + + return &model.SCIMUsers{ + TotalResults: total, + StartIndex: startIndex, + Resources: users[offset : offset+count], + ItemsPerPage: count, + }, nil +} + +// SCIMUserByID returns the SCIM user with the given ID. +func SCIMUserByID(ctx context.Context, tx bun.IDB, id model.UUID) (*model.SCIMUser, error) { + var suser model.SCIMUser + if err := tx.NewSelect().TableExpr("users AS u, scim.users AS s"). + ColumnExpr("s.id, u.username, s.external_id, s.name, s.emails, u.active, s.raw_attributes"). + Where("u.id = s.user_id AND s.id = ?", id).Scan(ctx, &suser); errors.Is(err, sql.ErrNoRows) { + return nil, errors.WithStack(db.ErrNotFound) + } else if err != nil { + return nil, errors.WithStack(err) + } + + return &suser, nil +} + +// scimUserByAttribute returns the SCIM user with the given value for the given attribute. +func scimUserByAttribute(ctx context.Context, name string, value string) (*model.SCIMUser, error) { + var suser model.SCIMUser + if err := db.Bun().NewSelect().TableExpr("users u, scim.users s"). + ColumnExpr("s.id, u.username, s.external_id, s.name, s.emails, u.active, s.raw_attributes"). + Where("u.id = s.user_id AND s.raw_attributes->>? = ?", name, value). + Scan(ctx, &suser); errors.Is(err, sql.ErrNoRows) { + return nil, errors.WithStack(db.ErrNotFound) + } else if err != nil { + return nil, errors.WithStack(err) + } + + return &suser, nil +} + +// UserBySCIMAttribute returns the user with the given value for the given SCIM attribute. +func UserBySCIMAttribute(ctx context.Context, name string, value string) (*model.User, error) { + var user model.User + if err := db.Bun().NewSelect().TableExpr("users AS u, scim.users AS s"). + ColumnExpr("u.id, u.username, u.active, u.password_hash, u.remote"). + Where("u.id = s.user_id AND s.raw_attributes->>?=?", name, value). + Scan(ctx, &user); errors.Is(err, sql.ErrNoRows) { + return nil, errors.WithStack(db.ErrNotFound) + } else if err != nil { + return nil, errors.WithStack(err) + } + + return &user, nil +} + +// SetSCIMUser updates fields on an existing SCIM user. +func SetSCIMUser(ctx context.Context, id string, user *model.SCIMUser) (*model.SCIMUser, error) { + return UpdateUserAndDeleteSession(ctx, id, user, + []string{ + "active", + "emails", + "external_id", + "name", + "username", + "password_hash", + "raw_attributes", + }) +} + +// UpdateUserAndDeleteSession updates some fields on an existing SCIM user and deletes the user session if inactive. +func UpdateUserAndDeleteSession( + ctx context.Context, + id string, + user *model.SCIMUser, + fields []string, +) (*model.SCIMUser, error) { + if userID := user.ID.String(); id != userID { + return nil, errors.Errorf("user ID %s does not match updated user ID %s", id, userID) + } + + var updated *model.SCIMUser + if err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + if err := updateSCIMUser(ctx, tx, user, set.FromSlice(fields)); err != nil { + return err + } + + u, err := SCIMUserByID(ctx, tx, user.ID) + if err != nil { + return err + } + + updated = u + + if !updated.Active { + subq := tx.NewSelect().Column("u.id").TableExpr("users AS u"). + Join("JOIN scim.users su on u.id = su.user_id").Where("su.id = ?", user.ID) + if _, err := db.Bun().NewDelete().Table("user_sessions").Where("user_id IN (?)", subq).Exec(ctx); err != nil { + return fmt.Errorf("deleting user session: %w", err) + } + } + + return nil + }); err != nil { + return nil, fmt.Errorf("updating SCIM user & deleting user session if inactive: %w", err) + } + + return updated, nil +} + +func updateSCIMUser(ctx context.Context, tx bun.IDB, user *model.SCIMUser, fieldSet set.Set[string]) error { + userValues := map[string]interface{}{} + if fieldSet.Contains("active") { + userValues["active"] = user.Active + fieldSet.Remove("active") + } + + if fieldSet.Contains("username") { + userValues["username"] = user.Username + fieldSet.Remove("username") + } + + if fieldSet.Contains("password_hash") { + userValues["password_hash"] = user.PasswordHash + fieldSet.Remove("password_hash") + } + + if len(userValues) > 0 { + q := tx.NewUpdate().Table("users").Model(&userValues).Where("id = (?)", + tx.NewSelect().Column("user_id").TableExpr("scim.users AS s").Where("s.id = ?", user.ID)) + if err := execUpdateSCIMUser(ctx, tx, q); err != nil { + return err + } + } + + if len(fieldSet) > 0 { + q := tx.NewUpdate().ModelTableExpr("?", bun.Safe("scim.users")). + Column(fieldSet.ToSlice()...).Model(user).Where("id = ?", user.ID) + if err := execUpdateSCIMUser(ctx, tx, q); err != nil { + return err + } + } + + return nil +} + +func execUpdateSCIMUser(ctx context.Context, tx bun.IDB, q *bun.UpdateQuery) error { + res, err := q.Exec(ctx) + if err != nil { + return errors.WithStack(err) + } + + num, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("getting rows affected: %w", err) + } else if num == 0 { + return errors.WithStack(db.ErrNotFound) + } + + return nil +} diff --git a/master/internal/user/postgres_scim_users_intg_test.go b/master/internal/user/postgres_scim_users_intg_test.go new file mode 100644 index 000000000000..81ea3ab40492 --- /dev/null +++ b/master/internal/user/postgres_scim_users_intg_test.go @@ -0,0 +1,319 @@ +//go:build integration +// +build integration + +package user + +import ( + "context" + "fmt" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + "gopkg.in/guregu/null.v3" + + "github.com/determined-ai/determined/master/internal/db" + "github.com/determined-ai/determined/master/pkg/model" +) + +// Tests for postgres_scim_users.go. +func TestAddSCIMUser(t *testing.T) { + testUUID := model.NewUUID().String() + + ctx := context.Background() + cases := []struct { + name string + users []*model.SCIMUser + errString string + }{ + {"simple-case", []*model.SCIMUser{mockSCIMUser(t)}, ""}, + {"multiples-case", []*model.SCIMUser{{ + Username: model.NewUUID().String(), + ExternalID: "multiples-external-id", + Name: model.SCIMName{GivenName: "John", FamilyName: "Multiple"}, + Emails: []model.SCIMEmail{ + {Type: "personal", SValue: "value-1", Primary: true}, + {Type: "personal", SValue: "value-2", Primary: false}, + {Type: "personal", SValue: "value-3", Primary: false}, + }, + Active: true, + PasswordHash: null.StringFrom("password"), + RawAttributes: map[string]interface{}{ + "attribute1": true, + "attribute2": "false", + "attribute3": []interface{}{"a", "b", "c"}, + }, + }}, ""}, + {"duplicate-case", []*model.SCIMUser{ + mockSCIMUserWithUsername(t, testUUID), + mockSCIMUserWithUsername(t, testUUID), + }, "duplicate key value violates"}, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + for _, v := range tt.users { + addedUser, err := AddSCIMUser(ctx, v) + if tt.errString != "" && err != nil { + require.Contains(t, err.Error(), tt.errString) + continue + } + + require.NoError(t, err) + dbUser, err := SCIMUserByID(ctx, db.Bun(), v.ID) + require.NoError(t, err) + matchUsers(t, addedUser, dbUser) + + // make sure the user table is updated too + var u *model.FullUser + u, err = ByID(ctx, addedUser.UserID) + require.NoError(t, err) + require.Equal(t, dbUser.Active, u.Active) + require.Equal(t, dbUser.Username, u.Username) + require.Equal(t, dbUser.PasswordHash, u.ToUser().PasswordHash) + } + }) + } +} + +func TestSCIMUserList(t *testing.T) { + uuid1 := model.NewUUID().String() + uuid2 := model.NewUUID().String() + uuid3 := model.NewUUID().String() + + ctx := context.Background() + cases := []struct { + name string + usernameToMatch string + usernames []string + count int + startIndex int + }{ + {"simple-case", "", []string{}, 0, 1}, + {"one-user-added", uuid1, []string{uuid1}, 1, 1}, + {"two-diff-users-added", uuid2, []string{uuid2, model.NewUUID().String()}, 1, 1}, + {"two-diff-users-returned", "", []string{ + model.NewUUID().String(), + model.NewUUID().String(), model.NewUUID().String(), + }, 1, 2}, + {"out-of-bounds-index", uuid3, []string{uuid3, model.NewUUID().String()}, 2, 2}, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + expectedUsers := []*model.SCIMUser{} + for idx, u := range tt.usernames { + addedUser, err := AddSCIMUser(ctx, mockSCIMUserWithUsername(t, u)) + require.NoError(t, err) + if idx+1 >= tt.startIndex && idx < tt.count { + expectedUsers = append(expectedUsers, addedUser) + } + } + + actualUsers, err := SCIMUserList(ctx, tt.startIndex, tt.count, tt.usernameToMatch) + require.NoError(t, err) + require.Equal(t, tt.startIndex, actualUsers.StartIndex) + if tt.name == "out-of-bounds-index" { + require.Empty(t, actualUsers.Resources) + } else { + require.Subset(t, usernameList(actualUsers.Resources), usernameList(expectedUsers)) + } + }) + } +} + +func TestSCIMUserByID(t *testing.T) { + ctx := context.Background() + cases := []struct { + name string + user *model.SCIMUser + errString string + }{ + {"simple-case", mockSCIMUser(t), ""}, + {"error-not-found", nil, "not found"}, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + id := model.UUID{} + if tt.user != nil { + addedUser, err := AddSCIMUser(ctx, tt.user) + require.NoError(t, err) + id = addedUser.ID + } + scimUser, err := SCIMUserByID(ctx, db.Bun(), id) + if tt.errString != "" { + require.Nil(t, scimUser) + require.ErrorContains(t, err, tt.errString) + } else { + require.NotNil(t, scimUser) + require.NoError(t, err) + } + }) + } +} + +func TestUserByAttribute(t *testing.T) { + ctx := context.Background() + cases := []struct { + name string + errString string + }{ + {"simple-case", ""}, + {"error-not-found", "not found"}, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + val := model.NewUUID() + user := mockSCIMUser(t) + user.RawAttributes = map[string]interface{}{"id": val} + + addedUser, err := AddSCIMUser(ctx, user) + require.NoError(t, err) + + if tt.errString != "" { + _, err := scimUserByAttribute(ctx, "user_id", "bogus-value") + require.Contains(t, err.Error(), tt.errString) + + _, err = UserBySCIMAttribute(ctx, "user_id", "bogus-value") + require.Contains(t, err.Error(), tt.errString) + } else { + // test scimUserByAttribute + scimUser, err := scimUserByAttribute(ctx, "id", fmt.Sprint(val)) + require.NoError(t, err) + require.Equal(t, addedUser.Username, scimUser.Username) + + // test userBySCIMAttribute + u, err := UserBySCIMAttribute(ctx, "id", fmt.Sprint(val)) + require.NoError(t, err) + require.Equal(t, addedUser.UserID, u.ID) + } + }) + } +} + +func TestSetSCIMUser(t *testing.T) { + ctx := context.Background() + cases := []struct { + name string + updatedUser *model.SCIMUser + errString string + matchUUID bool + }{ + {"simple-case", mockSCIMUser(t), "", true}, + {"simple-case", mockSCIMUser(t), "does not match updated user ID", false}, + {"empty-set", &model.SCIMUser{}, "duplicate key value", true}, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addedUser, err := AddSCIMUser(ctx, mockSCIMUser(t)) + require.NoError(t, err) + + if tt.matchUUID { + tt.updatedUser.ID = addedUser.ID + } + + user, err := SetSCIMUser(ctx, addedUser.ID.String(), tt.updatedUser) + if err != nil { + require.Contains(t, err.Error(), tt.errString) + } else { + require.NoError(t, err) + matchUsers(t, tt.updatedUser, user) + require.Equal(t, addedUser.ID, user.ID) + } + }) + } +} + +func TestUpdateUserAndDeleteSession(t *testing.T) { + ctx := context.Background() + cases := []struct { + name string + fields []string + updatedUser *model.SCIMUser + matchID bool + errString string + }{ + {"simple-case-one-field", []string{"username"}, mockSCIMUser(t), true, ""}, + {"multiple-fields", []string{"name", "emails", "username"}, mockSCIMUser(t), true, ""}, + {"id-not-found", []string{"username"}, mockSCIMUser(t), false, "does not match updated user ID"}, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + // Adding a mock test session -- to test for deletion later + // Add a user. + addedUser, err := AddSCIMUser(ctx, mockSCIMUser(t)) + require.NoError(t, err) + + var user model.User + err = db.Bun().NewSelect().Table("users").Where("id = ?", addedUser.UserID).Scan(ctx, &user) + require.NoError(t, err) + + // Add a session. + var session model.UserSession + _, err = StartSession(ctx, &user) + require.NoError(t, err) + + err = db.Bun().NewSelect().Table("user_sessions"). + Where("user_id = ?", user.ID).Scan(ctx, &session) + require.NoError(t, err) + + if tt.matchID { + tt.updatedUser.ID = addedUser.ID + } + + scimUser, err := UpdateUserAndDeleteSession(ctx, addedUser.ID.String(), tt.updatedUser, tt.fields) + if tt.errString != "" { + require.Contains(t, err.Error(), tt.errString) + } else { + require.NoError(t, err) + for _, v := range tt.fields { + switch v { + case "username": + require.Equal(t, tt.updatedUser.Username, scimUser.Username) + case "emails": + require.Equal(t, tt.updatedUser.Emails, scimUser.Emails) + case "name": + require.Equal(t, tt.updatedUser.Name, scimUser.Name) + } + } + } + + _, err = db.Bun().NewSelect().Table("user_sessions"). + Where("user_id = ?", user.ID).Exec(context.Background()) + require.ErrorAs(t, errors.New("Receive unexpected error: bun: Model(nil)"), &err) + }) + } +} + +func mockSCIMUser(t *testing.T) *model.SCIMUser { + return mockSCIMUserWithUsername(t, model.NewUUID().String()) +} + +func mockSCIMUserWithUsername(t *testing.T, username string) *model.SCIMUser { + user := &model.SCIMUser{ + Username: username, + ExternalID: fmt.Sprintf("external-id-%s", username), + Name: model.SCIMName{GivenName: "John", FamilyName: username}, + Emails: []model.SCIMEmail{{Type: "personal", SValue: fmt.Sprintf("value-%s", username), Primary: true}}, + Active: true, + PasswordHash: null.StringFrom("password"), + } + + return user +} + +func usernameList(l []*model.SCIMUser) []string { + res := []string{} + for _, v := range l { + res = append(res, v.Username) + } + return res +} + +func matchUsers(t *testing.T, a *model.SCIMUser, b *model.SCIMUser) { + // because only certain fields are written to the db + require.Equal(t, a.Username, b.Username) + require.Equal(t, a.ExternalID, b.ExternalID) + require.Equal(t, a.Name, b.Name) + require.Equal(t, a.Emails, b.Emails) + require.Equal(t, a.Active, b.Active) + require.Equal(t, a.RawAttributes, b.RawAttributes) +} diff --git a/master/pkg/model/scim_user.go b/master/pkg/model/scim_user.go index c2aaf97f9929..dac05e70bbdf 100644 --- a/master/pkg/model/scim_user.go +++ b/master/pkg/model/scim_user.go @@ -2,7 +2,6 @@ package model import ( "crypto/sha512" - "database/sql/driver" "encoding/json" "fmt" "net/url" @@ -22,11 +21,6 @@ type SCIMName struct { FamilyName string `json:"familyName"` } -// Value implements sql.Valuer. -func (e SCIMName) Value() (driver.Value, error) { - return json.Marshal(e) -} - // Scan implements sql.Scanner. func (e *SCIMName) Scan(value interface{}) error { return scanJSON(value, e) @@ -39,11 +33,6 @@ type SCIMEmail struct { Primary bool `json:"primary"` } -// Value implements sql.Valuer. -func (e SCIMEmail) Value() (driver.Value, error) { - return json.Marshal(e) -} - // Scan implements sql.Scanner. func (e *SCIMEmail) Scan(value interface{}) error { return scanJSON(value, e) @@ -52,11 +41,6 @@ func (e *SCIMEmail) Scan(value interface{}) error { // SCIMEmails is a list of emails in SCIM. type SCIMEmails []SCIMEmail -// Value implements sql.Valuer. -func (e SCIMEmails) Value() (driver.Value, error) { - return json.Marshal(e) -} - // Scan implements sql.Scanner. func (e *SCIMEmails) Scan(value interface{}) error { return scanJSON(value, e) @@ -96,21 +80,21 @@ func (s *SCIMUserSchemas) UnmarshalJSON(data []byte) error { // SCIMUser is a user in SCIM. type SCIMUser struct { - ID UUID `db:"id" json:"id"` - Username string `db:"username" json:"userName"` - ExternalID string `db:"external_id" json:"externalId"` - Name SCIMName `db:"name" json:"name"` - Emails SCIMEmails `db:"emails" json:"emails"` - Active bool `db:"active" json:"active"` + ID UUID `bun:"id" json:"id"` + Username string `bun:"username" json:"userName"` + ExternalID string `bun:"external_id" json:"externalId"` + Name SCIMName `bun:"name" json:"name"` + Emails SCIMEmails `bun:"emails" json:"emails"` + Active bool `bun:"active" json:"active"` - PasswordHash null.String `db:"password_hash" json:"password_hash,omitempty"` + PasswordHash null.String `bun:"password_hash" json:"password_hash,omitempty"` Password string `json:"password,omitempty"` Schemas SCIMUserSchemas `json:"schemas"` Meta *SCIMUserMeta `json:"meta"` - UserID UserID `db:"user_id" json:"-"` - RawAttributes map[string]interface{} `db:"raw_attributes" json:"-"` + UserID UserID `bun:"user_id" json:"-"` + RawAttributes map[string]interface{} `bun:"raw_attributes" json:"-"` } // Validate checks that external data satisfies the expected invariants.