Skip to content

Commit

Permalink
feat: decouple agent information from workloads starting tasks
Browse files Browse the repository at this point in the history
This is a step toward adding a generic resource provider interface.
As part of this change, Tasks now receive a single `TaskAssigned` message
when they are assigned rather than one `Assigned` message per container.
They also start containers by sending specs back to the cluster rather
than directly to agents.

DET-3178  #Done.
  • Loading branch information
aaron276h committed Jun 5, 2020
1 parent ed94d86 commit 675d0b7
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 69 deletions.
2 changes: 1 addition & 1 deletion master/internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (a *agent) Receive(ctx *actor.Context) error {
}
case aproto.SignalContainer:
ctx.Ask(a.socket, ws.WriteMessage{Message: aproto.AgentMessage{SignalContainer: &msg}})
case scheduler.StartTask:
case scheduler.StartTaskOnAgent:
start := ws.WriteMessage{Message: aproto.AgentMessage{StartContainer: &msg.StartContainer}}
ctx.Ask(a.socket, start)
ctx.Tell(a.slots, msg.StartContainer)
Expand Down
16 changes: 9 additions & 7 deletions master/internal/checkpoint_gc.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (t *checkpointGCTask) Receive(ctx *actor.Context) error {
},
})

case scheduler.Assigned:
case scheduler.TaskAssigned:
config := t.experiment.Config.CheckpointStorage

checkpoints, err := t.db.ExperimentCheckpointsToGCRaw(t.experiment.ID,
Expand All @@ -43,12 +43,14 @@ func (t *checkpointGCTask) Receive(ctx *actor.Context) error {

ctx.Log().Info("starting checkpoint garbage collection")

msg.StartTask(tasks.TaskSpec{
GCCheckpoints: &tasks.GCCheckpoints{
AgentUserGroup: t.agentUserGroup,
ExperimentID: t.experiment.ID,
ExperimentConfig: t.experiment.Config,
ToDelete: checkpoints,
ctx.Tell(t.cluster, scheduler.StartTask{
Spec: tasks.TaskSpec{
GCCheckpoints: &tasks.GCCheckpoints{
AgentUserGroup: t.agentUserGroup,
ExperimentID: t.experiment.ID,
ExperimentConfig: t.experiment.Config,
ToDelete: checkpoints,
},
},
})

Expand Down
18 changes: 10 additions & 8 deletions master/internal/command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,17 @@ func (c *command) Receive(ctx *actor.Context) error {
c.exit(ctx, exitStatus)
}

case scheduler.Assigned:
msg.StartTask(tasks.TaskSpec{
StartCommand: &tasks.StartCommand{
AgentUserGroup: c.agentUserGroup,
Config: c.config,
UserFiles: c.userFiles,
AdditionalFiles: c.additionalFiles,
case scheduler.TaskAssigned:
ctx.Tell(c.cluster, scheduler.StartTask{
Spec: tasks.TaskSpec{
StartCommand: &tasks.StartCommand{
AgentUserGroup: c.agentUserGroup,
Config: c.config,
UserFiles: c.userFiles,
AdditionalFiles: c.additionalFiles,
},
HarnessPath: c.harnessPath,
},
HarnessPath: c.harnessPath,
})
ctx.Tell(c.eventStream, event{Snapshot: newSummary(c), AssignedEvent: &msg})

Expand Down
2 changes: 1 addition & 1 deletion master/internal/command/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type event struct {

ScheduledEvent *scheduler.TaskID `json:"scheduled_event"`
// AssignedEvent is triggered when the parent was assigned to an agent.
AssignedEvent *scheduler.Assigned `json:"assigned_event"`
AssignedEvent *scheduler.TaskAssigned `json:"assigned_event"`
// ContainerStartedEvent is triggered when the container started on an agent.
ContainerStartedEvent *scheduler.ContainerStarted `json:"container_started_event"`
// ServiceReadyEvent is triggered when the service running in the container is ready to serve.
Expand Down
4 changes: 2 additions & 2 deletions master/internal/scheduler/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ type (

// Incoming agent actor messages; agent actors must accept these messages.
type (
// StartTask notifies the agent to start the task with the provided task spec.
StartTask struct {
// StartTaskOnAgent notifies the agent to start the task with the provided task spec.
StartTaskOnAgent struct {
Task *actor.Ref
agent.StartContainer
}
Expand Down
19 changes: 4 additions & 15 deletions master/internal/scheduler/assignment.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ import (
image "github.com/determined-ai/determined/master/pkg/tasks"
)

// Assigned is a message that tells the task actor that it has been assigned to the provided
// agent.
// assignment contains information for tasks have been assigned but not yet started.
// TODO: Expose assignment information (e.g. device type, num slots) to task actors.
type Assigned struct {
type assignment struct {
task *Task
container *container
agent *agentState
Expand All @@ -23,14 +22,14 @@ type Assigned struct {
}

// StartTask notifies the agent that the task is ready to start with the provided task spec.
func (a *Assigned) StartTask(spec image.TaskSpec) TaskSummary {
func (a *assignment) StartTask(spec image.TaskSpec) TaskSummary {
handler := a.agent.handler
spec.ClusterID = a.clusterID
spec.TaskID = string(a.task.ID)
spec.HarnessPath = a.harnessPath
spec.TaskContainerDefaults = a.taskContainerDefaults
spec.Devices = a.devices
handler.System().Tell(handler, StartTask{
handler.System().Tell(handler, StartTaskOnAgent{
Task: a.task.handler,
StartContainer: agent.StartContainer{
Container: cproto.Container{
Expand All @@ -44,13 +43,3 @@ func (a *Assigned) StartTask(spec image.TaskSpec) TaskSummary {
})
return newTaskSummary(a.task)
}

// IsLeader returns true if this assignment corresponds to the leader container of the task.
func (a *Assigned) IsLeader() bool {
return a.container.IsLeader()
}

// NumContainers returns the number of containers to which the task has been assigned.
func (a *Assigned) NumContainers() int {
return a.numContainers
}
41 changes: 37 additions & 4 deletions master/internal/scheduler/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ type Cluster struct {
tasksByID map[TaskID]*Task
tasksByContainerID map[ContainerID]*Task

assigmentByHandler map[*actor.Ref][]assignment

provisioner *actor.Ref
provisionerView *FilterableView

Expand Down Expand Up @@ -78,6 +80,8 @@ func NewCluster(
tasksByID: make(map[TaskID]*Task),
tasksByContainerID: make(map[ContainerID]*Task),

assigmentByHandler: make(map[*actor.Ref][]assignment),

proxy: proxy,
provisioner: provisioner,
provisionerView: newProvisionerView(provisionerSlotsPerInstance),
Expand All @@ -95,7 +99,7 @@ func (c *Cluster) assignContainer(task *Task, agent *agentState, slots int, numC
agent.containers[container.id] = container
task.containers[container.id] = container
c.tasksByContainerID[container.id] = task
assigned := Assigned{
c.assigmentByHandler[task.handler] = append(c.assigmentByHandler[task.handler], assignment{
task: task,
agent: agent,
container: container,
Expand All @@ -104,8 +108,7 @@ func (c *Cluster) assignContainer(task *Task, agent *agentState, slots int, numC
devices: agent.assignFreeDevices(slots, container.id),
harnessPath: c.harnessPath,
taskContainerDefaults: c.taskContainerDefaults,
}
task.handler.System().Tell(task.handler, assigned)
})
}

// assignTask allocates cluster data structures and sends the appropriate actor
Expand All @@ -114,10 +117,19 @@ func (c *Cluster) assignContainer(task *Task, agent *agentState, slots int, numC
func (c *Cluster) assignTask(task *Task) bool {
fits := findFits(task, c.agents, c.fittingMethod)

if len(fits) == 0 {
return false
}

c.assigmentByHandler[task.handler] = make([]assignment, 0, len(fits))

for _, fit := range fits {
c.assignContainer(task, fit.Agent, fit.Slots, len(fits))
}
return len(fits) > 0

task.handler.System().Tell(task.handler, TaskAssigned{numContainers: len(fits)})

return true
}

// terminateTask sends the appropriate actor messages to terminate a task and
Expand Down Expand Up @@ -246,6 +258,9 @@ func (c *Cluster) Receive(ctx *actor.Context) error {
c.receiveContainerTerminated(ctx, cid, *msg.ContainerStopped, false)
}

case StartTask:
c.receiveStartTask(ctx, msg)

case taskStopped:
c.receiveTaskStopped(ctx, msg)

Expand Down Expand Up @@ -341,6 +356,24 @@ func (c *Cluster) receiveAddTask(ctx *actor.Context, msg AddTask) {
}
}

func (c *Cluster) receiveStartTask(ctx *actor.Context, msg StartTask) {
task := c.tasksByHandler[ctx.Sender()]
if task == nil {
ctx.Log().WithField("address", ctx.Sender().Address()).Errorf("unknown task trying to start")
return
}

assignments := c.assigmentByHandler[ctx.Sender()]
if len(assignments) == 0 {
ctx.Log().WithField("name", task.name).Error("task is trying to start without any assignments")
return
}

for _, a := range assignments {
a.StartTask(msg.Spec)
}
}

func (c *Cluster) receiveContainerStartedOnAgent(ctx *actor.Context, msg ContainerStartedOnAgent) {
task := c.tasksByContainerID[msg.ContainerID]
if task == nil {
Expand Down
6 changes: 3 additions & 3 deletions master/internal/scheduler/cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ var errMock = errors.New("mock error")
type mockActor struct {
system *actor.System
cluster *actor.Ref
onAssigned func(Assigned) error
onAssigned func(TaskAssigned) error
onContainerStarted func(ContainerStarted) error
onTaskTerminated func(TaskTerminated) error
}
Expand Down Expand Up @@ -48,14 +48,14 @@ func (h *mockActor) Receive(ctx *actor.Context) error {
case ThrowPanic:
panic(errMock)

case Assigned:
case TaskAssigned:
if h.onAssigned != nil {
return h.onAssigned(msg)
}

h.system.Tell(h.cluster, ContainerStateChanged{
Container: cproto.Container{
ID: cproto.ID(msg.container.id),
ID: cproto.ID("random-container-name"),
State: cproto.Running,
},
ContainerStarted: &agent.ContainerStarted{
Expand Down
11 changes: 11 additions & 0 deletions master/internal/scheduler/resource_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,14 @@ type ContainerStateChanged struct {
ContainerStarted *agent.ContainerStarted
ContainerStopped *agent.ContainerStopped
}

// TaskAssigned is a message that tells the task actor that it has been assigned to run
// with a specified number of containers.
type TaskAssigned struct {
numContainers int
}

// NumContainers returns the number of containers to which the task has been assigned.
func (t *TaskAssigned) NumContainers() int {
return t.numContainers
}
8 changes: 4 additions & 4 deletions master/internal/scheduler/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ func newMockTask(
}

func (t *mockTask) Receive(ctx *actor.Context) error {
switch msg := ctx.Message().(type) {
case Assigned:
msg.StartTask(tasks.TaskSpec{})
switch ctx.Message().(type) {
case TaskAssigned:
ctx.Respond(StartTask{Spec: tasks.TaskSpec{}})
case getSlots:
ctx.Respond(t.slotsNeeded)
case getGroup:
Expand Down Expand Up @@ -115,7 +115,7 @@ func newMockAgent(

func (m mockAgent) Receive(ctx *actor.Context) error {
switch msg := ctx.Message().(type) {
case StartTask:
case StartTaskOnAgent:
if ctx.ExpectingResponse() {
ctx.Respond(newTask(&Task{
handler: msg.Task,
Expand Down
5 changes: 5 additions & 0 deletions master/internal/scheduler/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/google/uuid"

"github.com/determined-ai/determined/master/pkg/actor"
"github.com/determined-ai/determined/master/pkg/tasks"
)

// Task-related cluster level messages.
Expand All @@ -20,6 +21,10 @@ type (
Label string
FittingRequirements FittingRequirements
}
// StartTask signals that a scheduled task should be launched.
StartTask struct {
Spec tasks.TaskSpec
}
// taskStopped notifies that the task actor is stopped.
taskStopped struct {
Ref *actor.Ref
Expand Down
42 changes: 18 additions & 24 deletions master/internal/trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/determined-ai/determined/master/pkg/actor/api"
"github.com/determined-ai/determined/master/pkg/agent"
"github.com/determined-ai/determined/master/pkg/archive"
"github.com/determined-ai/determined/master/pkg/check"
"github.com/determined-ai/determined/master/pkg/container"
"github.com/determined-ai/determined/master/pkg/etc"
"github.com/determined-ai/determined/master/pkg/model"
Expand Down Expand Up @@ -321,7 +320,7 @@ func (t *trial) runningReceive(ctx *actor.Context) error {
case actor.ChildFailed:
ctx.Tell(t.cluster, scheduler.TerminateTask{TaskID: t.task.ID, Forcible: true})

case scheduler.Assigned:
case scheduler.TaskAssigned:
if err := t.processAssigned(ctx, msg); err != nil {
return err
}
Expand Down Expand Up @@ -371,7 +370,7 @@ func (t *trial) processID(ctx *actor.Context, id int) {
ctx.AddLabel("trial-id", id)
}

func (t *trial) processAssigned(ctx *actor.Context, msg scheduler.Assigned) error {
func (t *trial) processAssigned(ctx *actor.Context, msg scheduler.TaskAssigned) error {
if len(t.privateKey) == 0 {
generatedKeys, err := ssh.GenerateKey(nil)
if err != nil {
Expand Down Expand Up @@ -403,17 +402,10 @@ func (t *trial) processAssigned(ctx *actor.Context, msg scheduler.Assigned) erro
return errors.Wrap(err, "error getting workload from sequencer")
}

if t.numContainers == 0 {
t.numContainers = msg.NumContainers()
} else {
check.Panic(check.Equal(t.numContainers, msg.NumContainers(),
"got inconsistent numbers of containers"))
}
t.numContainers = msg.NumContainers()

if msg.IsLeader() {
if err = saveWorkload(t.db, w); err != nil {
ctx.Log().WithError(err).Error("failed to save workload to the database")
}
if err = saveWorkload(t.db, w); err != nil {
ctx.Log().WithError(err).Error("failed to save workload to the database")
}

ctx.Log().Infof("starting trial container: %v", w)
Expand Down Expand Up @@ -452,17 +444,19 @@ func (t *trial) processAssigned(ctx *actor.Context, msg scheduler.Assigned) erro
),
}

msg.StartTask(tasks.TaskSpec{
StartContainer: &tasks.StartContainer{
ExperimentConfig: t.experiment.Config,
ModelDefinition: t.modelDefinition,
HParams: t.create.Hparams,
TrialSeed: t.create.TrialSeed,
LatestCheckpoint: t.sequencer.LatestCheckpoint(),
InitialWorkload: w,
WorkloadManagerType: t.sequencer.WorkloadManagerType(),
AdditionalFiles: additionalFiles,
AgentUserGroup: t.agentUserGroup,
ctx.Tell(t.cluster, scheduler.StartTask{
Spec: tasks.TaskSpec{
StartContainer: &tasks.StartContainer{
ExperimentConfig: t.experiment.Config,
ModelDefinition: t.modelDefinition,
HParams: t.create.Hparams,
TrialSeed: t.create.TrialSeed,
LatestCheckpoint: t.sequencer.LatestCheckpoint(),
InitialWorkload: w,
WorkloadManagerType: t.sequencer.WorkloadManagerType(),
AdditionalFiles: additionalFiles,
AgentUserGroup: t.agentUserGroup,
},
},
})

Expand Down

0 comments on commit 675d0b7

Please sign in to comment.