Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC-0022: Proposal for adding new restart policy to torch.distributed… #34

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

aivanou
Copy link

@aivanou aivanou commented Dec 15, 2021

No description provided.

@aivanou aivanou marked this pull request as ready for review December 15, 2021 21:39
2. those that cannot checkpoint often (for various reasons)
3. already internally withstand worker failures (e.g. parameter servers, readers)
4. can run to completion with surviving workers without replacement of failed workers

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also worth mentioning that soft restarts means you're spending less time restarting things especially as failures become more common with larger setups

about the new RANK, WORLD_SIZE, MASTER_ADDR, and MASTER_PORT. The survivors use this information
to reinitialize in-process and continue making progress.

<em>Note that even in the presence of soft restarts,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what's the most used backend externally, is it nccl?

log_failure(state, e)
dist.destroy_process_group()
worker_info = torch.distributed.elastic.get_worker_info()
reset(worker_info) # implemented by user
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the user need to implement this? Or can we provide a reference implementation? For example, we know we need to run

world_size = world_size - 1
local_attempts = local_attempts + 1

| Scale up event | Scale down event | Worker failure |
|----------|:-------------:|------:|
| Hard Restart Policy | Terminate workers, perform re-rendezvous | Terminate workers, perform re-rendezvous | Terminate workers, perform re-rendezvous |
| Soft restart Policy | Broadcast PAUSE, start re-rendezvous, Broadcast WORK | Broadcast PAUSE, start re-rendezvous, Broadcast WORK | Terminate local workers, decrease max_restarts, proceed with scale-down event |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I wonder if there's a more descriptive name we could use. Hard restart makes sense.

Hard restart: Restart TE, Restart failed node
Soft restart: Restart TE, Do not restart the failed node

So maybe something like restartnode, donotrestartnode and later we'd add something like updatenode (for things like updating batch sizes)

# worker rank across a single node.
local_rank: int
# worker rank across a single role
role_rank: int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: what's a role?

# number of workers participating in the current role
role_world_size: int
# global number of workers
world_size: int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we prepend global?

1. The surviving TE agents will detect that scale-down event occurred and execute a soft restart policy.
2. TE agents will broadcast a PAUSE message to the local workers and will start a new rendezvous round.
3. Existing TE agents will wait for new TE agents to join rendezvous.
4. If the total number of TE agents at the end of the wait period is between configured `min` and `max`, all TE agents will conclude that the rendezvous round is successful.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the dataclass look like for a policy it sounds like we have

@dataclass 
class Policy:
  # Min nodes
  min_nodes : int

  # Max nodes
  max_nodes : int
  
  # Waiting time in ms
  wait_period : int

  # What else?

Also question about the min and max policy. Soft restarts doesn't mean that the number of nodes decreases monotically right? Suppose you start with 5 nodes, then you're down to 3, you could still end up with 4 at the next rendezvous? EDIT: nvm I think you answered here https://github.com/pytorch/rfcs/pull/34/files#diff-2dd32724169ece310d77ca72c847ba4df4856c0e91fa4df9449ec10ab9b25aa6R256

3. The downside is that it might be an overkill for what we want to do since torch.rpc was designed to actually send over tensors P2P
4. Tt might be tricky to overlap trainer process_groups and agent<>trainer rpc groups
3. File based
1. Robust choice for 90% of the use-cases. Idea is to create a directory per (agent, worker) channel and drop files with the serialized Msg on sends. A file watcher is used on the receiver-side. One can use the dir and file name + structure to organize the channel in a logical way.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you leaning on picking this option?

@ASDen
Copy link

ASDen commented Feb 21, 2022

First, moving that direction is a great step! thanks!

Second, I would highly suggest looking into Elastic Horovod, I think they managed to get few things right that would be big additions here:

  • Graceful node removal = continuing as-is without losing a single batch!
  • Providing very helpful classes for handling full state syncing (including weights, sampler, optimizer, primitive types, etc..), this is vital for simple true graceful node removal
  • The idea of doing memory (not disk) checkpointing of model weights seems interesting (much cheaper), esp. given this should be a fault-tolerant distributed system

Also, would this RFC support dynamically adding new nodes while training ?

@facebook-github-bot
Copy link
Contributor

Hi @aivanou!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants