Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic #4758

Merged
merged 6 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion colossalai/amp/naive_amp/mixed_precision_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch import Tensor
from torch.nn import Parameter
from torch.nn import Module, Parameter
from torch.optim import Optimizer

from colossalai.interface import OptimizerWrapper
Expand Down Expand Up @@ -152,3 +152,18 @@ def step(self, *args, **kwargs):
if p is working_param:
continue
working_param.data.copy_(p.data)

def update_master_params(self, model: Module):
# Update master params from working params
with torch.no_grad():
for p in model.parameters():
if (p is None) or (p not in self.working_to_master_map):
continue
master_param = self.working_to_master_map[p]
master_param.data.copy_(p.data)

def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()}

def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
2 changes: 1 addition & 1 deletion colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def boost(

if self.plugin and not self.plugin.control_device():
# transform model for accelerator
model = self.accelerator.configure(model)
model = self.accelerator.configure_model(model)

if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
# transform model for mixed precision
Expand Down
23 changes: 14 additions & 9 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
Expand All @@ -53,24 +54,27 @@ def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool =
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
The saving process will only be executed by master rank.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
"""
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)

def save_sharded_model(
Expand All @@ -86,6 +90,7 @@ def save_sharded_model(
Save sharded model.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
Expand All @@ -111,7 +116,7 @@ def save_sharded_model(
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.module, checkpoint_path)
save_config_file(model.unwrap(), checkpoint_path)
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
Expand All @@ -124,17 +129,17 @@ def load_sharded_model(
"""
Load shard model, load model from multiple files.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)

def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
):
"""
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
"""

assert isinstance(optimizer, GeminiOptimizer)
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"

if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
Expand Down Expand Up @@ -176,12 +181,12 @@ def save_sharded_optimizer(
f"index located at {save_index_file}."
)

def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
"""

assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")

Expand Down Expand Up @@ -383,7 +388,7 @@ def configure(

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(
optimizer, model.unwrap(), **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
)

return model, optimizer, criterion, dataloader, lr_scheduler
Expand Down
22 changes: 13 additions & 9 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
from contextlib import nullcontext
from functools import partial
from types import MethodType
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -165,6 +166,15 @@ def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_in
init_pipeline_optimizer(optim, model)
super().__init__(optim)

def update_master_params(self, model: Module):
pass

def get_working_to_master_map(self):
return None

def get_master_to_working_map(self):
return None


class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(
Expand Down Expand Up @@ -466,9 +476,6 @@ def configure(
max_norm=self.max_norm,
**self.amp_config,
)
self.checkpoint_io.link_master_and_working_param(
optimizer.working_to_master_map, optimizer.master_to_working_map
)
else:
optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
Expand All @@ -488,10 +495,8 @@ def configure(
**self.zero_config,
**self.amp_config,
)
self.checkpoint_io.link_master_and_working_param(
optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param
)

# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler

def execute_pipeline(
Expand Down Expand Up @@ -567,8 +572,7 @@ def seed_worker(worker_id):
)

def get_checkpoint_io(self) -> CheckpointIO:
self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)

def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError
46 changes: 10 additions & 36 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
save_param_groups,
save_state_dict,
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
Expand Down Expand Up @@ -65,10 +64,6 @@ def forward(self, *args, **kwargs):
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)

def unwrap(self):
# TODO(ver217): this is a workaround for loading model
return self


class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
Expand All @@ -79,7 +74,7 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str,
checkpoint (str): Path to save checkpoint
gather_dtensor (bool): Whether to gather_dtensor, not used
"""

assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
Expand Down Expand Up @@ -109,6 +104,7 @@ def save_sharded_optimizer(
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file that store state tensors
"""
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Expand Down Expand Up @@ -160,9 +156,8 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
index_file_path (str): Path to the index file
prefix (str): Not used.
"""
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!"
optimizer = optimizer.unwrap()

# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
Expand Down Expand Up @@ -194,44 +189,23 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
v_list = v.split(v.numel() // self.coordinator.world_size)
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
load_states_into_optimizer(optimizer, state_dict, id_map)

sharded_optimizer_loading_epilogue(optimizer)

def save_unsharded_model(
self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool
):
assert isinstance(model, LowLevelZeroModel)
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)

def save_sharded_model(
self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
assert isinstance(model, LowLevelZeroModel)
super().save_sharded_model(
model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)

def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_unsharded_model(model.module, checkpoint, strict)
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict)
ver217 marked this conversation as resolved.
Show resolved Hide resolved
model.update_master_params()

def load_sharded_model(
self,
model: LowLevelZeroModel,
model: ModelWrapper,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
assert isinstance(model, LowLevelZeroModel)
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()


Expand Down
63 changes: 51 additions & 12 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,33 @@ def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()

def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
Load model from checkpoint.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)

def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Load optimizer from checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)

def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)

Expand All @@ -50,7 +59,7 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):

def save_sharded_model(
self,
model: nn.Module,
model: ModelWrapper,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
Expand All @@ -60,22 +69,52 @@ def save_sharded_model(
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
super().save_sharded_model(
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)

def load_sharded_model(
self,
model: ModelWrapper,
checkpoint_index_file: str,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
"""
Load model from sharded checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)

def save_sharded_optimizer(
self,
optimizer: Optimizer,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
"""
Save optimizer to checkpoint but only on master process.
Save optimizer to sharded checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)

def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
prefix: Optional[str] = None,
):
"""
Load optimizer from sharded checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)


class TorchDDPModel(ModelWrapper):
Expand Down
Loading
Loading