Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ZeRO, checkpoint and pipeline code #128

Merged
merged 73 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
bcea035
using hooks to implement ZeRO and Checkpoint
Jul 24, 2023
7b080e7
async backward
Jul 25, 2023
be5f9d7
async forward
Jul 25, 2023
dea1781
merge upstream
Jul 25, 2023
05bc553
fix
Jul 25, 2023
bdf7087
save cuda_rng_state
Jul 26, 2023
6a366e3
fix
Jul 27, 2023
25ef84f
fix
Jul 27, 2023
768f209
fix
Jul 27, 2023
324e0dd
remove __call__
Jul 31, 2023
0f4ddb5
refactor code structure
Jul 31, 2023
76c5c26
pipeline
Jul 31, 2023
16c0922
for low version
Jul 31, 2023
2d35ba0
for low torch version
Jul 31, 2023
bc48d83
for checkpoint
Jul 31, 2023
bd61071
remove unused code
Jul 31, 2023
de25455
remove duplicate code
Jul 31, 2023
fde122f
fix pipeline; checkpoint support low version
Aug 1, 2023
a897ad4
fix pipeline; checkpoint support low version
Aug 1, 2023
ca50795
merge remote
Aug 1, 2023
ec8385b
fix indent
Aug 1, 2023
9877a81
pipe support low version
Aug 2, 2023
28993b5
custom linear for zero3
Aug 2, 2023
4d43952
merge origin
Aug 3, 2023
e4eaebf
resolve conflict
Aug 3, 2023
cba7c55
resolve conflict
Aug 3, 2023
839a976
use torch.utils.checkpoint.checkpoint
Aug 3, 2023
d5bbf1a
custom hook
Aug 4, 2023
e92d0ef
optimize code structure
Aug 4, 2023
6ba753e
for hidden_state
Aug 4, 2023
b0a0da9
for input.requires_grad is False
Aug 4, 2023
f4a0e0b
fix
Aug 5, 2023
8faff0f
pipeline support return hidden_state
Aug 6, 2023
26c8c94
fix args
Aug 7, 2023
b7d1c8c
fix test
Aug 7, 2023
4303575
CheckpointBlock -> BMTBlock
Aug 8, 2023
8061b66
reset block name
Aug 8, 2023
845f210
pipeline support batch_related
Aug 8, 2023
0b14fe5
remove use_checkpoint from init_distributed
Aug 9, 2023
ae56de8
for requires_grad
Aug 10, 2023
27ae2b7
for requires_grad
Aug 10, 2023
fdc8231
fix for arg is not tensor
Aug 10, 2023
b0f7154
fix for arg is not a tensor
Aug 10, 2023
420b626
add test
Aug 10, 2023
b843489
Merge branch 'hook' of https://github.com/zkh2016/BMTrain into hook
Aug 10, 2023
ebc269f
merge enhance_ckp
Aug 11, 2023
2f1e766
enhance ckp
Aug 11, 2023
1e993c6
refactor code
Aug 12, 2023
24d0f59
mv linear to bmt.nn.linear
Aug 12, 2023
ff72e66
for enhance_ckp
Aug 12, 2023
1fbf3b2
fix for all input not grad
Aug 14, 2023
ace5216
fix pre_module
Aug 14, 2023
52cd4e2
fix pre_module
Aug 14, 2023
0b0bd0b
fix for all input no grad
Aug 14, 2023
05b49f8
fix for all input no grad
Aug 14, 2023
64eb672
Merge branch 'main' of https://github.com/OpenBMB/BMTrain into hook
Aug 16, 2023
88b5bd3
fix reentrant
Aug 17, 2023
9c2e47d
Merge branch 'hook' of https://github.com/zkh2016/BMTrain into hook
Aug 17, 2023
e93e6dc
Merge branch 'dev' into hook
Aug 18, 2023
fd49311
refactor CheckpointBlock
Aug 20, 2023
221bdc3
refactor pipe
Aug 20, 2023
76f74e5
Merge branch 'hook' of https://github.com/zkh2016/BMTrain into hook
Aug 20, 2023
9c63407
fix all input no grad
Aug 20, 2023
f72fcfc
fix hiddenstate
Aug 20, 2023
ebdf519
fix test
Aug 21, 2023
780ca20
fix
Aug 21, 2023
6df85e7
remove unused import
Aug 21, 2023
bb482d6
fix pre_module
Aug 21, 2023
1010d26
recovery some code
Aug 21, 2023
b580530
add test_no_grad.py
Aug 21, 2023
767a875
test unroll block list
Aug 21, 2023
d19a627
fix test_fp32
Aug 21, 2023
bf986a7
cross_entropy support fp32
Aug 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions bmtrain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from .utils import print_block, print_dict, print_rank, see_memory, load_nccl_pypi
try:
from . import nccl
zkh2016 marked this conversation as resolved.
Show resolved Hide resolved
except:
load_nccl_pypi()
from .global_var import config, world_size, rank
from .init import init_distributed

from .parameter import DistributedParameter, ParameterInitializer
from .layer import DistributedModule
from .param_init import init_parameters, grouped_parameters
from .utils import print_block, print_dict, print_rank, see_memory
from .synchronize import synchronize, sum_loss, wait_loader, gather_result
from .block_layer import CheckpointBlock, TransformerBlockList
from .wrapper import BMTrainModelWrapper
from .pipe_layer import PipelineTransformerBlockList
from . import debug
from .store import save, load

from . import benchmark
from . import optim
from . import inspect
from . import lr_scheduler
from . import loss
from . import distributed
598 changes: 126 additions & 472 deletions bmtrain/block_layer.py

Large diffs are not rendered by default.

154 changes: 154 additions & 0 deletions bmtrain/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from typing import Callable, TypeVar
from functools import wraps
from . import debug
from . import nccl
from .global_var import config
from .synchronize import wait_loader, synchronize

class ScopedDebugTensorList:
Achazwl marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self) -> None:
Expand All @@ -28,3 +31,154 @@ def __exit__(self, *args):
self._local_list._set_hidden_states(debug.get("_inspect_hidden_states", []))
debug.set("_inspect_hidden_states", self.prev_hidden)
self.prev_hidden = None

class CheckpointBlockContext:
def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = False) -> None:
self.block = block
self.ctx_dict = ctx_dict
self._param_buffer = {}
self._grad_buffer = {}
self._param_tensor = {}
self._grad_tensor = {}
self._need_release = False
if pipe:
self.comm = config["zero_comm"]
else:
self.comm = config["comm"]
def enter(self, flag=0, requires_grad=False):
"""
gather parameters
"""
if self.block._ready:
return
self.block._ready = True
self._need_release = True

wait_loader()
with torch.cuda.stream(config["load_stream"]):
for kw, val in self.block._storage_info.items():
assert self.block._storage_params[kw].is_cuda
assert kw not in self._grad_buffer
assert kw not in self._param_buffer
local_param = self.block._storage_params[kw]

storage_type = local_param.storage_type()
if flag != 2:
self._param_buffer[kw] = storage_type(val["partition_size"] * val["world_size"])
self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw])

if requires_grad and local_param.requires_grad:
self._grad_buffer[kw] = storage_type(val["partition_size"] * val["world_size"])
self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_()
if flag != 2:
nccl.groupStart()
for kw, val in self.block._storage_info.items():
nccl.allGather(
self.block._storage_params[kw].storage(),
self._param_buffer[kw],
self.comm
)
nccl.groupEnd()

current_stream = torch.cuda.current_stream()
current_stream.wait_stream(config["load_stream"])

# set wait stream for each storage
for kw in self.block._storage_info.keys():
if flag != 2:
self._param_tensor[kw].record_stream(current_stream)
if requires_grad and kw in self._grad_tensor:
self._grad_tensor[kw].record_stream(current_stream)

# update parameters in block
for param in self.block._param_info:
kw_name = param["kw_name"]
offset = param["offset"]
shape = param["shape"]

if flag != 2:
dtype = self._param_buffer[kw_name].dtype
device = self._param_buffer[kw_name].device
param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape)
else:
dtype = param["parameter"].data.dtype
device = param["parameter"].data.device
param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.ctx_dict[kw_name], offset, shape)

if requires_grad and kw_name in self._grad_buffer and param["parameter"].requires_grad:
param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape)

def __enter__(self):
self.enter()

def exit(self, flag=0, backward=False):
"""
Reduce scatter gradients
"""

if not self._need_release:
return
self._need_release = False
self.block._ready = False
if backward:
for kw, val in self.block._storage_info.items():
local_param = self.block._storage_params[kw]

# accumulate previous gradient
if local_param.requires_grad:
if local_param.grad is None:
grad_storage = val["storage_type"](val["partition_size"]) # initialize gradient if not exist
local_param.grad = torch.tensor([], dtype=grad_storage.dtype, device=grad_storage.device).set_(grad_storage).zero_()
else:
self._grad_tensor[kw][val["begin"]:val["end"]] += local_param.grad

current_stream = torch.cuda.current_stream()
config["load_stream"].wait_stream(current_stream) # wait for backward

with torch.cuda.stream(config["load_stream"]):
nccl.groupStart()
for kw, val in self.block._storage_info.items():
local_param = self.block._storage_params[kw]

# scatter gradient
if local_param.requires_grad:
nccl.reduceScatter(
self._grad_buffer[kw],
local_param.grad.storage(),
"sum",
self.comm
)
nccl.groupEnd()

# set wait stream for each storage
for kw in self._grad_tensor.keys():
# grads can not be freed until reduce ops finish
self._grad_tensor[kw].record_stream(config["load_stream"])


