Skip to content

Commit

Permalink
[FIX DDP] fix ddp (PaddlePaddle#8549)
Browse files Browse the repository at this point in the history
* enable trainer tests.
  • Loading branch information
ZHUI committed Jun 7, 2024
1 parent 02fd721 commit 257a5bc
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 86 deletions.
12 changes: 2 additions & 10 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,16 +1771,8 @@ def _wrap_model(self, model, training=True):
in_sep_parallel_mode = self.args.sep_parallel_degree > 1

# Multi-gpu training
if (
self.args.world_size > 1
and not self.args.use_hybrid_parallel
or not (
in_pipeline_parallel_mode
or in_sharding_parallel_mode
or in_tensor_parallel_mode
or in_sep_parallel_mode
)
):
if self.args.world_size > 1 and (not self.args.use_hybrid_parallel):
# MOE use DDP to broadcaset parameters.
model = paddle.DataParallel(model)
# Distributed training (should be after fp16 initialization)

Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ def is_segment_parallel_supported():
if world_size > 1:
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized():
if self.unified_checkpoint:
self.use_hybrid_parallel = True
# DP use hybrid group
strategy = fleet.DistributedStrategy()
fleet.init(is_collective=True, strategy=strategy)
else:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ exclude = ['.flake8']

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra -q --ignore model_zoo/gpt-3/"
addopts = "-ra -q --dist loadgroup"
pythonpath = ["."]
testpaths = [
"tests/data",
Expand All @@ -28,7 +28,7 @@ testpaths = [
"tests/prompt",
# "tests/taskflow", TODO (paddle 2.5.1 breaks this test suite, debug later)
"tests/utils",
"model_zoo",
# "model_zoo",
]
python_files = [
"test.py",
Expand Down
47 changes: 25 additions & 22 deletions tests/trainer/test_lora_unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __test__(cls):

def setUp(self):
"""
1. update runfrist and rerun to run defined different config
1. update runfirst and rerun to run defined different config
2. update need_allclose to True if you want to check the result
3. update rtol to the relative value you want to check
"""
Expand All @@ -169,7 +169,7 @@ def setUp(self):

self.run_lora_file = "llm/finetune_generation.py"

def runfrist(self, train_args):
def runfirst(self, train_args):
self.run_n1c8(self.run_lora_file, **train_args)

def rerun(self, train_args):
Expand All @@ -181,7 +181,7 @@ def testTP4PP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP4PP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -196,7 +196,7 @@ def testTP2Sharding4(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP2Sharding4"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -213,7 +213,7 @@ def testTP8(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP8"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -227,7 +227,7 @@ def testTP4DP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP4DP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -242,7 +242,7 @@ def testTP4Sharding2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP4Sharding2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -257,7 +257,7 @@ def testTP2PP4(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP2PP4"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -272,7 +272,7 @@ def testPP8(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["PP8"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -287,7 +287,7 @@ def testPP4DP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["PP4DP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -302,7 +302,7 @@ def testPP4Sharding2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["PP4Sharding2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -317,7 +317,7 @@ def testSharding8S1(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding8S1"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -332,7 +332,7 @@ def testSharding8S2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding8S2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -347,7 +347,7 @@ def testSharding4S1DP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding4S1DP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -362,7 +362,7 @@ def testSharding4S2DP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding4S2DP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -377,7 +377,7 @@ def testSharding2S1DP4(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding2S1DP4"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -392,7 +392,7 @@ def testSharding2S2DP4(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding2S2DP4"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -407,7 +407,7 @@ def testDP8(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["DP8"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -416,27 +416,29 @@ def testDP8(self):
np.testing.assert_allclose(res[0], res[1], self.rtol)


@pytest.mark.skipif(True, reason="Skip for None CE")
class TestUnifiedCheckpointOnN2C4(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()
self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
def runfirst(self, train_args):
self.run_n2c4(self.run_lora_file, **train_args)

def rerun(self, train_args):
self.run_n2c4(self.run_lora_file, **train_args)


@pytest.mark.skipif(True, reason="Skip for None CE")
class TestUnifiedCheckpointOnN1C8CheckpointCompatible(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()

self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
def runfirst(self, train_args):
train_args["unified_checkpoint"] = 0
self.run_n1c8(self.run_lora_file, **train_args)

Expand All @@ -445,14 +447,15 @@ def rerun(self, train_args):
self.run_n1c8(self.run_lora_file, **train_args)


@pytest.mark.skipif(True, reason="Skip for None CE")
class TestPaddleCheckpointOnN1C8Reset(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()

self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
def runfirst(self, train_args):
train_args["unified_checkpoint"] = 0
self.run_n1c8(self.run_lora_file, **train_args)

Expand All @@ -469,7 +472,7 @@ def setUp(self):
self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
def runfirst(self, train_args):
train_args["unified_checkpoint"] = 0
self.run_n2c4(self.run_lora_file, **train_args)

Expand Down
Loading

0 comments on commit 257a5bc

Please sign in to comment.