Skip to content

Commit

Permalink
Cherry pick some changes from incubate branch (#8862)
Browse files Browse the repository at this point in the history
* Fix sharding reshard bug and support reshard in different partition ways (#8837)

* Fix sharding reshard bug and support reshard in different partition ways

* fix some bugs

* refine codes

* revert some useless codes

* add log

* Fix TP sync (#8846)

* fix sharding reshard compatiblity for wrong master weight name (#8851)

* change print to logger.info

* remove useless f-string
  • Loading branch information
sneaxiy authored Aug 6, 2024
1 parent c4d1abf commit dbf395f
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 18 deletions.
21 changes: 17 additions & 4 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,17 @@
g_cpu_optimizer_state_dict = {}


def _save_func(obj, path, saved_signal_path, protocol):
def _save_func(obj, name_mapping, path, saved_signal_path, protocol):
if isinstance(obj, dict):
for k, v in obj.items():
if k == "master_weights" and isinstance(v, dict):
for kk, vv in v.items():
if isinstance(vv, paddle.Tensor):
vv.name = name_mapping["master_weights"][kk]
else:
if k in name_mapping and isinstance(v, paddle.Tensor):
v.name = name_mapping[k]

paddle.save(obj, path, protocol)
# dump savd_siganl
with open(saved_signal_path, mode="w+") as f:
Expand Down Expand Up @@ -228,17 +238,18 @@ def clear_async_save_task_queue():
def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol=4):
global g_cpu_optimizer_state_dict
g_cpu_optimizer_state_dict.clear()
name_mapping = {"master_weights": {}}
for k, v in optimizer_state_dict.items():
if k == "master_weights":
g_cpu_optimizer_state_dict[k] = {}
for kk, vv in v.items():
tensor_name = vv.name
g_cpu_optimizer_state_dict[k][kk] = vv.pin_memory()
g_cpu_optimizer_state_dict[k][kk].name = tensor_name
name_mapping[k][kk] = vv.name
elif k == "LR_Scheduler":
g_cpu_optimizer_state_dict[k] = copy.deepcopy(v)
else:
g_cpu_optimizer_state_dict[k] = v.pin_memory()
name_mapping[k] = v.name
paddle.device.synchronize()
clear_async_save_task_queue()

Expand All @@ -248,7 +259,9 @@ def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol
def start_process():
nonlocal attempt
try:
p = ctx.Process(target=_save_func, args=(g_cpu_optimizer_state_dict, path, saved_signal_path, protocol))
p = ctx.Process(
target=_save_func, args=(g_cpu_optimizer_state_dict, name_mapping, path, saved_signal_path, protocol)
)
p.start()
return p
except Exception as e:
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,15 +1171,23 @@ def split_parallel_config(parallel_config):
# sync_param_name = [""] matches any parameter name.
# If sync_param, sync_grad and sync_moment are not set, the default value in Paddle is :
# sync_param = True, sync_grad = False, sync_moment = False, sync_param_name = ["embedding", "layer_norm", ".b_"].

if sync_param or sync_grad or sync_moment:
logger.info("setting sync_param_name")
strategy.sync_param_name = [""]

if sync_param:
logger.info("setting sync_param")
strategy.hybrid_configs["mp_configs"].sync_param = True
strategy.hybrid_configs["mp_configs"].sync_param_name = [""]

if sync_grad:
logger.info("setting sync_grad")
strategy.hybrid_configs["mp_configs"].sync_grad = True
strategy.hybrid_configs["mp_configs"].sync_grad_name = [""]

if sync_moment:
logger.info("setting sync_moment")
strategy.hybrid_configs["mp_configs"].sync_moment = True
strategy.hybrid_configs["mp_configs"].sync_moment_name = [""]

except:
warnings.warn(
"The enable_mp_async_allreduce, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add are not supported "
Expand Down
61 changes: 59 additions & 2 deletions paddlenlp/trainer/utils/reshard/sharding_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@

import numpy as np
import paddle
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import (
HybridParallelOptimizer,
)
from paddle.distributed.fleet.model import PipelineParallel

from paddlenlp.utils.log import logger

from ....transformers.model_utils import unwrap_optimizer

try:
Expand All @@ -29,6 +32,9 @@
DygraphShardingOptimizerV2 = None


from paddle.distributed.communication.reduce import ReduceOp


def shard(node_model_state, model, optimizer, hcg):
assert DygraphShardingOptimizerV2 is not None
group = hcg.get_sharding_parallel_group()
Expand Down Expand Up @@ -137,7 +143,7 @@ def slice_tensor(tensor, begin, end):
return tensor[begin:end]


def collect_split_info(optimizer, model):
def collect_split_info(optimizer, model, only_return_lengths=False):
split_infos = {}

def gather_infos(comm_buffer):
Expand All @@ -146,7 +152,13 @@ def gather_infos(comm_buffer):
padded_size = v._padded_size
buffer_size = v._param_buffer._numel()
has_slice_grad = v._slice_grad is not None
split_infos[k] = (index, padded_size, buffer_size, has_slice_grad)
if only_return_lengths:
if v._param_begin < v._param_end:
split_infos[k] = v._param_end - v._param_begin
else:
split_infos[k] = None
else:
split_infos[k] = (index, padded_size, buffer_size, has_slice_grad)

if isinstance(model, PipelineParallel) and model._sharding_comm_overlap > 0:
optimizer = unwrap_optimizer(optimizer, HybridParallelOptimizer)
Expand All @@ -167,6 +179,51 @@ def gather_infos(comm_buffer):
return split_infos


def is_matched_optimizer_state_dict(opt_state_dict, optimizer, model, hcg=None, need_allgather=True):
split_infos = collect_split_info(optimizer, model, only_return_lengths=True)
master_weights = opt_state_dict.get("master_weights", None)

def get_matched_length(name):
if master_weights and name in master_weights:
tensor = master_weights[name]
else:
moment_name = name + "_moment1_0"
if moment_name not in opt_state_dict:
return None

tensor = opt_state_dict[moment_name]
if isinstance(tensor, (list, tuple)):
assert len(tensor) == 2, tensor
assert isinstance(tensor[0], str), tensor[0]
tensor = tensor[1]
shape = tensor.shape
assert len(shape) == 1, shape
length = shape[0]
return length

is_matched = 1
for k, length in split_infos.items():
matched_length = get_matched_length(k)
if length != matched_length:
is_matched = 0
break

if need_allgather:
if hcg is None:
hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_sharding_parallel_group()
if group is not None and group.nranks > 1:
x = paddle.to_tensor([is_matched], dtype=paddle.int32)
paddle.distributed.stream.all_reduce(x, op=ReduceOp.MIN, group=group, sync_op=True, use_calc_stream=True)
global_is_matched = int(x.numpy()[0])
else:
global_is_matched = is_matched

global_is_matched = True if global_is_matched else False
logger.info(f"Sharding reshard checkpoint: local_match = {is_matched} , global_match = {global_is_matched}")
return global_is_matched


def is_bata(name):
if "_beta1_pow_acc_" in name:
return True
Expand Down
50 changes: 41 additions & 9 deletions paddlenlp/trainer/utils/sharding_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from paddlenlp.utils.log import logger

from . import reshard as reshard_util
from .reshard import SHARDING_STRATEGY_V1, pp_reshard
from .reshard import SHARDING_STRATEGY_V1, SHARDING_STRATEGY_V2, pp_reshard

# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
Expand Down Expand Up @@ -204,10 +204,21 @@ def _load_optimizer_state_of_one_shard(self, checkpoint, base_opt_name, optimize
path = os.path.join(checkpoint, optimizer_name)
logger.info(f"load optimizer state from {path}")
if os.path.isfile(path):
return paddlenlp_load(path, map_location="cpu")
return self._modify_ckpt_for_compatibility(paddlenlp_load(path, map_location="cpu"))
logger.info(f"{path} not exists")
return None

def _modify_ckpt_for_compatibility(self, ckpt):
master_weights = ckpt.get("master_weights", None)
if master_weights:
for k, v in master_weights.items():
assert isinstance(v, paddle.Tensor), v
if not v.name.startswith(k):
new_name = k + "_fp32_master_0"
logger.info(f"Modify master weights {v.name} -> {new_name}")
v.name = new_name
return ckpt

def _need_reshard(self, checkpoint):
if self._need_reshard_pp(checkpoint):
return True
Expand Down Expand Up @@ -253,10 +264,6 @@ def _need_reshard_pp(self, checkpoint):
def load_optimizer_state_with_reshard(self, checkpoint, base_opt_name, model_wrapped):
"""load state_dict of multiple shard from_checkpoint, Only load model state dict."""

if not self._need_reshard(checkpoint):
logger.info("do not need reshard")
return self._load_optimizer_state_of_one_shard(checkpoint, base_opt_name, self.args.optimizer_name_suffix)
logger.info("reshard optimizer state")
parallel_config = self._load_distributed_strategy(checkpoint)
sharding_meta = self._load_sharding_meta(checkpoint, 0)
pp_degree = parallel_config["pp_degree"]
Expand All @@ -276,16 +283,41 @@ def load_optimizer_state_with_reshard(self, checkpoint, base_opt_name, model_wra
cur_sharding_degree = self.args.sharding_parallel_degree
cur_sharding_strategy = reshard_util.get_sharding_strategy(self.optimizer)

if not self._need_reshard(checkpoint):
one_shard_opt_state_dict = self._load_optimizer_state_of_one_shard(
checkpoint, base_opt_name, self.args.optimizer_name_suffix
)

if sharding_strategy == SHARDING_STRATEGY_V2 and cur_sharding_strategy == SHARDING_STRATEGY_V2:
is_matched = reshard_util.sharding_v2.is_matched_optimizer_state_dict(
one_shard_opt_state_dict, self.optimizer, model_wrapped
)
else:
is_matched = True

if is_matched:
logger.info("do not need reshard")
return one_shard_opt_state_dict
else:
one_shard_opt_state_dict = None

logger.info("reshard optimizer state")

def load_model_slices():
model_state = reshard_util.NodeModelState()
for j in range(self.args.pipeline_parallel_rank, pp_degree, cur_pp_degree):
cur_sharding_meta = self._load_sharding_meta(checkpoint, j)
assert "structure_name_mapping" in cur_sharding_meta
structure_name_map = cur_sharding_meta["structure_name_mapping"]
for i in range(self.args.sharding_parallel_rank, sharding_degree, cur_sharding_degree):
tmp = self._load_optimizer_state_of_one_shard(
checkpoint, base_opt_name, self.args.sharded_name_suffix(i, j)
)
sharded_name_suffix = self.args.sharded_name_suffix(i, j)
if one_shard_opt_state_dict is None:
tmp = self._load_optimizer_state_of_one_shard(checkpoint, base_opt_name, sharded_name_suffix)
else:
assert (
self.args.optimizer_name_suffix == sharded_name_suffix
), f"{self.args.optimizer_name_suffix} vs {sharded_name_suffix}"
tmp = one_shard_opt_state_dict
node_model_state_tmp = reshard_util.NodeModelState()
node_model_state_tmp.add_opts(tmp)
node_model_state_tmp.pack_keys(structure_name_map)
Expand Down

0 comments on commit dbf395f

Please sign in to comment.