diff --git a/cli/command/container/opts.go b/cli/command/container/opts.go index 8fe4ded9c89f..38971b9e8565 100644 --- a/cli/command/container/opts.go +++ b/cli/command/container/opts.go @@ -46,6 +46,7 @@ type containerOptions struct { labels opts.ListOpts deviceCgroupRules opts.ListOpts devices opts.ListOpts + gpus opts.GpuOpts ulimits *opts.UlimitOpt sysctls *opts.MapOpts publish opts.ListOpts @@ -166,6 +167,8 @@ func addFlags(flags *pflag.FlagSet) *containerOptions { flags.VarP(&copts.attach, "attach", "a", "Attach to STDIN, STDOUT or STDERR") flags.Var(&copts.deviceCgroupRules, "device-cgroup-rule", "Add a rule to the cgroup allowed devices list") flags.Var(&copts.devices, "device", "Add a host device to the container") + flags.Var(&copts.gpus, "gpus", "GPU devices to add to the container ('all' to pass all GPUs)") + flags.SetAnnotation("gpus", "version", []string{"1.40"}) flags.VarP(&copts.env, "env", "e", "Set environment variables") flags.Var(&copts.envFile, "env-file", "Read in a file of environment variables") flags.StringVar(&copts.entrypoint, "entrypoint", "", "Overwrite the default ENTRYPOINT of the image") @@ -527,6 +530,8 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con } } + deviceRequests := copts.gpus.Value() + resources := container.Resources{ CgroupParent: copts.cgroupParent, Memory: copts.memory.Value(), @@ -557,6 +562,7 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con Ulimits: copts.ulimits.GetList(), DeviceCgroupRules: copts.deviceCgroupRules.GetAll(), Devices: deviceMappings, + DeviceRequests: deviceRequests, } config := &container.Config{ diff --git a/opts/gpus.go b/opts/gpus.go new file mode 100644 index 000000000000..95ad74aaa7b7 --- /dev/null +++ b/opts/gpus.go @@ -0,0 +1,104 @@ +package opts + +import ( + "encoding/csv" + "fmt" + "strconv" + "strings" + + "github.com/docker/docker/api/types/container" + "github.com/pkg/errors" +) + +// GpuOpts is a Value type for parsing mounts +type GpuOpts struct { + values []container.DeviceRequest +} + +func parseCount(s string) (int, error) { + if s == "all" { + return -1, nil + } + i, err := strconv.Atoi(s) + return i, errors.Wrap(err, "count must be an integer") +} + +// Set a new mount value +func (o *GpuOpts) Set(value string) error { + csvReader := csv.NewReader(strings.NewReader(value)) + fields, err := csvReader.Read() + if err != nil { + return err + } + + req := container.DeviceRequest{Options: make(map[string]string), Capabilities: [][]string{{"gpu"}}} + + // Set writable as the default + for _, field := range fields { + parts := strings.SplitN(field, "=", 2) + key := parts[0] + + if len(parts) == 1 { + req.Count, err = parseCount(key) + if err != nil { + return err + } + continue + } + + value := parts[1] + switch key { + case "driver": + req.Driver = value + case "count": + req.Count, err = parseCount(value) + if err != nil { + return err + } + case "device": + req.DeviceIDs = strings.Split(value, ",") + case "capabilities": + req.Capabilities = [][]string{append(strings.Split(value, ","), "gpu")} + case "options": + r := csv.NewReader(strings.NewReader(value)) + optFields, err := r.Read() + if err != nil { + return errors.Wrap(err, "failed to read gpu options") + } + req.Options = make(map[string]string) + for _, optField := range optFields { + optParts := strings.SplitN(optField, "=", 2) + key := optParts[0] + var value string + if len(optParts) > 1 { + value = optParts[1] + } + req.Options[key] = value + } + default: + return fmt.Errorf("unexpected key '%s' in '%s'", key, field) + } + } + + o.values = append(o.values, req) + return nil +} + +// Type returns the type of this option +func (o *GpuOpts) Type() string { + return "gpu-request" +} + +// String returns a string repr of this option +func (o *GpuOpts) String() string { + gpus := []string{} + for _, gpu := range o.values { + gpus = append(gpus, fmt.Sprintf("%v", gpu)) + } + return strings.Join(gpus, ", ") +} + +// Value returns the mounts +func (o *GpuOpts) Value() []container.DeviceRequest { + return o.values +}