# Release all parameters from buffer to block_storge
for param in self.block._param_info:
kw_name = param["kw_name"]
dtype = self.block._storage_params[kw_name].dtype
device = self.block._storage_params[kw_name].device
if "begin" not in param:
param["parameter"].data = torch.tensor([], dtype=dtype, device=device)
param["parameter"].grad = None
continue
begin = param["begin"]
end = param["end"]
param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end)
if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None:
param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end)
if flag == 1:
for i in self._param_buffer:
self.ctx_dict[i] = self._param_buffer[i]
self._grad_tensor = {}
self._param_tensor = {}
self._grad_buffer = {}
self._param_buffer = {}


def __exit__(self, exc_type, exc_val, exc_tb):
# reduce scatter gradients
self.exit()
111 changes: 111 additions & 0 deletions bmtrain/hook_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torch
from .global_var import config
from .checkpointing import CheckpointBlockContext
from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations

def zero_pre_forward(module, inputs):
enter = True
pipe = False
if module._mode == "PIPE":
enter = module._micro_idx == 0
pipe = True
if enter:
zero_level = config['zero_level']
forward_flag = 1 if zero_level == 2 else 0
if zero_level == 2 and module._ref_count > 1:
forward_flag = 2 # repeating forward in same layer
if module.all_param_no_grad: #only forward
forward_flag = 0
module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=pipe)
module._forward_block_ctx.enter(forward_flag)

def zero_post_forward(module, inputs, outputs):
forward_flag = 1 if config['zero_level'] == 2 else 0
if module.all_param_no_grad:
forward_flag = 0
exit = True
if module._mode == "PIPE":
exit = module._micro_idx == config['micros'] - 1

if exit:
module._forward_block_ctx.exit(forward_flag)
module._ref_count += 1

def zero_pre_backward(module, grad_outputs):
backward_flag = 2 if config['zero_level'] == 2 else 0
if module._mode != "PIPE":
module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict)
module._backward_block_ctx.enter(backward_flag, True)
if not module._is_last_layer:
module.next_module().backward_release(backward_flag)
else:
if module._micro_idx == config['micros'] - 1:
module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=True)
module._backward_block_ctx.enter(backward_flag, True)

def zero_post_backward(module, grad_inputs, grad_outputs):
backward_flag = 2 if config['zero_level'] == 2 else 0
if module._mode != "PIPE":
if module._is_first_layer:
module.backward_release(backward_flag)
else:
if module._micro_idx == 0:
module.backward_release(backward_flag)
module._micro_idx -= 1

class OneStepNoGradFunc(torch.autograd.Function):
"""
requires_grad = False for all inputs
"""
@staticmethod
def forward(ctx, module, placeholder, *x):
ctx.x = x
ctx.module = module
ctx.rng_state = torch.cuda.get_rng_state()

with torch.no_grad():
out = module._module(*x)
zero_post_forward(module, None, out)
if not isinstance(out, torch.Tensor):
return tuple(out)
return out

@staticmethod
def backward(ctx, grads):
zero_pre_backward(ctx.module, grads)
with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True):
torch.cuda.set_rng_state(ctx.rng_state)
x = ctx.x
with torch.enable_grad():
out = ctx.module._module(*x)
torch.autograd.backward(out, grads)
zero_post_backward(ctx.module, grads, None)
grads = []
for _ in x:
grads.append(None)
return None, None, *grads


class PreHookFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, module, *x):
ctx.module = module
zero_pre_forward(module, x)
return x

@staticmethod
def backward(ctx, *grads):
zero_post_backward(ctx.module, grads, None)
return None, *grads

class PostHookFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, module, *out):
ctx.module = module
zero_post_forward(module, None, out)
return out

@staticmethod
def backward(ctx, *grads):
zero_pre_backward(ctx.module, grads)
return None, *grads
1 change: 1 addition & 0 deletions bmtrain/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .linear import Linear
24 changes: 22 additions & 2 deletions example/layers/linear.py → bmtrain/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,26 @@
import torch.nn.functional as F
import bmtrain as bmt

class CustomLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias=None):
ctx.save_for_backward(x, weight, bias)
return F.linear(x, weight, bias)

@staticmethod
def backward(ctx, grad_output):
x, weight, bias = ctx.saved_tensors
grad_x = grad_weight = grad_bias = None
if x.requires_grad:
grad_x = grad_output.matmul(weight)
if weight.requires_grad:
dim = grad_output.dim()
grad_weight = grad_output.reshape(-1,
grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1]))
if bias is not None and bias.requires_grad:
grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
return grad_x, grad_weight, grad_bias

class Linear(bmt.DistributedModule):
def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
super().__init__()
Expand All @@ -15,9 +35,9 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp
self.register_parameter('bias', None)

def forward(self, input):
return F.linear(input, self.weight, self.bias)
return CustomLinear.apply(input, self.weight, self.bias)

def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
)
Loading