From 53e64a938eab4886a49c6ae3caed055915c58352 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 13 Jun 2024 13:53:18 -0700 Subject: [PATCH] Cosmetic changes to train.py ghstack-source-id: 75fe8612072535d205993416cfeca911099ea40d Pull Request resolved: https://github.com/pytorch/torchtitan/pull/398 --- train.py | 48 +++++++++++++++++++++++------------------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/train.py b/train.py index adbd975f..3d07199b 100644 --- a/train.py +++ b/train.py @@ -129,7 +129,7 @@ def main(job_config: JobConfig): gc.disable() gc.collect(1) - # init world mesh + # init distributed world_size = int(os.environ["WORLD_SIZE"]) parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, @@ -142,15 +142,8 @@ def main(job_config: JobConfig): torch.cuda.set_device(device) init_distributed(job_config) + # build meshes world_mesh = parallel_dims.build_mesh(device_type="cuda") - - model_name = job_config.model.name - - # build tokenizer - tokenizer_type = model_name_to_tokenizer[model_name] - tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) - - # build dataloader if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] dp_degree = dp_mesh.size() @@ -161,6 +154,13 @@ def main(job_config: JobConfig): if parallel_dims.pp_enabled: pp_mesh = world_mesh["pp"] + model_name = job_config.model.name + + # build tokenizer + tokenizer_type = model_name_to_tokenizer[model_name] + tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) + + # build dataloader data_loader = build_hf_data_loader( job_config.training.dataset, job_config.training.dataset_path, @@ -191,10 +191,8 @@ def loss_fn(pred, labels): model_config.vocab_size = tokenizer.n_words model_config.max_seq_len = job_config.training.seq_len + logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}") with torch.device("meta"): - logger.info( - f"Building {model_name} {job_config.model.flavor} with {model_config}" - ) model = model_cls.from_model_args(model_config) # apply fp8 linear module swap @@ -236,7 +234,6 @@ def loss_fn(pred, labels): else: # If PP is enabled, we can't rely on init_weights, because some layers are missing. # In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation. - # allocate sharded model on GPU and initialize weights via DTensor model.init_weights() @@ -249,7 +246,7 @@ def loss_fn(pred, labels): # build optimizer after applying parallelisms to the model optimizer = build_optimizer(model, job_config) - scheduler = get_lr_scheduler(optimizer, job_config) + lr_scheduler = get_lr_scheduler(optimizer, job_config) metric_logger = build_metric_logger( job_config, metrics_log_rank=get_metrics_rank(world_mesh, parallel_dims) @@ -257,13 +254,13 @@ def loss_fn(pred, labels): train_state = TrainState() - # train loop model.train() + # load initial checkpoint checkpoint = CheckpointManager( model=model, optimizer=optimizer, - lr_scheduler=scheduler, + lr_scheduler=lr_scheduler, dataloader=data_loader, states={"train_state": train_state}, job_config=job_config, @@ -298,19 +295,20 @@ def loss_fn(pred, labels): data_iterator = iter(data_loader) + checkpoint.reset() + + # variables used to keep info for metrics logging + losses_since_last_log: List[float] = [] + ntokens_since_last_log = 0 + data_loading_times: List[float] = [] + time_last_log = timer() + gpu_memory_monitor.reset_peak_stats() + + # train loop logger.info(f"Training starts at step {train_state.step + 1}") with maybe_enable_profiling( job_config, global_step=train_state.step ) as torch_profiler: - checkpoint.reset() - - # variables used to keep info for metrics logging - losses_since_last_log: List[float] = [] - ntokens_since_last_log = 0 - data_loading_times: List[float] = [] - time_last_log = timer() - gpu_memory_monitor.reset_peak_stats() - while train_state.step < job_config.training.steps: train_state.step += 1 if train_state.step > 1 and train_state.step % _gc_freq == 0: