From 219c7c7d9be6d425ced468a612b597c901b2623e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Sun, 29 Sep 2024 03:31:13 +0000 Subject: [PATCH] zbv + zero --- colossalai/pipeline/schedule/zero_bubble_pp.py | 18 ++++++++++++------ .../test_model/test_shard_llama.py | 13 +++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 5c25c5bfaa80..cb5a47fa89aa 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -500,12 +500,18 @@ def backward_b_step( output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] - optimizer.backward_by_grad( - tensor=output_obj_, - grad=output_obj_grad_, - inputs=input_obj_, - retain_graph=True, - ) + try: + ctx = optimizer.no_sync() + except AttributeError: + ctx = model_chunk.no_sync() + + with ctx: + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad_, + inputs=input_obj_, + retain_graph=True, + ) # Format output_obj_grad input_obj_grad = {} diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index f3b4db1cefc1..e773da75e221 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -292,6 +292,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_gradient_checkpointing": True, "parallel_output": False, }, + { + "tp_size": 2, + "pp_size": 2, + "pp_style": "zbv", + "num_model_chunks": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "parallel_output": False, + }, ], ) def run_llama_test(test_config):