Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: agent config precedence #8656

Merged
merged 3 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions agent/cmd/determined-agent/env.go

This file was deleted.

157 changes: 157 additions & 0 deletions agent/cmd/determined-agent/init.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package main

import (
"strings"

"github.com/spf13/pflag"
"github.com/spf13/viper"

"github.com/determined-ai/determined/agent/internal/options"
)

const viperKeyDelimiter = ".."

var v *viper.Viper

//nolint:gochecknoinits
func init() {
registerAgentConfig()
}

type optionsKey []string

func (c optionsKey) EnvName() string {
return "DET_" + strings.ReplaceAll(strings.ToUpper(c.FlagName()), "-", "_")
}

func (c optionsKey) AccessPath() string {
return strings.ReplaceAll(strings.Join(c, viperKeyDelimiter), "-", "_")
}

func (c optionsKey) FlagName() string {
return strings.Join(c, "-")
}

func registerString(flags *pflag.FlagSet, name optionsKey, value string, usage string) {
flags.String(name.FlagName(), value, usage)
_ = v.BindPFlag(name.AccessPath(), flags.Lookup(name.FlagName()))
_ = v.BindEnv(name.AccessPath(), name.EnvName())
v.SetDefault(name.AccessPath(), value)
}

func registerBool(flags *pflag.FlagSet, name optionsKey, value bool, usage string) {
flags.Bool(name.FlagName(), value, usage)
_ = v.BindEnv(name.AccessPath(), name.EnvName())
_ = v.BindPFlag(name.AccessPath(), flags.Lookup(name.FlagName()))
v.SetDefault(name.AccessPath(), value)
}

func registerInt(flags *pflag.FlagSet, name optionsKey, value int, usage string) {
flags.Int(name.FlagName(), value, usage)
_ = v.BindEnv(name.AccessPath(), name.EnvName())
_ = v.BindPFlag(name.AccessPath(), flags.Lookup(name.FlagName()))
v.SetDefault(name.AccessPath(), value)
}

func registerAgentConfig() {
v = viper.NewWithOptions(viper.KeyDelimiter(viperKeyDelimiter))
v.SetTypeByDefaultValue(true)

defaults := options.DefaultOptions()
name := func(components ...string) optionsKey { return components }

// Register flags and environment variables, and set default values for viper settings.

// TODO(DET-8884): Configure log level through agent config file.
rootCmd.PersistentFlags().StringP("log-level", "l", "info",
"set the logging level (can be one of: debug, info, warn, error, or fatal)")
rootCmd.PersistentFlags().Bool("log-color", true, "disable colored output")

flags := runCmd.Flags()
iFlags := runCmd.InheritedFlags()

// Logging flags.
logLevelName := name("log", "level")
_ = v.BindEnv(logLevelName.AccessPath(), logLevelName.EnvName())
_ = v.BindPFlag(logLevelName.AccessPath(), iFlags.Lookup(logLevelName.FlagName()))
v.SetDefault(logLevelName.AccessPath(), defaults.Log.Level)

logColorName := name("log", "color")
_ = v.BindEnv(logColorName.AccessPath(), logColorName.EnvName())
_ = v.BindPFlag(logColorName.AccessPath(), iFlags.Lookup(logColorName.FlagName()))
v.SetDefault(logColorName.AccessPath(), true)

// Top-level flags.
registerString(flags, name("config-file"), defaults.ConfigFile,
"Path to agent configuration file")
registerString(flags, name("master-host"), defaults.MasterHost, "Hostname of the master")
registerInt(flags, name("master-port"), defaults.MasterPort, "Port of the master")
registerString(flags, name("agent-id"), defaults.AgentID, "Unique ID of this Determined agent")

// Label flags.
registerString(flags, name("label"), defaults.Label,
"This field has been deprecated and will be ignored, use ``resource_pool`` instead.")

// ResourcePool flags.
registerString(flags, name("resource-pool"), defaults.ResourcePool,
"Resource Pool the agent belongs to")

// Container flags.
registerString(flags, name("container-master-host"), defaults.ContainerMasterHost,
"Master hostname that containers started by this agent will connect to")
registerInt(flags, name("container-master-port"), defaults.ContainerMasterPort,
"Master port that containers started by this agent will connect to")

// Device flags.
registerString(flags, name("slot-type"), defaults.SlotType, "slot type to expose")
registerString(flags, name("visible-gpus"), defaults.VisibleGPUs, "GPUs to expose as slots")

// Security flags.
registerBool(flags, name("security", "tls", "enabled"), defaults.Security.TLS.Enabled,
"Whether to use TLS to connect to the master")
registerBool(flags, name("security", "tls", "skip-verify"), defaults.Security.TLS.SkipVerify,
"Whether to skip verifying the master certificate when TLS is on (insecure!)")
registerString(flags, name("security", "tls", "master-cert"), defaults.Security.TLS.MasterCert,
"CA cert file for the master")
registerString(flags, name("security", "tls", "master-cert-name"),
defaults.Security.TLS.MasterCertName,
"expected address in the master TLS certificate (if different than the one used for connecting)",
)

// Debug flags.
registerBool(flags, name("debug"), defaults.Debug, "Enable verbose script output")
registerInt(flags, name("artificial-slots"), defaults.ArtificialSlots, "")
flags.Lookup("artificial-slots").Hidden = true
registerString(flags, name("image-root"), defaults.ImageRoot,
"Path to local container image cache")

// Endpoint TLS flags.
registerBool(flags, name("tls"), defaults.TLS, "Use TLS for the API server")
registerString(flags, name("tls-cert"), defaults.TLSCertFile, "Path to TLS certification file")
registerString(flags, name("tls-key"), defaults.TLSKeyFile, "Path to TLS key file")

// Endpoint flags.
registerBool(flags, name("api-enabled"), defaults.APIEnabled, "Enable agent API endpoints")
registerString(flags, name("bind-ip"), defaults.BindIP,
"IP address to listen on for API requests")
registerInt(flags, name("bind-port"), defaults.BindPort, "Port to listen on for API requests")

// Proxy flags.
registerString(flags, name("http-proxy"), defaults.HTTPProxy,
"The HTTP proxy address for the agent's containers")
registerString(flags, name("https-proxy"), defaults.HTTPSProxy,
"The HTTPS proxy address for the agent's containers")
registerString(flags, name("ftp-proxy"), defaults.FTPProxy,
"THe FTP proxy address for the agent's containers")
registerString(flags, name("no-proxy"), defaults.NoProxy,
"Addresses that the agent's containers should not proxy")

// Fault-tolerance flags.
registerInt(flags, name("agent-reconnect-attempts"), defaults.AgentReconnectAttempts,
"Max attempts agent has to reconnect")
registerInt(flags, name("agent-reconnect-backoff"), defaults.AgentReconnectBackoff,
"Time between agent reconnect attempts")

registerString(flags, name("container-runtime"), defaults.ContainerRuntime,
"The container runtime to use")
}
6 changes: 4 additions & 2 deletions agent/cmd/determined-agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (

log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"

"github.com/determined-ai/determined/master/pkg/logger"
)

