Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ZeRO, checkpoint and pipeline code #128

Merged
merged 73 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
bcea035
using hooks to implement ZeRO and Checkpoint
Jul 24, 2023
7b080e7
async backward
Jul 25, 2023
be5f9d7
async forward
Jul 25, 2023
dea1781
merge upstream
Jul 25, 2023
05bc553
fix
Jul 25, 2023
bdf7087
save cuda_rng_state
Jul 26, 2023
6a366e3
fix
Jul 27, 2023
25ef84f
fix
Jul 27, 2023
768f209
fix
Jul 27, 2023
324e0dd
remove __call__
Jul 31, 2023
0f4ddb5
refactor code structure
Jul 31, 2023
76c5c26
pipeline
Jul 31, 2023
16c0922
for low version
Jul 31, 2023
2d35ba0
for low torch version
Jul 31, 2023
bc48d83
for checkpoint
Jul 31, 2023
bd61071
remove unused code
Jul 31, 2023
de25455
remove duplicate code
Jul 31, 2023
fde122f
fix pipeline; checkpoint support low version
Aug 1, 2023
a897ad4
fix pipeline; checkpoint support low version
Aug 1, 2023
ca50795
merge remote
Aug 1, 2023
ec8385b
fix indent
Aug 1, 2023
9877a81
pipe support low version
Aug 2, 2023
28993b5
custom linear for zero3
Aug 2, 2023
4d43952
merge origin
Aug 3, 2023
e4eaebf
resolve conflict
Aug 3, 2023
cba7c55
resolve conflict
Aug 3, 2023
839a976
use torch.utils.checkpoint.checkpoint
Aug 3, 2023
d5bbf1a
custom hook
Aug 4, 2023
e92d0ef
optimize code structure
Aug 4, 2023
6ba753e
for hidden_state
Aug 4, 2023
b0a0da9
for input.requires_grad is False
Aug 4, 2023
f4a0e0b
fix
Aug 5, 2023
8faff0f
pipeline support return hidden_state
Aug 6, 2023
26c8c94
fix args
Aug 7, 2023
b7d1c8c
fix test
Aug 7, 2023
4303575
CheckpointBlock -> BMTBlock
Aug 8, 2023
8061b66
reset block name
Aug 8, 2023
845f210
pipeline support batch_related
Aug 8, 2023
0b14fe5
remove use_checkpoint from init_distributed
Aug 9, 2023
ae56de8
for requires_grad
Aug 10, 2023
27ae2b7
for requires_grad
Aug 10, 2023
fdc8231
fix for arg is not tensor
Aug 10, 2023
b0f7154
fix for arg is not a tensor
Aug 10, 2023
420b626
add test
Aug 10, 2023
b843489
Merge branch 'hook' of https://github.com/zkh2016/BMTrain into hook
Aug 10, 2023
ebc269f
merge enhance_ckp
Aug 11, 2023
2f1e766
enhance ckp
Aug 11, 2023
1e993c6
refactor code
Aug 12, 2023
24d0f59
mv linear to bmt.nn.linear
Aug 12, 2023
ff72e66
for enhance_ckp
Aug 12, 2023
1fbf3b2
fix for all input not grad
Aug 14, 2023
ace5216
fix pre_module
Aug 14, 2023
52cd4e2
fix pre_module
Aug 14, 2023
0b0bd0b
fix for all input no grad
Aug 14, 2023
05b49f8
fix for all input no grad
Aug 14, 2023
64eb672
Merge branch 'main' of https://github.com/OpenBMB/BMTrain into hook
Aug 16, 2023
88b5bd3
fix reentrant
Aug 17, 2023
9c2e47d
Merge branch 'hook' of https://github.com/zkh2016/BMTrain into hook
Aug 17, 2023
e93e6dc
Merge branch 'dev' into hook
Aug 18, 2023
fd49311
refactor CheckpointBlock
Aug 20, 2023
221bdc3
refactor pipe
Aug 20, 2023
76f74e5
Merge branch 'hook' of https://github.com/zkh2016/BMTrain into hook
Aug 20, 2023
9c63407
fix all input no grad
Aug 20, 2023
f72fcfc
fix hiddenstate
Aug 20, 2023
ebdf519
fix test
Aug 21, 2023
780ca20
fix
Aug 21, 2023
6df85e7
remove unused import
Aug 21, 2023
bb482d6
fix pre_module
Aug 21, 2023
1010d26
recovery some code
Aug 21, 2023
b580530
add test_no_grad.py
Aug 21, 2023
767a875
test unroll block list
Aug 21, 2023
d19a627
fix test_fp32
Aug 21, 2023
bf986a7
cross_entropy support fp32
Aug 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 34 additions & 42 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .synchronize import wait_loader
from .parameter import DistributedParameter, OpAllGather
from .checkpointing import (
ScopedTensorInspectorContext,
CheckpointBlockContext
)

Expand Down Expand Up @@ -50,32 +49,6 @@ def _get_param_kw(param : DistributedParameter):
group_name = "_g_" + param.group
return type_name + grad_name + group_name

class BMTBlockContext:
def __init__(self):
self._pre_module = None
self._first = True

def link_module(self, module):
if not self._first and module._ref_count == -1:
self._pre_module = module
module._ref_count = 1
return

if self._pre_module is None:
module._ref_count = 1
module._is_first_layer = True
else:
if module._ref_count == 0:
module._is_first_layer = False
self._pre_module.set_next_module(module)
self._pre_module._is_last_layer = False
self._pre_module = module
self._first = False

def clear(self):
self._pre_module = None
self._first = True

