Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ZeRO, checkpoint and pipeline code #128

Merged
merged 73 commits into from
Aug 21, 2023
Merged

Conversation

zkh2016
Copy link
Collaborator

@zkh2016 zkh2016 commented Jul 24, 2023

重构代码,各种策略解藕:

  1. zero和checkpint独立控制,添加开关use_checkpoint,控制是否开启checkpoint
  2. checkpoint开关放在CheckpointBlock中,可单模块开关是否使用checkpoint
  3. 添加CustomLinear解决zero3参数无法释放问题
  4. pipeline,zero2/3,checkpoint可以自由组合
  5. 10B模型8卡A100 zero-3不开checkpoint训练性能提升15%,单卡显存增加35%
    bmt.init_distributed(
        zero_level=3, # 支持 Zero2 或 Zero3
        pipe_size=2 #配置流水线并行 stage 数
    )

 self.transformers = bmt.TransformerBlockList([
            bmt.CheckpointBlock(
                TransformerEncoder(
                    dim_model, dim_head, num_heads, dim_ff, bias, dtype
                ),
             use_checkpoint=False, #开启或关闭checkpoint,默认开启
            )
            for _ in range(num_layers)
        ])
 self.transformers(x, mask)

TODO:

  1. 自动替换:torch.nn.Linear -> bmt.nn.Linear
  2. 添加ZeRO-1的功能
  3. 重构通信模块
  4. 重构类名、文件名、目录结构

@zkh2016 zkh2016 changed the title [WIP]Using hooks to implement ZeRO and Checkpoint Using hooks to implement ZeRO and Checkpoint Jul 26, 2023
zhangkaihuo added 3 commits July 27, 2023 12:19
@zkh2016 zkh2016 changed the title Using hooks to implement ZeRO and Checkpoint [WIP]Using hooks to implement ZeRO and Checkpoint Jul 31, 2023
@zkh2016 zkh2016 changed the base branch from main to dev August 18, 2023 02:56
@Achazwl Achazwl self-requested a review August 18, 2023 11:07
@Achazwl Achazwl assigned Achazwl and zkh2016 and unassigned Achazwl Aug 18, 2023
bmtrain/init.py Outdated Show resolved Hide resolved
bmtrain/init.py Outdated Show resolved Hide resolved
bmtrain/block_layer.py Outdated Show resolved Hide resolved
bmtrain/block_layer.py Outdated Show resolved Hide resolved
bmtrain/block_layer.py Show resolved Hide resolved
bmtrain/block_layer.py Show resolved Hide resolved
bmtrain/block_layer.py Outdated Show resolved Hide resolved
bmtrain/block_layer.py Outdated Show resolved Hide resolved
arg_list = self.pre_hook(*args)

if self.use_checkpoint:
out = checkpoint(self._module, *arg_list, use_reentrant=not self.all_input_no_grad)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 找到 reentrant 报错的原因
  2. False 可以省计算

bmtrain/block_layer.py Outdated Show resolved Hide resolved
bmtrain/checkpointing.py Outdated Show resolved Hide resolved
bmtrain/checkpointing.py Show resolved Hide resolved
bmtrain/pipe_layer.py Outdated Show resolved Hide resolved
bmtrain/pipe_layer.py Show resolved Hide resolved
bmtrain/__init__.py Outdated Show resolved Hide resolved
@zkh2016 zkh2016 merged commit 74700e4 into OpenBMB:dev Aug 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants