Skip to content

Commit

Permalink
Shorten nccl comm timeout and enable flight recorder dumping (#103)
Browse files Browse the repository at this point in the history
Timeout
-------

It's convenient whether during iterative debugging or long running
training to find out asap about a failure. The default timeout is way
too long and leads to wasted cluster time or developer frustration.
  
Timeout can be adjusted via cmdline or in .toml if it needs to be larger
for a particular model.

Another useful pattern can be to set a large timeout for initialization
and then tighten it after iteration 1. We can add this later if desired.

Ideally we could pass the timeout to the device mesh ctor, but it's not
ready yet. Also, we can change timeouts of the existing PGs after
creating them, but that's more LOC and not necessary unless we want to
change the timeouts at runtime.

Dumps
-----

Dumping on timeout should be a safe default for everyone. It has the
side-effect of requiring a dump path which defaults to ~/pgnccl_dump but
can be overridden via DUMP_PATH env.

The raw content of the dump is a pickle that is intended to be consumed
through scripts/tools which are under development, so it may not be easy
to know how to use these for now. As the tooling matures, we should
provide reference docs and probably print out pointers in the logs when
we perform the dump.


Test plan:
tested locally by adding a rank0 sleep for 10sec inside the training
loop, validating all 8 ranks dumped a trace.
  • Loading branch information
wconstab authored Mar 15, 2024
1 parent af56ae0 commit 073909b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 1 deletion.
14 changes: 14 additions & 0 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,18 @@ def init_args_from_command_line(
default="2", # 2 = checkpoint every other layer
help="['int', 'op'] = selective activation checkpointing options, 'int' for every nth layer, or 'op' for op level ac.",
)

# communications library settings
parser.add_argument(
"--comm.timeout_seconds",
type=int,
default=5,
help="Timeout for async communication operations",
)
parser.add_argument(
"--comm.trace_buf_size",
type=int,
default=20000,
help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
)
return parser.parse_args(args_list)
40 changes: 40 additions & 0 deletions torchtrain/parallelisms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import os

from dataclasses import dataclass
from datetime import timedelta
from functools import cached_property

import torch

from torch.distributed.device_mesh import init_device_mesh
from torchtrain.logging_utils import logger
from torchtrain.parallelisms.parallelize_llama import parallelize_llama
Expand All @@ -12,6 +17,41 @@
"llama": parallelize_llama,
}

TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE"
TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE"
DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT"
ASYNC_ERROR_HANDLING = "TORCH_NCCL_ASYNC_ERROR_HANDLING"
SKIP_CLEANUP = "3"


def _warn_overwrite_env(env, val):
if env in os.environ:
logger.warning(
f"ENV[{env}] = {os.environ[env]} will be overridden to {val} based on job config"
)
os.environ[env] = val


def init_distributed(job_config):
# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup)
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055
# This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle
# behavior differences
_warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP)

# enable torch nccl flight recorder in the mode that would dump files if timeout is detected
_warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size))
if job_config.comm.trace_buf_size > 0:
# dump on timeout by default if trace buffer is enabled
_warn_overwrite_env(DUMP_ON_TIMEOUT, "1")
dump_dir = f"{job_config.job.dump_folder}/comm_trace"
os.makedirs(dump_dir, exist_ok=True)
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_")

torch.distributed.init_process_group(
"nccl", timeout=timedelta(seconds=job_config.comm.timeout_seconds)
)


@dataclass
class ParallelDims:
Expand Down
9 changes: 8 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from torchtrain.meta_init import meta_model_init
from torchtrain.metrics import build_metric_logger, get_num_params, GPUMemoryMonitor
from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtrain.parallelisms import models_parallelize_fns, ParallelDims
from torchtrain.parallelisms import (
init_distributed,
models_parallelize_fns,
ParallelDims,
)
from torchtrain.profiling import maybe_run_profiler
from torchtrain.utils import Color, dist_max, dist_mean

Expand Down Expand Up @@ -100,6 +104,9 @@ def main(job_config: JobConfig):
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
init_distributed(job_config)

world_mesh = parallel_dims.build_mesh(device_type="cuda")

model_name = job_config.model.name
Expand Down

0 comments on commit 073909b

Please sign in to comment.