Skip to content

Commit

Permalink
feat: adding PACHD_ADDRESS and DEX_TOKEN to task env
Browse files Browse the repository at this point in the history
  • Loading branch information
salonig23 committed Nov 23, 2023
1 parent 8c9dfbf commit 05b6e68
Show file tree
Hide file tree
Showing 12 changed files with 169 additions and 12 deletions.
28 changes: 27 additions & 1 deletion master/internal/api_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,25 @@ func (a *apiServer) SetCommandPriority(
return &apiv1.SetCommandPriorityResponse{Command: cmd.ToV1Command()}, nil
}

func (a *apiServer) AddOIDCPachydermEnvVars(
session *model.UserSession,
) (map[string]string, error) {
envVars := make(map[string]string)

if val, ok := session.InheritedClaims["OIDCRawIDToken"]; ok {
envVars["DEX_TOKEN"] = val
}

if a.m.config.Integrations.Pachyderm.Address != "" {
envVars["PACHD_ADDRESS"] = a.m.config.Integrations.Pachyderm.Address
}
return envVars, nil
}

func (a *apiServer) LaunchCommand(
ctx context.Context, req *apiv1.LaunchCommandRequest,
) (*apiv1.LaunchCommandResponse, error) {
user, _, err := grpcutil.GetUser(ctx)
user, session, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get the user: %s", err)
}
Expand Down Expand Up @@ -402,10 +417,21 @@ func (a *apiServer) LaunchCommand(
err.Error(),
)
}

launchReq.Spec.Base.ExtraEnvVars = map[string]string{
"DET_TASK_TYPE": string(model.TaskTypeCommand),
}

OIDCPachydermEnvVars, err := a.AddOIDCPachydermEnvVars(session)
if err != nil {
return nil, err
}
if len(OIDCPachydermEnvVars) == 0 {
for k, v := range OIDCPachydermEnvVars {
launchReq.Spec.Base.ExtraEnvVars[k] = v
}
}

// Launch a command.
cmd, err := command.DefaultCmdService.LaunchGenericCommand(
model.TaskTypeCommand,
Expand Down
8 changes: 7 additions & 1 deletion master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,7 @@ func (a *apiServer) ContinueExperiment(
func (a *apiServer) CreateExperiment(
ctx context.Context, req *apiv1.CreateExperimentRequest,
) (*apiv1.CreateExperimentResponse, error) {
user, _, err := grpcutil.GetUser(ctx)
user, session, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get the user: %s", err)
}
Expand Down Expand Up @@ -1669,6 +1669,12 @@ func (a *apiServer) CreateExperiment(
if err != nil {
return nil, err
}

taskSpec.ExtraEnvVars, err = a.AddOIDCPachydermEnvVars(session)
if err != nil {
return nil, err
}

if err = experiment.AuthZProvider.Get().CanCreateExperiment(ctx, *user, p); err != nil {
return nil, status.Errorf(codes.PermissionDenied, err.Error())
}
Expand Down
14 changes: 13 additions & 1 deletion master/internal/api_notebook.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (a *apiServer) isNTSCPermittedToLaunch(
func (a *apiServer) LaunchNotebook(
ctx context.Context, req *apiv1.LaunchNotebookRequest,
) (*apiv1.LaunchNotebookResponse, error) {
user, _, err := grpcutil.GetUser(ctx)
user, session, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get the user: %s", err)
}
Expand Down Expand Up @@ -277,12 +277,24 @@ func (a *apiServer) LaunchNotebook(
if err != nil {
return nil, status.Errorf(codes.Internal, "cannot marshal notebook config: %s", err.Error())
}

launchReq.Spec.Base.ExtraEnvVars = map[string]string{
"NOTEBOOK_PORT": strconv.Itoa(port),
"NOTEBOOK_CONFIG": string(configBytes),
"NOTEBOOK_IDLE_TYPE": launchReq.Spec.Config.NotebookIdleType,
"DET_TASK_TYPE": string(model.TaskTypeNotebook),
}

OIDCPachydermEnvVars, err := a.AddOIDCPachydermEnvVars(session)
if err != nil {
return nil, err
}
if len(OIDCPachydermEnvVars) == 0 {
for k, v := range OIDCPachydermEnvVars {
launchReq.Spec.Base.ExtraEnvVars[k] = v
}
}

launchReq.Spec.Base.ExtraProxyPorts = append(launchReq.Spec.Base.ExtraProxyPorts,
expconf.ProxyPort{
RawProxyPort: port,
Expand Down
12 changes: 11 additions & 1 deletion master/internal/api_shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func (a *apiServer) SetShellPriority(
func (a *apiServer) LaunchShell(
ctx context.Context, req *apiv1.LaunchShellRequest,
) (*apiv1.LaunchShellResponse, error) {
user, _, err := grpcutil.GetUser(ctx)
user, session, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get the user: %s", err)
}
Expand Down Expand Up @@ -246,6 +246,16 @@ func (a *apiServer) LaunchShell(

launchReq.Spec.Base.ExtraEnvVars = map[string]string{"DET_TASK_TYPE": string(model.TaskTypeShell)}

OIDCPachydermEnvVars, err := a.AddOIDCPachydermEnvVars(session)
if err != nil {
return nil, err
}
if len(OIDCPachydermEnvVars) == 0 {
for k, v := range OIDCPachydermEnvVars {
launchReq.Spec.Base.ExtraEnvVars[k] = v
}
}

var passphrase *string
if len(req.Data) > 0 {
var data map[string]interface{}
Expand Down
12 changes: 11 additions & 1 deletion master/internal/api_tensorboard.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (a *apiServer) LaunchTensorboard(
return nil, status.Error(codes.InvalidArgument, err.Error())
}

user, _, err := grpcutil.GetUser(ctx)
user, session, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get the user: %s", err)
}
Expand Down Expand Up @@ -281,6 +281,16 @@ func (a *apiServer) LaunchTensorboard(
"DET_TASK_TYPE": string(model.TaskTypeTensorboard),
}

OIDCPachydermEnvVars, err := a.AddOIDCPachydermEnvVars(session)
if err != nil {
return nil, err
}
if len(OIDCPachydermEnvVars) == 0 {
for k, v := range OIDCPachydermEnvVars {
uniqEnvVars[k] = v
}
}

if launchReq.Spec.Config.Debug {
uniqEnvVars["DET_DEBUG"] = "true"
}
Expand Down
14 changes: 13 additions & 1 deletion master/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ type WebhooksConfig struct {
SigningKey string `json:"signing_key"`
}

// IntegrationsConfig stores configs related to integrations like pachyderm.
type IntegrationsConfig struct {
Pachyderm PachydermConfig `json:"pachyderm"`
}

// PachydermConfig stores fields needed to integrate Pachyderm with determined.
type PachydermConfig struct {
Address string `json:"address"`
}

// DefaultConfig returns the default configuration of the master.
func DefaultConfig() *Config {
return &Config{
Expand Down Expand Up @@ -151,7 +161,9 @@ type Config struct {
ResourceConfig

// Internal contains "hidden" useful debugging configurations.
InternalConfig InternalConfig `json:"__internal"`
InternalConfig InternalConfig `json:"__internal"`
OIDC OIDCConfig `json:"oidc"`
Integrations IntegrationsConfig `json:"integrations"`
}

// GetMasterConfig returns reference to the master config singleton.
Expand Down
8 changes: 8 additions & 0 deletions master/internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ resource_pools:
max_agent_starting_period: 30s
task_container_defaults:
dtrain_network_interface: if0
integrations:
pachyderm:
address: foo
`
expected := Config{
Log: logger.Config{
Expand Down Expand Up @@ -92,6 +95,11 @@ resource_pools:
},
},
},
Integrations: IntegrationsConfig{
Pachyderm: PachydermConfig{
Address: "foo",
},
},
}

unmarshaled := Config{}
Expand Down
29 changes: 29 additions & 0 deletions master/internal/config/oidc_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package config

import (
"net/url"
)

// OIDCConfig holds the parameters for the OIDC provider.
type OIDCConfig struct {
Enabled bool `json:"enabled"`
Provider string `json:"provider"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
IDPSSOURL string `json:"idp_sso_url"`
IDPRecipientURL string `json:"idp_recipient_url"`
AuthenticationClaim string `json:"authentication_claim"`
SCIMAuthenticationAttribute string `json:"scim_authentication_attribute"`
AutoProvisionUsers bool `json:"auto_provision_users"`
GroupsClaimName string `json:"groups_claim_name"`
}

// Validate implements the check.Validatable interface.
func (c OIDCConfig) Validate() []error {
if !c.Enabled {
return nil
}

_, err := url.Parse(c.IDPRecipientURL)
return []error{err}
}
16 changes: 15 additions & 1 deletion master/internal/user/postgres_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,27 @@ const SessionDuration = 7 * 24 * time.Hour
// PersonalGroupPostfix is the system postfix appended to the username of all personal groups.
const PersonalGroupPostfix = "DeterminedPersonalGroup"

// UserSessionOption is the return type for WithInheritedClaims helper function.
type UserSessionOption func(f *model.UserSession)

// WithInheritedClaims function will add the specified inherited claims to the user session.
func WithInheritedClaims(claims map[string]string) UserSessionOption {
return func(s *model.UserSession) {
s.InheritedClaims = claims
}
}

// StartSession creates a row in the user_sessions table.
func StartSession(ctx context.Context, user *model.User) (string, error) {
func StartSession(ctx context.Context, user *model.User, opts ...UserSessionOption) (string, error) {
userSession := &model.UserSession{
UserID: user.ID,
Expiry: time.Now().Add(SessionDuration),
}

for _, opt := range opts {
opt(userSession)
}

err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
_, err := db.Bun().NewInsert().
Model(userSession).
Expand Down
23 changes: 23 additions & 0 deletions master/internal/user/postgres_users_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"testing"

"github.com/google/uuid"
"github.com/o1egl/paseto"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"github.com/uptrace/bun/schema"
Expand Down Expand Up @@ -167,6 +168,28 @@ func TestUserStartSession(t *testing.T) {
require.NoError(t, err)
}

func TestUserStartSessionTokenHasClaims(t *testing.T) {
user, err := addTestUser(nil)
require.NoError(t, err)

// Add a session with inherited claims.
claims := map[string]string{"test_key": "test_val"}
token, err := StartSession(context.TODO(), user, WithInheritedClaims(claims))
require.NoError(t, err)
require.NotNil(t, token)

var restoredSession model.UserSession
v2 := paseto.NewV2()
err = v2.Verify(token, db.GetTokenKeys().PublicKey, &restoredSession, nil)
require.NoError(t, err)
require.Equal(t, restoredSession.InheritedClaims, claims)

exists, err := db.Bun().NewSelect().Table("user_sessions").
Where("user_id = ?", user.ID).Exists(context.TODO())
require.True(t, exists)
require.NoError(t, err)
}

func TestDeleteSessionByToken(t *testing.T) {
userID, _, token, err := addTestSession()
require.NoError(t, err)
Expand Down
9 changes: 5 additions & 4 deletions master/pkg/model/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ type User struct {

// UserSession corresponds to a row in the "user_sessions" DB table.
type UserSession struct {
bun.BaseModel `bun:"table:user_sessions"`
ID SessionID `db:"id" json:"id"`
UserID UserID `db:"user_id" json:"user_id"`
Expiry time.Time `db:"expiry" json:"expiry"`
bun.BaseModel `bun:"table:user_sessions"`
ID SessionID `db:"id" json:"id"`
UserID UserID `db:"user_id" json:"user_id"`
Expiry time.Time `db:"expiry" json:"expiry"`
InheritedClaims map[string]string `bun:"-"` // InheritedClaims contains the OIDC raw ID token when OIDC is enabled
}

// A FullUser is a User joined with any other user relations.
Expand Down
8 changes: 7 additions & 1 deletion master/pkg/tasks/task_trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,13 @@ func (s TrialSpec) ToTaskSpec() TaskSpec {
envVars["DET_LATEST_CHECKPOINT"] = s.LatestCheckpoint.UUID.String()
}

res.ExtraEnvVars = envVars
if res.ExtraEnvVars != nil {
for k, v := range envVars {
res.ExtraEnvVars[k] = v
}
} else {
res.ExtraEnvVars = envVars
}

if shm := s.ExperimentConfig.Resources().ShmSize(); shm != nil {
res.ShmSize = int64(*shm)
Expand Down

0 comments on commit 05b6e68

Please sign in to comment.