Skip to content

Commit

Permalink
Refactor ZeRO, checkpoint and pipeline code (#128)
Browse files Browse the repository at this point in the history
* using hooks to implement ZeRO and Checkpoint
---------

Co-authored-by: zhangkaihuo <[email protected]>
  • Loading branch information
zkh2016 and zhangkaihuo authored Aug 21, 2023
1 parent 75aa1a8 commit 74700e4
Show file tree
Hide file tree
Showing 24 changed files with 744 additions and 857 deletions.
4 changes: 4 additions & 0 deletions bmtrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,7 @@

from . import loss
from . import distributed
from . import nn
from . import optim
from . import inspect
from . import lr_scheduler
596 changes: 123 additions & 473 deletions bmtrain/block_layer.py

Large diffs are not rendered by default.

156 changes: 154 additions & 2 deletions bmtrain/checkpointing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from typing import Callable, TypeVar
from functools import wraps
from . import debug
from . import nccl
from .global_var import config
from .synchronize import wait_loader

class ScopedDebugTensorList:
def __init__(self) -> None:
Expand All @@ -28,3 +29,154 @@ def __exit__(self, *args):
self._local_list._set_hidden_states(debug.get("_inspect_hidden_states", []))
debug.set("_inspect_hidden_states", self.prev_hidden)
self.prev_hidden = None

class CheckpointBlockContext:
def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = False) -> None:
self.block = block
self.ctx_dict = ctx_dict
self._param_buffer = {}
self._grad_buffer = {}
self._param_tensor = {}
self._grad_tensor = {}
self._need_release = False
if pipe:
self.comm = config["zero_comm"]
else:
self.comm = config["comm"]
def enter(self, flag=0, requires_grad=False):
"""
gather parameters
"""
if self.block._ready:
return
self.block._ready = True
self._need_release = True

wait_loader()
with torch.cuda.stream(config["load_stream"]):
for kw, val in self.block._storage_info.items():
assert self.block._storage_params[kw].is_cuda
assert kw not in self._grad_buffer
assert kw not in self._param_buffer
local_param = self.block._storage_params[kw]

storage_type = local_param.storage_type()
if flag != 2:
self._param_buffer[kw] = storage_type(val["partition_size"] * val["world_size"])
self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw])

if requires_grad and local_param.requires_grad:
self._grad_buffer[kw] = storage_type(val["partition_size"] * val["world_size"])
self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_()
if flag != 2:
nccl.groupStart()
for kw, val in self.block._storage_info.items():
nccl.allGather(
self.block._storage_params[kw].storage(),
self._param_buffer[kw],
self.comm
)
nccl.groupEnd()

current_stream = torch.cuda.current_stream()
current_stream.wait_stream(config["load_stream"])

# set wait stream for each storage
for kw in self.block._storage_info.keys():
if flag != 2:
self._param_tensor[kw].record_stream(current_stream)
if requires_grad and kw in self._grad_tensor:
self._grad_tensor[kw].record_stream(current_stream)

# update parameters in block
for param in self.block._param_info:
kw_name = param["kw_name"]
offset = param["offset"]
shape = param["shape"]

if flag != 2:
dtype = self._param_buffer[kw_name].dtype
device = self._param_buffer[kw_name].device
param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape)
else:
dtype = param["parameter"].data.dtype
device = param["parameter"].data.device
param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.ctx_dict[kw_name], offset, shape)

if requires_grad and kw_name in self._grad_buffer and param["parameter"].requires_grad:
param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape)

def __enter__(self):
self.enter()

def exit(self, flag=0, backward=False):
"""
Reduce scatter gradients
"""

if not self._need_release:
return
self._need_release = False
self.block._ready = False
if backward:
for kw, val in self.block._storage_info.items():
local_param = self.block._storage_params[kw]

# accumulate previous gradient
if local_param.requires_grad:
if local_param.grad is None:
grad_storage = val["storage_type"](val["partition_size"]) # initialize gradient if not exist
local_param.grad = torch.tensor([], dtype=grad_storage.dtype, device=grad_storage.device).set_(grad_storage).zero_()
else:
self._grad_tensor[kw][val["begin"]:val["end"]] += local_param.grad

current_stream = torch.cuda.current_stream()
config["load_stream"].wait_stream(current_stream) # wait for backward

with torch.cuda.stream(config["load_stream"]):
nccl.groupStart()
for kw, val in self.block._storage_info.items():
local_param = self.block._storage_params[kw]

# scatter gradient
if local_param.requires_grad:
nccl.reduceScatter(
self._grad_buffer[kw],
local_param.grad.storage(),
"sum",
self.comm
)
nccl.groupEnd()

# set wait stream for each storage
for kw in self._grad_tensor.keys():
# grads can not be freed until reduce ops finish
self._grad_tensor[kw].record_stream(config["load_stream"])


