Skip to content

Commit

Permalink
Cosmetic changes to train.py
Browse files Browse the repository at this point in the history
ghstack-source-id: 75fe8612072535d205993416cfeca911099ea40d
Pull Request resolved: #398
  • Loading branch information
kwen2501 committed Jun 13, 2024
1 parent e17e3b8 commit 53e64a9
Showing 1 changed file with 23 additions and 25 deletions.
48 changes: 23 additions & 25 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def main(job_config: JobConfig):
gc.disable()
gc.collect(1)

# init world mesh
# init distributed
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
Expand All @@ -142,15 +142,8 @@ def main(job_config: JobConfig):
torch.cuda.set_device(device)
init_distributed(job_config)

# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")

model_name = job_config.model.name

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# build dataloader
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
dp_degree = dp_mesh.size()
Expand All @@ -161,6 +154,13 @@ def main(job_config: JobConfig):
if parallel_dims.pp_enabled:
pp_mesh = world_mesh["pp"]

model_name = job_config.model.name

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# build dataloader
data_loader = build_hf_data_loader(
job_config.training.dataset,
job_config.training.dataset_path,
Expand Down Expand Up @@ -191,10 +191,8 @@ def loss_fn(pred, labels):
model_config.vocab_size = tokenizer.n_words
model_config.max_seq_len = job_config.training.seq_len

logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}")
with torch.device("meta"):
logger.info(
f"Building {model_name} {job_config.model.flavor} with {model_config}"
)
model = model_cls.from_model_args(model_config)

# apply fp8 linear module swap
Expand Down Expand Up @@ -236,7 +234,6 @@ def loss_fn(pred, labels):
else:
# If PP is enabled, we can't rely on init_weights, because some layers are missing.
# In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.

# allocate sharded model on GPU and initialize weights via DTensor
model.init_weights()

Expand All @@ -249,21 +246,21 @@ def loss_fn(pred, labels):

# build optimizer after applying parallelisms to the model
optimizer = build_optimizer(model, job_config)
scheduler = get_lr_scheduler(optimizer, job_config)
lr_scheduler = get_lr_scheduler(optimizer, job_config)

metric_logger = build_metric_logger(
job_config, metrics_log_rank=get_metrics_rank(world_mesh, parallel_dims)
)

train_state = TrainState()

# train loop
model.train()

# load initial checkpoint
checkpoint = CheckpointManager(
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
lr_scheduler=lr_scheduler,
dataloader=data_loader,
states={"train_state": train_state},
job_config=job_config,
Expand Down Expand Up @@ -298,19 +295,20 @@ def loss_fn(pred, labels):

data_iterator = iter(data_loader)

checkpoint.reset()

# variables used to keep info for metrics logging
losses_since_last_log: List[float] = []
ntokens_since_last_log = 0
data_loading_times: List[float] = []
time_last_log = timer()
gpu_memory_monitor.reset_peak_stats()

# train loop
logger.info(f"Training starts at step {train_state.step + 1}")
with maybe_enable_profiling(
job_config, global_step=train_state.step
) as torch_profiler:
checkpoint.reset()

# variables used to keep info for metrics logging
losses_since_last_log: List[float] = []
ntokens_since_last_log = 0
data_loading_times: List[float] = []
time_last_log = timer()
gpu_memory_monitor.reset_peak_stats()

while train_state.step < job_config.training.steps:
train_state.step += 1
if train_state.step > 1 and train_state.step % _gc_freq == 0:
Expand Down

0 comments on commit 53e64a9

Please sign in to comment.