Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chat] fix gemini strategy #4698

Merged
merged 23 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def main(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5)
elif args.strategy == "colossalai_gemini_cpu":
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif args.strategy == "colossalai_zero2_cpu":
Expand Down
1 change: 1 addition & 0 deletions applications/Chat/coati/models/base/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ def forward(
"""Returns model output."""
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
return output

4 changes: 2 additions & 2 deletions applications/Chat/coati/ray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def get_strategy_from_args(strategy: str):
if strategy == "ddp":
strategy_ = DDPStrategy()
elif strategy == "colossalai_gemini":
strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5)
elif strategy == "colossalai_zero2":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif strategy == "colossalai_gemini_cpu":
strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
elif strategy == "colossalai_zero2_cpu":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
Expand Down
4 changes: 2 additions & 2 deletions applications/Chat/coati/trainer/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def unwrap_model(model: nn.Module) -> nn.Module:
"""
return model

def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None:
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None:
self.booster.save_model(model, path, shard=shard, **kwargs)

def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
self.booster.load_model(model, path, strict)
Expand Down
9 changes: 7 additions & 2 deletions applications/Chat/coati/trainer/strategies/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.utils import get_current_device
from colossalai.zero.gemini.gemini_ddp import GeminiDDP

Expand Down Expand Up @@ -130,6 +129,9 @@ def __init__(
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = "auto",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
Expand Down Expand Up @@ -160,6 +162,9 @@ def __init__(
plugin_initializer = lambda: GeminiPlugin(
chunk_init_device=get_current_device(),
placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
precision="fp16",
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
Expand Down Expand Up @@ -188,7 +193,7 @@ def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)

def model_init_context(self):
return LazyInitContext(default_device=get_current_device())
return super().model_init_context()

def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, GeminiDDP)
Expand Down
10 changes: 5 additions & 5 deletions applications/Chat/coati/trainer/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def unwrap_model(self, model: nn.Module) -> nn.Module:
return model.unwrap()

def save_pretrained(
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> None:
if not only_rank0 or dist.get_rank() == 0:
if dist.get_rank() == 0:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
pretrained_model = unwrapped_model.model
Expand All @@ -98,19 +98,19 @@ def save_pretrained(
pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
if tokenizer is not None:
tokenizer.save_pretrained(path)
model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model, model_path, only_rank0=only_rank0)

model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model, model_path, shard=shard)
def _replace_keys(model_path: str, replace_fn: Callable):
state_dict = torch.load(model_path, map_location="cpu")
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
torch.save(state_dict, model_path)

# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
# HACK: rename keys of pytorch_model.bin
if dist.get_rank() == 0:
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))


def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
model = self.unwrap_model(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def train(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda")
strategy = GeminiStrategy(placement_policy="static")
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
Expand Down
2 changes: 1 addition & 1 deletion applications/Chat/examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pandas>=1.4.1
sentencepiece
colossalai>=0.3.1
colossalai==0.3.3
8 changes: 6 additions & 2 deletions applications/Chat/examples/train_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def main(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
Expand All @@ -33,6 +33,10 @@ def main(args):
warnings.warn("LoRA weights should be merged with the model weights")
state_dict = torch.load(args.rm_path, map_location="cpu")

if args.lora_rank > 0:
warnings.warn("Lora is not supported yet.")
args.lora_rank = 0

with strategy.model_init_context():
# configure model
if args.model == "gpt2":
Expand Down Expand Up @@ -193,7 +197,7 @@ def main(args):
)

# save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True)
strategy.save_pretrained(actor, path=args.save_path)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(
Expand Down
8 changes: 7 additions & 1 deletion applications/Chat/examples/train_reward_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import warnings

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -33,6 +34,10 @@ def train(args):
raise ValueError(f'Unsupported strategy "{args.strategy}"')

# configure model
if args.lora_rank > 0:
warnings.warn("Lora is not supported yet.")
args.lora_rank = 0

with strategy.model_init_context():
if args.model == "bloom":
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
Expand Down Expand Up @@ -158,7 +163,8 @@ def train(args):
use_wandb=args.use_wandb,
)
# save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True)
state_dict = model.state_dict()
torch.save(state_dict, args.save_path)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(
Expand Down
7 changes: 4 additions & 3 deletions applications/Chat/examples/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def train(args):

# configure model
if args.lora_rank > 0:
warnings.warn("Gradient checkpoint is disabled when using LoRA")
args.grad_checkpoint = False
warnings.warn("Lora is not supported yet.")
args.lora_rank = 0

with strategy.model_init_context():
if args.model == "bloom":
model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
Expand Down Expand Up @@ -178,7 +179,7 @@ def train(args):
)

# save model checkpoint after fitting on only rank0
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
strategy.save_pretrained(model, path=args.save_path, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(
Expand Down
2 changes: 1 addition & 1 deletion applications/Chat/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pytest
colossalai>=0.3.1
colossalai==0.3.3
2 changes: 1 addition & 1 deletion applications/Chat/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ transformers>=4.20.1
tqdm
datasets
loralib
colossalai>=0.3.1
colossalai==0.3.3
torch<2.0.0, >=1.12.1
langchain
tokenizers
Expand Down
4 changes: 2 additions & 2 deletions applications/Chat/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def run_test_checkpoint(strategy_name: str, shard: bool):
rank0_dirname = rank0_dirname[0]

model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt")
strategy.save_model(actor, model_path, only_rank0=not shard)
strategy.save_model(actor, model_path)
optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt")
strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard)
strategy.save_optimizer(actor_optim, optim_path)
dist.barrier()

strategy.load_model(actor, model_path, strict=False)
Expand Down
20 changes: 6 additions & 14 deletions applications/Chat/tests/test_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ MODELS_DIR=$BASE_DIR/examples/models_config
MODELS=('gpt2' 'bloom' 'opt' 'llama')
STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2')


export OMP_NUM_THREADS=8

# install requirements
Expand Down Expand Up @@ -80,13 +81,10 @@ SKIPPED_TESTS=(
"llama-ddp"
"llama-colossalai_gemini"
"llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
)

GRAD_CKPTS=('' '--grad_checkpoint')
for lora_rank in '0' '4'; do
for lora_rank in '0'; do
for model in ${MODELS[@]}; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
Expand Down Expand Up @@ -135,14 +133,11 @@ SKIPPED_TESTS=(
"llama-ddp"
"llama-colossalai_gemini"
"llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
)

LOSS_FNS=('log_sig' 'log_exp')
DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static')
for lora_rank in '0' '4'; do
for lora_rank in '0'; do
for model in ${MODELS[@]}; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
Expand Down Expand Up @@ -193,13 +188,10 @@ SKIPPED_TESTS=(
"llama-ddp"
"llama-colossalai_gemini"
"llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
)

for model in ${MODELS[@]}; do
for lora_rank in '0' '4'; do
for lora_rank in '0'; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
Expand All @@ -223,7 +215,7 @@ for model in ${MODELS[@]}; do
--experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
--pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
$rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
--save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
--save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
passed=$?
if [ $passed -eq 0 ]; then
break
Expand All @@ -238,4 +230,4 @@ for model in ${MODELS[@]}; do
rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
done
done
rm $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
rm -rf $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
Loading