class CheckpointBlock(torch.nn.Module):
zkh2016 marked this conversation as resolved.
Show resolved Hide resolved
""" A bmtrain block containing two memory-saving methods of ZeRO-2/3 and checkpoint.

Expand All @@ -94,7 +67,7 @@ class CheckpointBlock(torch.nn.Module):
>>> y2, ... = transformer_block(x)
>>> assert torch.allclose(y1, y2)
"""
def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, block_context=None):
def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True):
super().__init__()
self._module = inner_module
self._inputs = None
Expand Down Expand Up @@ -222,25 +195,35 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, block_co
self.use_checkpoint = use_checkpoint
self._is_first_layer = True
self._is_last_layer = True
self._pre_module = []
self._next_module = []
self._ref_count = 0
self._release_list = [True]
self._next_module = [] #save the next module of self
self._pre_module = [] #save the pre module of self
self._ref_count = 0 #incremental in forward and decreasing in backward
self._mode = "BLOCK" #BLOCK or ZERO or PIPE
self.return_hidden_states = False
self.hidden_states = []
self.block_context = block_context
if block_context is None:
self.block_context = config['block_context'][config['rank']]
self.all_input_no_grad = False
self.all_param_no_grad = False

def set_next_module(self, module):
self._next_module.append(module)
module._pre_module.append(self)
module._ref_count += 1
def set_pre_module(self, pre_module):
if pre_module is not None:
self._pre_module.append(pre_module)
pre_module._next_module.append(self)

def pre_module(self):
return self._pre_module[self._ref_count-1]

def next_module(self):
assert len(self._next_module) == self._ref_count, "{} != {}".format(len(self._next_module), self._ref_count)
Achazwl marked this conversation as resolved.
Show resolved Hide resolved
return self._next_module[self._ref_count-1]

def backward_release(self, flag):
if self._ref_count == 1:
self._backward_block_ctx.exit(flag, True)
config['load_stream'].record_event(config['load_event'])
self._ref_count -= 1

def pre_hook(self, *args):
if self._mode != "PIPE":
self.block_context.link_module(self)
grad_tensors = []
grad_index = []
arg_list = list(args)
Expand All @@ -255,9 +238,11 @@ def pre_hook(self, *args):
arg_list[grad_index[i]] = pre_out[i]

if self._mode != "PIPE" and len(grad_tensors) == 0:
self.all_param_no_grad = True
for param in self._param_info:
if param['parameter'].requires_grad:
param['parameter'].register_hook(lambda grad: hook_func.zero_post_backward(self, grad, None))
self.all_param_no_grad = False
break
self.all_input_no_grad = True
else:
Expand Down Expand Up @@ -537,16 +522,23 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False)
super().__init__()

self._modules = {}
release_list = []
pre_module = None
for i, module in enumerate(modules):
if not isinstance(module, CheckpointBlock):
module = CheckpointBlock(module)

Achazwl marked this conversation as resolved.
Show resolved Hide resolved
module._mode = "ZERO"
module._is_last_layer = True if i == len(modules) -1 else False
module._is_first_layer = True if i == 0 else False
module.set_pre_module(pre_module)
pre_module = module
self._is_first_layer = False
self._is_last_layer = False

self._modules[str(i)] = module
self.add_module(str(i), module)

self._modules[str(0)]._is_first_layer = True
self._modules[str(len(modules)-1)]._is_last_layer = True

self.num_hidden = num_hidden

Expand Down
21 changes: 6 additions & 15 deletions bmtrain/hook_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,16 @@ def zero_post_forward(module, inputs, outputs):

if exit:
module._forward_block_ctx.exit(forward_flag)
if module._mode != "PIPE":
module._ref_count += 1

def zero_pre_backward(module, grad_outputs):
backward_flag = 2 if config['zero_level'] == 2 else 0
if module._mode != "PIPE":
module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict)
module._backward_block_ctx.enter(backward_flag, True)
if not module._is_last_layer and len(module._next_module) > 0 and module._next_module[-1]._backward_block_ctx is not None:
if module._next_module[-1]._ref_count == 1:
module._next_module[-1]._ref_count = 0
module._next_module.pop()._backward_block_ctx.exit(backward_flag, True)
config['load_stream'].record_event(config['load_event'])
else:
module._next_module[-1]._ref_count -= 1

if not module._is_last_layer:
module.next_module().backward_release(backward_flag)
else:
if module._micro_idx == config['micros'] - 1:
module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=True)
Expand All @@ -47,15 +43,10 @@ def zero_pre_backward(module, grad_outputs):
def zero_post_backward(module, grad_inputs, grad_outputs):
backward_flag = 2 if config['zero_level'] == 2 else 0
if module._mode != "PIPE":
if module._is_first_layer and module._ref_count == 1:
module._backward_block_ctx.exit(backward_flag, True)
module._ref_count = -1
config['load_stream'].record_event(config['load_event'])
if not module._is_first_layer and len(module._pre_module) > 0:
module._pre_module.pop()
if module._is_first_layer:
module.backward_release(backward_flag)
else:
if module._micro_idx == 0:
module._ref_count = -1 if module._is_first_layer else 0
module._backward_block_ctx.exit(backward_flag, True)
config['load_stream'].record_event(config['load_event'])

Expand Down
4 changes: 0 additions & 4 deletions bmtrain/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .global_var import config
from . import nccl
from .synchronize import synchronize
from .block_layer import BMTBlockContext

def init_distributed(
init_method : str = "env://",
Expand Down Expand Up @@ -74,9 +73,6 @@ def init_distributed(
config["zero_level"] = zero_level
config["topology"] = topology(config)
config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank']
config["block_context"] = []
for i in range(world_size):
config["block_context"].append(BMTBlockContext())
cpus_this_worker = None

all_available_cpus = sorted(list(os.sched_getaffinity(0)))
Expand Down