Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX DDP] fix ddp #8549

Merged
merged 5 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,17 +1795,8 @@ def _wrap_model(self, model, training=True):
in_cp_parallel_mode = self.args.context_parallel_degree > 1

# Multi-gpu training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方需要合入到2.8吗?

if (
self.args.world_size > 1
and not self.args.use_hybrid_parallel
or not (
in_pipeline_parallel_mode
or in_sharding_parallel_mode
or in_tensor_parallel_mode
or in_sep_parallel_mode
or in_cp_parallel_mode
)
):
if self.args.world_size > 1 and (not self.args.use_hybrid_parallel):
# MOE use DDP to broadcaset parameters.
model = paddle.DataParallel(model)
# Distributed training (should be after fp16 initialization)

Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ def is_segment_parallel_supported():
if world_size > 1:
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized():
if self.unified_checkpoint:
self.use_hybrid_parallel = True
# DP use hybrid group
strategy = fleet.DistributedStrategy()
fleet.init(is_collective=True, strategy=strategy)
else:
Expand Down
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ exclude = ['.flake8']
[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra -q --dist loadgroup"
retries = 0
retry_delay = 0.5
timeout = 200
pythonpath = ["."]
testpaths = [
"tests/data",
Expand All @@ -25,7 +22,7 @@ testpaths = [
"tests/layers",
"tests/metrics",
"tests/ops",
# "tests/trainer",
"tests/trainer",
"tests/transformers",
"tests/peft",
"tests/prompt",
Expand Down
6 changes: 3 additions & 3 deletions scripts/unit_test/ci_unit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ install_requirements() {
python -m pip install -r requirements-dev.txt
python -m pip install -r tests/requirements.txt
python -m pip install -r paddlenlp/experimental/autonlp/requirements.txt
python -m pip uninstall paddlepaddle -y
python -m pip uninstall paddlepaddle paddlepaddle_gpu -y
python -m pip install --no-cache-dir ${paddle}

python setup.py bdist_wheel
python setup.py bdist_wheel > /dev/null
python -m pip install dist/p****.whl
cd csrc/
python setup_cuda.py install
Expand All @@ -51,4 +51,4 @@ set_env() {

install_requirements
set_env
pytest -v -n 8 --durations 20
pytest -v -n 8 --timeout 200 --durations 20 --cov paddlenlp --cov-report xml:coverage.xml
47 changes: 25 additions & 22 deletions tests/trainer/test_lora_unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __test__(cls):

def setUp(self):
"""
1. update runfrist and rerun to run defined different config
1. update runfirst and rerun to run defined different config
2. update need_allclose to True if you want to check the result
3. update rtol to the relative value you want to check
"""
Expand All @@ -171,7 +171,7 @@ def setUp(self):

self.run_lora_file = "llm/finetune_generation.py"

def runfrist(self, train_args):
def runfirst(self, train_args):
self.run_n1c8(self.run_lora_file, **train_args)

def rerun(self, train_args):
Expand All @@ -183,7 +183,7 @@ def testTP4PP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP4PP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -198,7 +198,7 @@ def testTP2Sharding4(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP2Sharding4"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -216,7 +216,7 @@ def testTP8(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP8"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -230,7 +230,7 @@ def testTP4DP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP4DP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -245,7 +245,7 @@ def testTP4Sharding2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP4Sharding2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -260,7 +260,7 @@ def testTP2PP4(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["TP2PP4"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -275,7 +275,7 @@ def testPP8(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["PP8"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -290,7 +290,7 @@ def testPP4DP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["PP4DP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -305,7 +305,7 @@ def testPP4Sharding2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["PP4Sharding2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -320,7 +320,7 @@ def testSharding8S1(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding8S1"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -335,7 +335,7 @@ def testSharding8S2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding8S2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -350,7 +350,7 @@ def testSharding4S1DP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding4S1DP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -365,7 +365,7 @@ def testSharding4S2DP2(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding4S2DP2"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -380,7 +380,7 @@ def testSharding2S1DP4(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding2S1DP4"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -395,7 +395,7 @@ def testSharding2S2DP4(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["Sharding2S2DP4"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -410,7 +410,7 @@ def testDP8(self):
remove_ckpt(lora_arguments["output_dir"])

train_args = self.configs["DP8"]
self.runfrist(train_args)
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
Expand All @@ -419,27 +419,29 @@ def testDP8(self):
np.testing.assert_allclose(res[0], res[1], self.rtol)


@pytest.mark.skipif(True, reason="Skip for None CE")
class TestUnifiedCheckpointOnN2C4(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()
self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
def runfirst(self, train_args):
self.run_n2c4(self.run_lora_file, **train_args)

def rerun(self, train_args):
self.run_n2c4(self.run_lora_file, **train_args)


@pytest.mark.skipif(True, reason="Skip for None CE")
class TestUnifiedCheckpointOnN1C8CheckpointCompatible(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()

self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
def runfirst(self, train_args):
train_args["unified_checkpoint"] = 0
self.run_n1c8(self.run_lora_file, **train_args)

Expand All @@ -448,14 +450,15 @@ def rerun(self, train_args):
self.run_n1c8(self.run_lora_file, **train_args)


@pytest.mark.skipif(True, reason="Skip for None CE")
class TestPaddleCheckpointOnN1C8Reset(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()

self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
def runfirst(self, train_args):
train_args["unified_checkpoint"] = 0
self.run_n1c8(self.run_lora_file, **train_args)

Expand All @@ -472,7 +475,7 @@ def setUp(self):
self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
def runfirst(self, train_args):
train_args["unified_checkpoint"] = 0
self.run_n2c4(self.run_lora_file, **train_args)

Expand Down
Loading
Loading