func maybeInjectRootAlias(rootCmd *cobra.Command, inject string) {
Expand All @@ -31,10 +33,10 @@ func nonRootSubCmds(rootCmd *cobra.Command) []string {
}

func main() {
rootCmd := newRootCmd()
logger.SetLogrus(*logger.DefaultConfig())
maybeInjectRootAlias(rootCmd, "run")

if err := newRootCmd().Execute(); err != nil {
if err := rootCmd.Execute(); err != nil {
log.WithError(err).Fatal("fatal error running Determined agent")
}
}
39 changes: 6 additions & 33 deletions agent/cmd/determined-agent/root.go
Original file line number Diff line number Diff line change
@@ -1,49 +1,22 @@
package main

import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)

type cobraOpts struct {
logLevel string
noColor bool
}

var version = "dev"
var (
version = "dev"
rootCmd = newRootCmd()
runCmd = newRunCmd()
)

func newRootCmd() *cobra.Command {
opts := cobraOpts{}

cmd := &cobra.Command{
Use: "determined-agent",
Version: version,
PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
if err := bindEnv("DET_", cmd); err != nil {
return err
}
level, err := log.ParseLevel(opts.logLevel)
if err != nil {
return err
}
log.SetLevel(level)
log.SetFormatter(&log.TextFormatter{
FullTimestamp: true,
ForceColors: true,
DisableColors: opts.noColor,
})
return nil
},
}

// TODO(DET-8884): Configure log level through agent config file.
cmd.PersistentFlags().StringVarP(&opts.logLevel, "log-level", "l", "trace",
"set the logging level (can be one of: debug, info, warn, error, or fatal)")
cmd.PersistentFlags().BoolVar(&opts.noColor, "no-color", false, "disable colored output")

cmd.AddCommand(newCompletionCmd())
cmd.AddCommand(newVersionCmd())
cmd.AddCommand(newRunCmd())
cmd.AddCommand(newCompletionCmd(), newVersionCmd(), runCmd)

return cmd
}
Loading
Loading