# Release all parameters from buffer to block_storge
for param in self.block._param_info:
kw_name = param["kw_name"]
dtype = self.block._storage_params[kw_name].dtype
device = self.block._storage_params[kw_name].device
if "begin" not in param:
param["parameter"].data = torch.tensor([], dtype=dtype, device=device)
param["parameter"].grad = None
continue
begin = param["begin"]
end = param["end"]
param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end)
if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None:
param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end)
if flag == 1:
for i in self._param_buffer:
self.ctx_dict[i] = self._param_buffer[i]
self._grad_tensor = {}
self._param_tensor = {}
self._grad_buffer = {}
self._param_buffer = {}


def __exit__(self, exc_type, exc_val, exc_tb):
# reduce scatter gradients
self.exit()
110 changes: 110 additions & 0 deletions bmtrain/hook_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
from .global_var import config
from .checkpointing import CheckpointBlockContext

def zero_pre_forward(module, inputs):
enter = True
pipe = False
if module._mode == "PIPE":
enter = module._micro_idx == 0
pipe = True
if enter:
zero_level = config['zero_level']
forward_flag = 1 if zero_level == 2 else 0
if zero_level == 2 and module._ref_count > 1:
forward_flag = 2 # repeating forward in same layer
if module.all_param_no_grad: #only forward
forward_flag = 0
module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=pipe)
module._forward_block_ctx.enter(forward_flag)

def zero_post_forward(module, inputs, outputs):
forward_flag = 1 if config['zero_level'] == 2 else 0
if module.all_param_no_grad:
forward_flag = 0
exit = True
if module._mode == "PIPE":
exit = module._micro_idx == config['micros'] - 1

if exit:
module._forward_block_ctx.exit(forward_flag)
module._ref_count += 1

def zero_pre_backward(module, grad_outputs):
backward_flag = 2 if config['zero_level'] == 2 else 0
if module._mode != "PIPE":
module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict)
module._backward_block_ctx.enter(backward_flag, True)
if not module._is_last_layer:
module.next_module().backward_release(backward_flag)
else:
if module._micro_idx == config['micros'] - 1:
module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=True)
module._backward_block_ctx.enter(backward_flag, True)

def zero_post_backward(module, grad_inputs, grad_outputs):
backward_flag = 2 if config['zero_level'] == 2 else 0
if module._mode != "PIPE":
if module._is_first_layer:
module.backward_release(backward_flag)
else:
if module._micro_idx == 0:
module.backward_release(backward_flag)
module._micro_idx -= 1

class OneStepNoGradFunc(torch.autograd.Function):
"""
requires_grad = False for all inputs
"""
@staticmethod
def forward(ctx, module, placeholder, *x):
ctx.x = x
ctx.module = module
ctx.rng_state = torch.cuda.get_rng_state()

with torch.no_grad():
out = module._module(*x)
zero_post_forward(module, None, out)
if not isinstance(out, torch.Tensor):
return tuple(out)
return out

@staticmethod
def backward(ctx, grads):
zero_pre_backward(ctx.module, grads)
with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True):
torch.cuda.set_rng_state(ctx.rng_state)
x = ctx.x
with torch.enable_grad():
out = ctx.module._module(*x)
torch.autograd.backward(out, grads)
zero_post_backward(ctx.module, grads, None)
grads = []
for _ in x:
grads.append(None)
return None, None, *grads


class PreHookFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, module, *x):
ctx.module = module
zero_pre_forward(module, x)
return x

@staticmethod
def backward(ctx, *grads):
zero_post_backward(ctx.module, grads, None)
return None, *grads

class PostHookFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, module, *out):
ctx.module = module
zero_post_forward(module, None, out)
return out

@staticmethod
def backward(ctx, *grads):
zero_pre_backward(ctx.module, grads)
return None, *grads
9 changes: 9 additions & 0 deletions bmtrain/loss/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ def __init__(self,
self.inplace = inplace

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if input.dtype == torch.float32:
return torch.nn.functional.cross_entropy(
input,
target.long(),
weight=self.weight,
ignore_index=self.ignore_index,
reduction=self.reduction,
label_smoothing=self.label_smoothing)

if self.inplace:
ret = OpFusedCrossEntropyInplace.apply(input, target.int(), self.ignore_index) # return float tensor
else:
Expand Down
1 change: 1 addition & 0 deletions bmtrain/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .linear import Linear
24 changes: 22 additions & 2 deletions example/layers/linear.py → bmtrain/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,26 @@
import torch.nn.functional as F
import bmtrain as bmt

class CustomLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias=None):
ctx.save_for_backward(x, weight, bias)
return F.linear(x, weight, bias)

@staticmethod
def backward(ctx, grad_output):
x, weight, bias = ctx.saved_tensors
grad_x = grad_weight = grad_bias = None
if x.requires_grad:
grad_x = grad_output.matmul(weight)
if weight.requires_grad:
dim = grad_output.dim()
grad_weight = grad_output.reshape(-1,
grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1]))
if bias is not None and bias.requires_grad:
grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
return grad_x, grad_weight, grad_bias

class Linear(bmt.DistributedModule):
def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
super().__init__()
Expand All @@ -15,9 +35,9 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp
self.register_parameter('bias', None)

def forward(self, input):
return F.linear(input, self.weight, self.bias)
return CustomLinear.apply(input, self.weight, self.bias)

def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
)
Loading

0 comments on commit 74700e4

Please sign in to comment.