Skip to content

Commit

Permalink
[WIP] Add --gpus support
Browse files Browse the repository at this point in the history
Signed-off-by: Tibor Vass <[email protected]>
  • Loading branch information
Tibor Vass committed Mar 14, 2019
1 parent 2178fea commit 9621d3f
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 2 deletions.
10 changes: 9 additions & 1 deletion cli/command/container/opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type containerOptions struct {
labels opts.ListOpts
deviceCgroupRules opts.ListOpts
devices opts.ListOpts
gpus opts.GpuOpts
ulimits *opts.UlimitOpt
sysctls *opts.MapOpts
publish opts.ListOpts
Expand Down Expand Up @@ -166,6 +167,7 @@ func addFlags(flags *pflag.FlagSet) *containerOptions {
flags.VarP(&copts.attach, "attach", "a", "Attach to STDIN, STDOUT or STDERR")
flags.Var(&copts.deviceCgroupRules, "device-cgroup-rule", "Add a rule to the cgroup allowed devices list")
flags.Var(&copts.devices, "device", "Add a host device to the container")
flags.Var(&copts.gpus, "gpus", "Request GPU devices for the container ('all' to pass all GPUs)")
flags.VarP(&copts.env, "env", "e", "Set environment variables")
flags.Var(&copts.envFile, "env-file", "Read in a file of environment variables")
flags.StringVar(&copts.entrypoint, "entrypoint", "", "Overwrite the default ENTRYPOINT of the image")
Expand Down Expand Up @@ -527,6 +529,8 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
}
}

deviceRequests := copts.gpus.Value()

resources := container.Resources{
CgroupParent: copts.cgroupParent,
Memory: copts.memory.Value(),
Expand All @@ -545,7 +549,6 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
CPUQuota: copts.cpuQuota,
CPURealtimePeriod: copts.cpuRealtimePeriod,
CPURealtimeRuntime: copts.cpuRealtimeRuntime,
PidsLimit: copts.pidsLimit,
BlkioWeight: copts.blkioWeight,
BlkioWeightDevice: copts.blkioWeightDevice.GetList(),
BlkioDeviceReadBps: copts.deviceReadBps.GetList(),
Expand All @@ -557,6 +560,11 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
Ulimits: copts.ulimits.GetList(),
DeviceCgroupRules: copts.deviceCgroupRules.GetAll(),
Devices: deviceMappings,
DeviceRequests: deviceRequests,
}

if copts.pidsLimit != 0 {
resources.PidsLimit = &copts.pidsLimit
}

config := &container.Config{
Expand Down
112 changes: 112 additions & 0 deletions opts/gpus.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package opts

import (
"encoding/csv"
"fmt"
"strconv"
"strings"

"github.com/docker/docker/api/types/container"
"github.com/pkg/errors"
)

// GpuOpts is a Value type for parsing mounts
type GpuOpts struct {
values []container.DeviceRequest
}

func parseCount(s string) (int, error) {
i := -1
var err error
if s != "all" {
i, err = strconv.Atoi(s)
if err != nil {
err = errors.Wrap(err, "count must be an integer")
}
}
return i, err
}

// Set a new mount value
func (o *GpuOpts) Set(value string) error {
csvReader := csv.NewReader(strings.NewReader(value))
fields, err := csvReader.Read()
if err != nil {
return err
}

req := container.DeviceRequest{Options: make(map[string]string), Capabilities: [][]string{{"gpu"}}}

// Set writable as the default
for _, field := range fields {
parts := strings.SplitN(field, "=", 2)
key := strings.ToLower(parts[0])

if len(parts) == 1 {
req.Count, err = parseCount(key)
if err != nil {
return err
}
continue
}

if len(parts) != 2 {
return fmt.Errorf("invalid field '%s' must be a key=value pair", field)
}

value := parts[1]
switch key {
case "driver":
req.Driver = value
case "count":
req.Count, err = parseCount(value)
if err != nil {
return err
}
case "device":
req.DeviceIDs = strings.Split(value, ",")
case "caps":
req.Capabilities = [][]string{append(strings.Split(value, ","), "gpu")}
case "options":
r := csv.NewReader(strings.NewReader(value))
optFields, err := r.Read()
if err != nil {
return errors.Wrap(err, "error reading gpu options")
}
req.Options = make(map[string]string)
for _, optField := range optFields {
optParts := strings.SplitN(optField, "=", 2)
key := strings.ToLower(optParts[0])
var value string
if len(optParts) > 1 {
value = optParts[1]
}
req.Options[key] = value
}
default:
return fmt.Errorf("unexpected key '%s' in '%s'", key, field)
}
}

o.values = append(o.values, req)
return nil
}

// Type returns the type of this option
func (o *GpuOpts) Type() string {
return "gpuRequest"
}

// String returns a string repr of this option
func (o *GpuOpts) String() string {
gpus := []string{}
for _, gpu := range o.values {
gpus = append(gpus, fmt.Sprintf("%v", gpu))
}
return strings.Join(gpus, ", ")
}

// Value returns the mounts
func (o *GpuOpts) Value() []container.DeviceRequest {
return o.values
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9621d3f

Please sign in to comment.