Skip to content

Commit

Permalink
[chat] fix gemini strategy (#4698)
Browse files Browse the repository at this point in the history
* [chat] fix gemini strategy

* [chat] fix gemini strategy

* [chat] fix gemini strategy

* [chat] fix gemini strategy

* g# This is a combination of 2 commits.

[chat] fix gemini strategy

fox

* [chat] fix gemini strategy

update llama2 example

[chat] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* fix

* fix

* fix

* fix

* fix

* Update train_prompts.py
  • Loading branch information
flybird11111 authored Sep 27, 2023
1 parent bbbcac2 commit be400a0
Show file tree
Hide file tree
Showing 16 changed files with 49 additions and 40 deletions.
4 changes: 2 additions & 2 deletions applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def main(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5)
elif args.strategy == "colossalai_gemini_cpu":
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif args.strategy == "colossalai_zero2_cpu":
Expand Down
1 change: 1 addition & 0 deletions applications/Chat/coati/models/base/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ def forward(
"""Returns model output."""
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
return output

4 changes: 2 additions & 2 deletions applications/Chat/coati/ray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def get_strategy_from_args(strategy: str):
if strategy == "ddp":
strategy_ = DDPStrategy()
elif strategy == "colossalai_gemini":
strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5)
elif strategy == "colossalai_zero2":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif strategy == "colossalai_gemini_cpu":
strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
elif strategy == "colossalai_zero2_cpu":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
Expand Down
4 changes: 2 additions & 2 deletions applications/Chat/coati/trainer/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def unwrap_model(model: nn.Module) -> nn.Module:
"""
return model

def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None:
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None:
self.booster.save_model(model, path, shard=shard, **kwargs)

def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
self.booster.load_model(model, path, strict)
Expand Down
9 changes: 7 additions & 2 deletions applications/Chat/coati/trainer/strategies/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.utils import get_current_device
from colossalai.zero.gemini.gemini_ddp import GeminiDDP

Expand Down Expand Up @@ -130,6 +129,9 @@ def __init__(
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = "auto",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
Expand Down Expand Up @@ -160,6 +162,9 @@ def __init__(
plugin_initializer = lambda: GeminiPlugin(
chunk_init_device=get_current_device(),
placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
precision="fp16",
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
Expand Down Expand Up @@ -188,7 +193,7 @@ def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)

def model_init_context(self):
return LazyInitContext(default_device=get_current_device())
return super().model_init_context()

def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, GeminiDDP)
Expand Down
10 changes: 5 additions & 5 deletions applications/Chat/coati/trainer/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def unwrap_model(self, model: nn.Module) -> nn.Module:
return model.unwrap()

def save_pretrained(
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> None:
if not only_rank0 or dist.get_rank() == 0:
if dist.get_rank() == 0:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
pretrained_model = unwrapped_model.model
Expand All @@ -98,19 +98,19 @@ def save_pretrained(
pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
if tokenizer is not None:
tokenizer.save_pretrained(path)
model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model, model_path, only_rank0=only_rank0)

model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model, model_path, shard=shard)
def _replace_keys(model_path: str, replace_fn: Callable):
state_dict = torch.load(model_path, map_location="cpu")
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
torch.save(state_dict, model_path)

# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
# HACK: rename keys of pytorch_model.bin
if dist.get_rank() == 0:
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))


def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
model = self.unwrap_model(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def train(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda")
strategy = GeminiStrategy(placement_policy="static")
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
Expand Down
2 changes: 1 addition & 1 deletion applications/Chat/examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pandas>=1.4.1
sentencepiece
colossalai>=0.3.1
colossalai==0.3.3
8 changes: 6 additions & 2 deletions applications/Chat/examples/train_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def main(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
Expand All @@ -33,6 +33,10 @@ def main(args):
warnings.warn("LoRA weights should be merged with the model weights")
state_dict = torch.load(args.rm_path, map_location="cpu")

if args.lora_rank > 0:
warnings.warn("Lora is not supported yet.")
args.lora_rank = 0

with strategy.model_init_context():
# configure model
if args.model == "gpt2":
Expand Down Expand Up @@ -199,7 +203,7 @@ def main(args):
LORA_MANAGER.merge_weights = True
actor.eval()
# save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True)
strategy.save_pretrained(actor, path=args.save_path)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(
Expand Down
8 changes: 7 additions & 1 deletion applications/Chat/examples/train_reward_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import warnings

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -33,6 +34,10 @@ def train(args):
raise ValueError(f'Unsupported strategy "{args.strategy}"')

# configure model
if args.lora_rank > 0:
warnings.warn("Lora is not supported yet.")
args.lora_rank = 0

with strategy.model_init_context():
if args.model == "bloom":
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
Expand Down Expand Up @@ -165,7 +170,8 @@ def train(args):
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True)
state_dict = model.state_dict()
torch.save(state_dict, args.save_path)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(
Expand Down
7 changes: 4 additions & 3 deletions applications/Chat/examples/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def train(args):

# configure model
if args.lora_rank > 0:
warnings.warn("Gradient checkpoint is disabled when using LoRA")
args.grad_checkpoint = False
warnings.warn("Lora is not supported yet.")
args.lora_rank = 0

with strategy.model_init_context():
if args.model == "bloom":
model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
Expand Down Expand Up @@ -184,7 +185,7 @@ def train(args):
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
strategy.save_pretrained(model, path=args.save_path, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(
Expand Down
2 changes: 1 addition & 1 deletion applications/Chat/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pytest
colossalai>=0.3.1
colossalai==0.3.3
2 changes: 1 addition & 1 deletion applications/Chat/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ transformers>=4.20.1
tqdm
datasets
loralib
colossalai>=0.3.1
colossalai==0.3.3
torch<2.0.0, >=1.12.1
langchain
tokenizers
Expand Down
4 changes: 2 additions & 2 deletions applications/Chat/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def run_test_checkpoint(strategy_name: str, shard: bool):
rank0_dirname = rank0_dirname[0]

model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt")
strategy.save_model(actor, model_path, only_rank0=not shard)
strategy.save_model(actor, model_path)
optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt")
strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard)
strategy.save_optimizer(actor_optim, optim_path)
dist.barrier()

strategy.load_model(actor, model_path, strict=False)
Expand Down
20 changes: 6 additions & 14 deletions applications/Chat/tests/test_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ MODELS_DIR=$BASE_DIR/examples/models_config
MODELS=('gpt2' 'bloom' 'opt' 'llama')
STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2')


export OMP_NUM_THREADS=8

# install requirements
Expand Down Expand Up @@ -80,13 +81,10 @@ SKIPPED_TESTS=(
"llama-ddp"
"llama-colossalai_gemini"
"llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
)

GRAD_CKPTS=('' '--grad_checkpoint')
for lora_rank in '0' '4'; do
for lora_rank in '0'; do
for model in ${MODELS[@]}; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
Expand Down Expand Up @@ -135,14 +133,11 @@ SKIPPED_TESTS=(
"llama-ddp"
"llama-colossalai_gemini"
"llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
)

LOSS_FNS=('log_sig' 'log_exp')
DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static')
for lora_rank in '0' '4'; do
for lora_rank in '0'; do
for model in ${MODELS[@]}; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
Expand Down Expand Up @@ -193,13 +188,10 @@ SKIPPED_TESTS=(
"llama-ddp"
"llama-colossalai_gemini"
"llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
)

for model in ${MODELS[@]}; do
for lora_rank in '0' '4'; do
for lora_rank in '0'; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
Expand All @@ -223,7 +215,7 @@ for model in ${MODELS[@]}; do
--experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
--pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
$rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
--save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
--save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
passed=$?
if [ $passed -eq 0 ]; then
break
Expand All @@ -238,4 +230,4 @@ for model in ${MODELS[@]}; do
rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
done
done
rm $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
rm -rf $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts

0 comments on commit be400a0

Please sign in to comment.