Skip to content

Commit

Permalink
zbv + zero
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Sep 29, 2024
1 parent 87e742d commit 219c7c7
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
18 changes: 12 additions & 6 deletions colossalai/pipeline/schedule/zero_bubble_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,18 @@ def backward_b_step(
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]

optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
inputs=input_obj_,
retain_graph=True,
)
try:
ctx = optimizer.no_sync()
except AttributeError:
ctx = model_chunk.no_sync()

with ctx:
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
inputs=input_obj_,
retain_graph=True,
)

# Format output_obj_grad
input_obj_grad = {}
Expand Down
13 changes: 13 additions & 0 deletions tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"enable_gradient_checkpointing": True,
"parallel_output": False,
},
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "zbv",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"parallel_output": False,
},
],
)
def run_llama_test(test_config):
Expand Down

0 comments on commit 219c7c7

Please sign in to comment.