Skip to content

Commit

Permalink
[feature] support no master weights option for low level zero plugin (h…
Browse files Browse the repository at this point in the history
…pcaitech#4816)

* [feature] support no master weights for low level zero plugin

* [feature] support no master weights for low level zero plugin, remove data copy when no master weights

* remove data copy and typecasting when no master weights

* not load weights to cpu when using no master weights

* fix grad: use fp16 grad when no master weights

* only do not update working param when no master weights

* fix: only do not update working param when no master weights

* fix: passing params in dict format in hybrid plugin

* fix: remove extra params (tp_process_group) in hybrid_parallel_plugin
  • Loading branch information
KKZ20 authored Oct 13, 2023
1 parent 77a9328 commit a0684e7
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 28 deletions.
34 changes: 17 additions & 17 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,23 +464,23 @@ def __init__(
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(
optimizer,
initial_scale,
min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
max_scale,
clip_grad_norm,
verbose,
reduce_bucket_size,
communication_dtype,
overlap_communication,
partition_grad,
cpu_offload,
dp_process_group,
forced_dtype,
optimizer=optimizer,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
clip_grad_norm=clip_grad_norm,
verbose=verbose,
reduce_bucket_size=reduce_bucket_size,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
)

def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
Expand Down
6 changes: 4 additions & 2 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def __init__(
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
) -> None:
super().__init__()
Expand All @@ -272,18 +273,19 @@ def __init__(
self.precision = precision
self.zero_optim_kwargs = dict(
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
clip_grad_norm=max_norm,
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(stage == 2),
cpu_offload=cpu_offload,
master_weights=master_weights,
)
self.verbose = verbose

Expand Down
30 changes: 21 additions & 9 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]["params"][0].dtype
Expand Down Expand Up @@ -106,6 +107,9 @@ def __init__(
# gradient clipping
self._clip_grad_norm = clip_grad_norm

# master weights copy
self._master_weights = master_weights

if forced_dtype:
for group in self.optim.param_groups:
group_params = group["params"]
Expand Down Expand Up @@ -135,7 +139,6 @@ def __init__(
self._working_param_groups[group_id] = group_params

master_param_current_rank = self._create_master_param_current_rank(group_params)

self._master_param_groups_of_current_rank[group_id] = master_param_current_rank

# need to replace the params in the `params` field in the optimizer
Expand Down Expand Up @@ -200,11 +203,18 @@ def _create_master_param_current_rank(self, param_list):
with torch.no_grad():
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
# reset working params' ptr when no master weights
if self._master_weights == False:
param.data = padding_param[: param.numel()].view(param.shape)
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // self._world_size)

splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
# use fp32 when master_weights is True
if self._master_weights is True:
splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
else:
splited_param_current_rank = splited_params[self._local_rank]
params_current_rank.append(splited_param_current_rank)
self._param_store.link_master_and_working_param(splited_param_current_rank, param)

Expand Down Expand Up @@ -402,9 +412,7 @@ def step(self, closure=None):
# and should not be updated
real_working_params = dict()
real_master_params = dict()

grad_index = 0 if self._partition_grads else self._local_rank

for group_id in range(self.num_param_groups):
master_params = self._master_param_groups_of_current_rank[group_id]
real_working_params[group_id] = []
Expand All @@ -417,7 +425,12 @@ def step(self, closure=None):
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
if len(grads) > 0:
real_working_params[group_id].append(working_param)
grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device)
# no need to copy fp32 grad if master_weights is False
grad = (
grads[grad_index].to(splited_param.dtype).to(splited_param.device)
if self._master_weights
else grads[grad_index]
)
splited_param.grad = grad
grad_partition_groups.append(grad)
real_master_params[group_id].append(splited_param)
Expand Down Expand Up @@ -445,17 +458,16 @@ def step(self, closure=None):
release_param_grad(self._master_param_groups_of_current_rank[group_id])

# update working partition updated by the current rank
dtype = real_working_params[0][0].dtype
# dtype = real_working_params[0][0].dtype
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size)
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg)
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))

self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]

def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
Expand Down

0 comments on commit a0684e7

Please sign in to comment.