diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index 26c3760d..ae243e65 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -18,3 +18,7 @@ from . import loss from . import distributed +from . import nn +from . import optim +from . import inspect +from . import lr_scheduler diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 8647b7ff..8e52c68a 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -4,279 +4,18 @@ from .global_var import config import torch from . import nccl -from .synchronize import wait_loader from .parameter import DistributedParameter, OpAllGather -from .checkpointing import ScopedTensorInspectorContext -from . import debug -import copy -import inspect +from .checkpointing import ( + CheckpointBlockContext +) +from . import debug -# the flag is used to control the zero level , 0 means normal zero3 , 1 means forward without release parameter ,2 means backward without gather parameter -class OpCheckpointBlock(torch.autograd.Function): - @staticmethod - def forward(ctx, placeholder, block : 'CheckpointBlock', preserve_rng_state, len_args, *args): - ctx.block = block - ctx.preserve_rng_state = preserve_rng_state - - ctx.cuda_rng_state = torch.cuda.get_rng_state() if preserve_rng_state else None - tensors = [] - others = [] - for arg in args: - if torch.is_tensor(arg): - tensors.append(arg) - others.append(None) - else: - tensors.append(None) - others.append(arg) - - ctx.nontensor_inputs = others - ctx.len_args = len_args - ctx.save_for_backward(*tensors) - ctx.param_dict={} - if config['zero_level'] == 2: - flag = 1 - else: - flag = 0 - with torch.no_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(block, ctx.param_dict, flag): - inp_args = args[:len_args] - inp_kwargs = {} - for k, v in zip(args[len_args::2], args[len_args + 1::2]): - inp_kwargs[k] = v - outputs = ctx.block._module._call_impl(*inp_args, **inp_kwargs) - for it in inspector.hidden_states: - debug.append("_inspect_hidden_states", it) - ctx.inspect_list = inspector.hidden_states - - if not isinstance(outputs, list) and not isinstance(outputs, tuple): - outputs = [outputs] - len_outputs = 0 - else: - outputs = list(outputs) - len_outputs = len(outputs) - return tuple([len_outputs] + outputs + [hidden_state["tensor"] for hidden_state in inspector.hidden_states]) - - @staticmethod - def backward(ctx, _, *grads): - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad() or when an `inputs` parameter" - " is passed to .backward(). Please use .backward() and do not pass its `inputs`" - " argument.") - - all_inputs = [] - input_reqires_grad = [] - len_args = ctx.len_args - for tensor, other in zip(ctx.saved_tensors, ctx.nontensor_inputs): - if tensor is None: - all_inputs.append(other) - input_reqires_grad.append(False) - else: - input_reqires_grad.append( tensor.requires_grad ) - nw_tensor = tensor.detach() - nw_tensor.requires_grad = tensor.requires_grad - all_inputs.append(nw_tensor) - - - with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=ctx.preserve_rng_state): - if ctx.preserve_rng_state: - torch.cuda.set_rng_state(ctx.cuda_rng_state) - if config['zero_level'] == 2: - flag = 2 - else: - flag = 0 - with torch.enable_grad(), CheckpointBlockContext(ctx.block, ctx.param_dict, flag): - inp_args = all_inputs[:len_args] - inp_kwargs = {} - for k, v in zip(all_inputs[len_args::2], all_inputs[len_args + 1::2]): - inp_kwargs[k] = v - with ScopedTensorInspectorContext() as inspector: - outputs = ctx.block._module._call_impl(*inp_args, **inp_kwargs) - if not isinstance(outputs, tuple): - outputs = (outputs,) - - assert len(outputs) + len(inspector.hidden_states) == len(grads) - - outputs_with_grad = [] - grad_of_output = [] - for i, output in enumerate(outputs): - if torch.is_tensor(output) and output.requires_grad: - outputs_with_grad.append(output) - grad_of_output.append(grads[i]) - - # calculate gradients for inputs, also for parameters - torch.autograd.backward( - outputs_with_grad + [hidden_state["tensor"] for hidden_state in inspector.hidden_states], - grad_of_output + list(grads[len(outputs):]), - ) - assert len(ctx.inspect_list) == len(inspector.hidden_states), "Backward step changed" - for i, it in enumerate(inspector.hidden_states): - assert it["name"] == ctx.inspect_list[i]["name"], "Backward step changed" - assert it["shape"] == ctx.inspect_list[i]["shape"], "Backward step changed" - assert it["group"] == ctx.inspect_list[i]["group"], "Backward step changed" - - # change the tensor in placeholder - ctx.inspect_list[i]["tensor"] = it["tensor"] - ctx.inspect_list[i]["requires_grad"] = it["requires_grad"] - - grads = [] - for inp, requires_grad in zip(all_inputs, input_reqires_grad): - if requires_grad: - grads.append(inp.grad) - else: - grads.append(None) - return (None, None, None, None) + tuple(grads) - -class CheckpointBlockContext: - def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, flag : int = 0, pipe = False) -> None: - self.block = block - self.ctx_dict = ctx_dict - self._param_buffer = {} - self._grad_buffer = {} - self._param_tensor = {} - self._grad_tensor = {} - self.flag = flag - self._need_release = False - if pipe: - self.comm = config["zero_comm"] - else: - self.comm = config["comm"] - def enter(self): - """ - gather parameters - """ - if self.block._ready: - return - self.block._ready = True - self._need_release = True - - wait_loader() - requires_grad = torch.is_grad_enabled() - with torch.cuda.stream(config["load_stream"]): - for kw, val in self.block._storage_info.items(): - assert self.block._storage_params[kw].is_cuda - assert kw not in self._grad_buffer - assert kw not in self._param_buffer - local_param = self.block._storage_params[kw] - - storage_type = local_param.storage_type() - if self.flag != 2: - self._param_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) - self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) - - if requires_grad and local_param.requires_grad: - self._grad_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) - self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_() - if self.flag != 2: - nccl.groupStart() - for kw, val in self.block._storage_info.items(): - nccl.allGather( - self.block._storage_params[kw].storage(), - self._param_buffer[kw], - self.comm - ) - nccl.groupEnd() - - current_stream = torch.cuda.current_stream() - current_stream.wait_stream(config["load_stream"]) - - # set wait stream for each storage - for kw in self.block._storage_info.keys(): - if self.flag != 2: - self._param_tensor[kw].record_stream(current_stream) - if requires_grad and kw in self._grad_tensor: - self._grad_tensor[kw].record_stream(current_stream) - - # update parameters in block - for param in self.block._param_info: - kw_name = param["kw_name"] - offset = param["offset"] - shape = param["shape"] - - if self.flag != 2: - dtype = self._param_buffer[kw_name].dtype - device = self._param_buffer[kw_name].device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) - else: - dtype = param["parameter"].data.dtype - device = param["parameter"].data.device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.ctx_dict[kw_name], offset, shape) - - if requires_grad and kw_name in self._grad_buffer and param["parameter"].requires_grad: - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) - - def __enter__(self): - self.enter() - - def exit(self): - """ - Reduce scatter gradients - """ +from . import hook_func - if not self._need_release: - return - self._need_release = False - self.block._ready = False - requires_grad = torch.is_grad_enabled() - if requires_grad: - for kw, val in self.block._storage_info.items(): - local_param = self.block._storage_params[kw] - - # accumulate previous gradient - if local_param.requires_grad: - if local_param.grad is None: - grad_storage = val["storage_type"](val["partition_size"]) # initialize gradient if not exist - local_param.grad = torch.tensor([], dtype=grad_storage.dtype, device=grad_storage.device).set_(grad_storage).zero_() - else: - self._grad_tensor[kw][val["begin"]:val["end"]] += local_param.grad - - current_stream = torch.cuda.current_stream() - config["load_stream"].wait_stream(current_stream) # wait for backward - - with torch.cuda.stream(config["load_stream"]): - nccl.groupStart() - for kw, val in self.block._storage_info.items(): - local_param = self.block._storage_params[kw] - - # scatter gradient - if local_param.requires_grad: - nccl.reduceScatter( - self._grad_buffer[kw], - local_param.grad.storage(), - "sum", - self.comm - ) - nccl.groupEnd() - - # set wait stream for each storage - for kw in self._grad_tensor.keys(): - # grads can not be freed until reduce ops finish - self._grad_tensor[kw].record_stream(config["load_stream"]) - - # Release all parameters from buffer to block_storge - for param in self.block._param_info: - kw_name = param["kw_name"] - dtype = self.block._storage_params[kw_name].dtype - device = self.block._storage_params[kw_name].device - if "begin" not in param: - param["parameter"].data = torch.tensor([], dtype=dtype, device=device) - param["parameter"].grad = None - continue - begin = param["begin"] - end = param["end"] - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) - if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None: - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) - if self.flag == 1: - for i in self._param_buffer: - self.ctx_dict[i] = self._param_buffer[i] - self._grad_tensor = {} - self._param_tensor = {} - self._grad_buffer = {} - self._param_buffer = {} - def __exit__(self, exc_type, exc_val, exc_tb): - # reduce scatter gradients - self.exit() +import copy +import inspect +from torch.utils.checkpoint import checkpoint def storage_type_cuda(storage_type): STORAGE_MAP = { @@ -310,7 +49,7 @@ def _get_param_kw(param : DistributedParameter): return type_name + grad_name + group_name class CheckpointBlock(torch.nn.Module): - """ Checkpoint a model or part of the model. + """ A bmtrain block containing two memory-saving methods of ZeRO-2/3 and checkpoint. Checkpoint block is used to save the occupation of GPU memory in training. @@ -318,6 +57,7 @@ class CheckpointBlock(torch.nn.Module): Args: model (torch.nn.Module): The model to be checkpointed. All kinds of modules are supported. + use_checkpoint (boolean): use checkpoint or not. Default True. Examples: >>> transformer_block = TransformerBlock(...) @@ -326,9 +66,13 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): super().__init__() self._module = inner_module + self._inputs = None + self._layer_dict = {} + self._forward_block_ctx = None + self._backward_block_ctx = None # build large parameter&grad here self._param_info = [] self._storage_params : Dict[str, torch.nn.Parameter] = {} @@ -440,23 +184,88 @@ def __init__(self, inner_module : torch.nn.Module): del contiguous_param else: param.data = torch.tensor([], dtype=param.dtype, device=param.device) - # clear parameter data, but keep the dtype and device setattr(param, "_in_checkpoint_block", True) for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] - - def __call__(self, *args, **kwargs): - # gather here - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - all_inputs = list(args) - for kw, val in kwargs.items(): - all_inputs.append(kw) - all_inputs.append(val) - outputs = OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs) - len_output = outputs[0] - return outputs[1:1+len_output] if len_output > 0 else outputs[1] + + self.use_checkpoint = use_checkpoint + self._is_first_layer = True + self._is_last_layer = True + self._release_list = [True] + self._next_module = [] #save the next module of self + self._pre_module = [] #save the pre module of self + self._ref_count = 0 #incremental in forward and decreasing in backward + self._mode = "BLOCK" #BLOCK or ZERO or PIPE + self.all_input_no_grad = False + self.all_param_no_grad = False + + def set_pre_module(self, pre_module): + if pre_module is not None: + self._pre_module.append(pre_module) + pre_module._next_module.append(self) + + def pre_module(self): + assert len(self._pre_module) == self._ref_count, "{} != {}".format(len(self._pre_module), self._ref_count) + return self._pre_module[self._ref_count-1] + + def next_module(self): + assert len(self._next_module) == self._ref_count, "{} != {}".format(len(self._next_module), self._ref_count) + return self._next_module[self._ref_count-1] + + def backward_release(self, flag): + if self._ref_count == 1: + self._backward_block_ctx.exit(flag, True) + config['load_stream'].record_event(config['load_event']) + self._ref_count -= 1 + + def pre_hook(self, *args): + grad_tensors = [] + grad_index = [] + arg_list = list(args) + for i, arg in enumerate(args): + if arg is not None and isinstance(arg, torch.Tensor) and arg.requires_grad: + grad_tensors.append(arg) + grad_index.append(i) + grad_tensors = tuple(grad_tensors) + + pre_out = hook_func.PreHookFunc.apply(self, *grad_tensors) + for i in range(len(grad_index)): + arg_list[grad_index[i]] = pre_out[i] + + if self._mode != "PIPE" and len(grad_tensors) == 0: + self.all_param_no_grad = True + for param in self._param_info: + if param['parameter'].requires_grad: + self.all_param_no_grad = False + break + self.all_input_no_grad = True + else: + self.all_input_no_grad = False + return arg_list + + def post_hook(self, out): + tuple_out = (out, ) if isinstance(out, torch.Tensor) else out + post_out = hook_func.PostHookFunc.apply(self, *tuple_out) + if isinstance(out, torch.Tensor) and isinstance(post_out, tuple): + return post_out[0] + post_out = tuple(post_out) + return post_out + + def forward(self, *args): + arg_list = self.pre_hook(*args) + + if self.all_input_no_grad and not self.all_param_no_grad: + placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) + return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list) + + if self.use_checkpoint: + out = checkpoint(self._module, *arg_list, use_reentrant=not self.all_input_no_grad) + else: + out = self._module(*arg_list) + + return self.post_hook(out) def __getattr__(self,name:str): if name=="_module": @@ -619,6 +428,7 @@ def init_parameters(self): param.data[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:] del tmp_tensor + def _named_members(self, get_members_fn, prefix='', recurse=True, **kwargs): r"""Helper method for yielding various names + members of modules.""" @@ -685,192 +495,6 @@ def eval(self): def __repr__(self): return self._module.__repr__() -class OpTransformerBlockList(torch.autograd.Function): - @staticmethod - def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, num_hidden, *args): - tensors = [] - others = [] - for arg in args[num_hidden:]: - if torch.is_tensor(arg): - tensors.append(arg) - others.append(None) - else: - tensors.append(None) - others.append(arg) - hidden_states = args[:num_hidden] - - ctx.nontensor_inputs = others - ctx.self = self - ctx.save_list = copy.deepcopy(save_list) - ctx.num_save_needed = save_list[-1][1]+1 - ctx.layers_dict = [{} for _ in range(len(self))] - layer_inputs = [] - layer_inspector = [] - cuda_rng_state = [] - for i in range(len(self)): - with torch.no_grad(): - if save_list[i][0] == i: - layer_inputs += [hidden_state.detach() for hidden_state in hidden_states] - cuda_rng_state.append( torch.cuda.get_rng_state() ) - if config['zero_level']==2: - flag = 1 - else: - flag = 0 - block_ctx = CheckpointBlockContext(self._modules[str(i)], ctx.layers_dict[i], flag) - # gather parameter on load stream - block_ctx.enter() - # call inner module directly - with ScopedTensorInspectorContext() as inspector: - hidden_states = self._modules[str(i)]._module._call_impl(*hidden_states, *args[num_hidden:]) - if not isinstance(hidden_states, tuple): - hidden_states = (hidden_states,) - block_ctx.exit() - for it in inspector.hidden_states: - debug.append("_inspect_hidden_states", it) - layer_inspector.append(inspector.hidden_states) - - ctx.layer_inspector = layer_inspector - ctx.cuda_rng_state = cuda_rng_state - ctx.num_hidden = num_hidden - - ctx.save_for_backward(*layer_inputs, *tensors) - - if self.return_hidden_states: - middle_hiddens = layer_inputs - for mid in middle_hiddens: - mid.requires_grad_() - middle_hiddens = [ - torch.stack(middle_hiddens[i::num_hidden], dim=0) - for i in range(num_hidden) - ] - else: - middle_hiddens = [None] * num_hidden - return tuple(list(hidden_states) + middle_hiddens + [it["tensor"] for inspector_hiddens in ctx.layer_inspector for it in inspector_hiddens]) - - - @staticmethod - def backward(ctx, *grads): - grad_hidden_states = grads[:ctx.num_hidden] - grad_middles = grads[ctx.num_hidden:2*ctx.num_hidden] - grad_inspectors = grads[2*ctx.num_hidden:] - def exit_prev(prev_ctx, prev_grad): - if prev_ctx is not None: - if prev_grad: - with torch.enable_grad(): - prev_ctx.exit() - config["load_stream"].record_event(config["load_event"]) - else: - with torch.no_grad(): - prev_ctx.exit() - config["load_stream"].record_event(config["load_event"]) - - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad() or when an `inputs` parameter" - " is passed to .backward(). Please use .backward() and do not pass its `inputs`" - " argument.") - all_inputs = [] - input_requires_grad = [] - - layer_inputs = ctx.saved_tensors[:ctx.num_save_needed * ctx.num_hidden] - save_args = ctx.saved_tensors[ctx.num_save_needed * ctx.num_hidden:] - for tensor, other in zip(save_args, ctx.nontensor_inputs): - if tensor is None: - all_inputs.append(other) - input_requires_grad.append(False) - else: - # detach for tensor inputs - input_requires_grad.append( tensor.requires_grad ) - nw_tensor = tensor.detach() - nw_tensor.requires_grad = tensor.requires_grad - all_inputs.append(nw_tensor) - - with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): - with torch.enable_grad(): - # overlap load and scatter here - prev_ctx = None - prev_grad = False - for i in reversed(range(len(ctx.self))): - if ctx.save_list[i][0] != i: - with torch.no_grad(): - st = ctx.save_list[i][0] - for j in range(st, i): - torch.cuda.set_rng_state(ctx.cuda_rng_state[j]) - if config['zero_level'] == 2: - flag = 2 - else: - flag = 0 - block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)], ctx.layers_dict[j], flag) - block_ctx.enter() - exit_prev(prev_ctx, prev_grad) - outputs = ctx.self._modules[str(j)]._module._call_impl( - layer_inputs[ctx.save_list[j][1]*ctx.num_hidden: ctx.save_list[j][1]*ctx.num_hidden+ctx.num_hidden], - *all_inputs - ) - if not isinstance(outputs, tuple): - outputs = (outputs,) - prev_ctx = block_ctx - prev_grad = False - for k, output in enumerate(outputs): - layer_inputs[ctx.save_list[j+1][1]*ctx.num_hidden + k].copy_(output) - ctx.save_list[j+1][0] = j+1 - - torch.cuda.set_rng_state(ctx.cuda_rng_state[i]) - ipts = [ - layer_inputs[ctx.save_list[i][1]*ctx.num_hidden + k].detach().requires_grad_() - for k in range(ctx.num_hidden) - ] - if config['zero_level'] == 2: - flag = 2 - else: - flag = 0 - block_ctx = CheckpointBlockContext(ctx.self._modules[str(i)], ctx.layers_dict[i], flag) - block_ctx.enter() - exit_prev(prev_ctx, prev_grad) - prev_ctx = block_ctx - prev_grad = True - - with ScopedTensorInspectorContext() as inspector: - outputs = ctx.self._modules[str(i)]._module._call_impl(*ipts, *all_inputs) - if not isinstance(outputs, tuple): - outputs = (outputs,) - - assert len(ctx.layer_inspector[i]) == len(inspector.hidden_states), "Backward step changed" - for j, it in enumerate(inspector.hidden_states): - assert it["name"] == ctx.layer_inspector[i][j]["name"], "Backward step changed" - assert it["shape"] == ctx.layer_inspector[i][j]["shape"], "Backward step changed" - assert it["group"] == ctx.layer_inspector[i][j]["group"], "Backward step changed" - - # change the tensor in placeholder - ctx.layer_inspector[i][j]["tensor"] = it["tensor"] - ctx.layer_inspector[i][j]["requires_grad"] = it["requires_grad"] - if len(inspector.hidden_states) > 0: - torch.autograd.backward( - list(outputs) + [hidden_state["tensor"] for hidden_state in inspector.hidden_states], - grad_hidden_states + grad_inspectors[-len(inspector.hidden_states):], - ) - grad_inspectors = grad_inspectors[:-len(inspector.hidden_states)] - else: - torch.autograd.backward( - outputs, - grad_hidden_states, - ) - grad_hidden_states = [ipt.grad for ipt in ipts] - for k in range(ctx.num_hidden): - if grad_middles[k] is not None: - grad_hidden_states[k] = grad_hidden_states[k] + grad_middles[k][i] - grad_hidden_states = tuple(grad_hidden_states) - - exit_prev(prev_ctx, prev_grad) - - grads = [] - for inp, requires_grad in zip(all_inputs, input_requires_grad): - if requires_grad: - grads.append(inp.grad) - else: - grads.append(None) - return (None, None, None, None) + tuple(grad_hidden_states) + tuple(grads) - class TransformerBlockList(torch.nn.Module): r""" TransformerBlockList is a list of CheckpointBlocks. @@ -896,12 +520,23 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) super().__init__() self._modules = {} + pre_module = None for i, module in enumerate(modules): if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) + + module._mode = "ZERO" + module.set_pre_module(pre_module) + pre_module = module + self._is_first_layer = False + self._is_last_layer = False + self._modules[str(i)] = module self.add_module(str(i), module) + self._modules[str(0)]._is_first_layer = True + self._modules[str(len(modules)-1)]._is_last_layer = True + self.num_hidden = num_hidden if sqrt: @@ -928,6 +563,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) def __len__(self) -> int: return len(self._modules) + def __iter__(self) -> Iterator[CheckpointBlock]: return iter(self._modules.values()) def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: @@ -935,9 +571,23 @@ def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: def forward(self, *args, return_hidden_states = False): self.return_hidden_states = return_hidden_states - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - outputs = OpTransformerBlockList.apply(placeholder, self, self.save_list, self.num_hidden, *args) + hidden_states = [] + for i in range(len(self)): + if return_hidden_states: + for hidden_state in args[:self.num_hidden]: + hidden_states.append(hidden_state) + outputs = self._modules[str(i)]._call_impl(*args) + if not isinstance(outputs, tuple): + outputs = (outputs, ) + args = outputs + args[self.num_hidden:] + + if return_hidden_states: + hidden_states = [ + torch.stack(hidden_states[i::self.num_hidden], dim=0) + for i in range(self.num_hidden) + ] + if return_hidden_states: - return tuple(outputs[:2*self.num_hidden]) + return outputs + tuple(hidden_states) else: - return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] \ No newline at end of file + return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index ac6a8d4f..b2c9ec07 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -1,7 +1,8 @@ import torch -from typing import Callable, TypeVar -from functools import wraps from . import debug +from . import nccl +from .global_var import config +from .synchronize import wait_loader class ScopedDebugTensorList: def __init__(self) -> None: @@ -28,3 +29,154 @@ def __exit__(self, *args): self._local_list._set_hidden_states(debug.get("_inspect_hidden_states", [])) debug.set("_inspect_hidden_states", self.prev_hidden) self.prev_hidden = None + +class CheckpointBlockContext: + def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = False) -> None: + self.block = block + self.ctx_dict = ctx_dict + self._param_buffer = {} + self._grad_buffer = {} + self._param_tensor = {} + self._grad_tensor = {} + self._need_release = False + if pipe: + self.comm = config["zero_comm"] + else: + self.comm = config["comm"] + def enter(self, flag=0, requires_grad=False): + """ + gather parameters + """ + if self.block._ready: + return + self.block._ready = True + self._need_release = True + + wait_loader() + with torch.cuda.stream(config["load_stream"]): + for kw, val in self.block._storage_info.items(): + assert self.block._storage_params[kw].is_cuda + assert kw not in self._grad_buffer + assert kw not in self._param_buffer + local_param = self.block._storage_params[kw] + + storage_type = local_param.storage_type() + if flag != 2: + self._param_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) + self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) + + if requires_grad and local_param.requires_grad: + self._grad_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) + self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_() + if flag != 2: + nccl.groupStart() + for kw, val in self.block._storage_info.items(): + nccl.allGather( + self.block._storage_params[kw].storage(), + self._param_buffer[kw], + self.comm + ) + nccl.groupEnd() + + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config["load_stream"]) + + # set wait stream for each storage + for kw in self.block._storage_info.keys(): + if flag != 2: + self._param_tensor[kw].record_stream(current_stream) + if requires_grad and kw in self._grad_tensor: + self._grad_tensor[kw].record_stream(current_stream) + + # update parameters in block + for param in self.block._param_info: + kw_name = param["kw_name"] + offset = param["offset"] + shape = param["shape"] + + if flag != 2: + dtype = self._param_buffer[kw_name].dtype + device = self._param_buffer[kw_name].device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) + else: + dtype = param["parameter"].data.dtype + device = param["parameter"].data.device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.ctx_dict[kw_name], offset, shape) + + if requires_grad and kw_name in self._grad_buffer and param["parameter"].requires_grad: + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) + + def __enter__(self): + self.enter() + + def exit(self, flag=0, backward=False): + """ + Reduce scatter gradients + """ + + if not self._need_release: + return + self._need_release = False + self.block._ready = False + if backward: + for kw, val in self.block._storage_info.items(): + local_param = self.block._storage_params[kw] + + # accumulate previous gradient + if local_param.requires_grad: + if local_param.grad is None: + grad_storage = val["storage_type"](val["partition_size"]) # initialize gradient if not exist + local_param.grad = torch.tensor([], dtype=grad_storage.dtype, device=grad_storage.device).set_(grad_storage).zero_() + else: + self._grad_tensor[kw][val["begin"]:val["end"]] += local_param.grad + + current_stream = torch.cuda.current_stream() + config["load_stream"].wait_stream(current_stream) # wait for backward + + with torch.cuda.stream(config["load_stream"]): + nccl.groupStart() + for kw, val in self.block._storage_info.items(): + local_param = self.block._storage_params[kw] + + # scatter gradient + if local_param.requires_grad: + nccl.reduceScatter( + self._grad_buffer[kw], + local_param.grad.storage(), + "sum", + self.comm + ) + nccl.groupEnd() + + # set wait stream for each storage + for kw in self._grad_tensor.keys(): + # grads can not be freed until reduce ops finish + self._grad_tensor[kw].record_stream(config["load_stream"]) + + + # Release all parameters from buffer to block_storge + for param in self.block._param_info: + kw_name = param["kw_name"] + dtype = self.block._storage_params[kw_name].dtype + device = self.block._storage_params[kw_name].device + if "begin" not in param: + param["parameter"].data = torch.tensor([], dtype=dtype, device=device) + param["parameter"].grad = None + continue + begin = param["begin"] + end = param["end"] + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) + if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None: + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) + if flag == 1: + for i in self._param_buffer: + self.ctx_dict[i] = self._param_buffer[i] + self._grad_tensor = {} + self._param_tensor = {} + self._grad_buffer = {} + self._param_buffer = {} + + + def __exit__(self, exc_type, exc_val, exc_tb): + # reduce scatter gradients + self.exit() diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py new file mode 100644 index 00000000..6a56300e --- /dev/null +++ b/bmtrain/hook_func.py @@ -0,0 +1,110 @@ +import torch +from .global_var import config +from .checkpointing import CheckpointBlockContext + +def zero_pre_forward(module, inputs): + enter = True + pipe = False + if module._mode == "PIPE": + enter = module._micro_idx == 0 + pipe = True + if enter: + zero_level = config['zero_level'] + forward_flag = 1 if zero_level == 2 else 0 + if zero_level == 2 and module._ref_count > 1: + forward_flag = 2 # repeating forward in same layer + if module.all_param_no_grad: #only forward + forward_flag = 0 + module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=pipe) + module._forward_block_ctx.enter(forward_flag) + +def zero_post_forward(module, inputs, outputs): + forward_flag = 1 if config['zero_level'] == 2 else 0 + if module.all_param_no_grad: + forward_flag = 0 + exit = True + if module._mode == "PIPE": + exit = module._micro_idx == config['micros'] - 1 + + if exit: + module._forward_block_ctx.exit(forward_flag) + module._ref_count += 1 + +def zero_pre_backward(module, grad_outputs): + backward_flag = 2 if config['zero_level'] == 2 else 0 + if module._mode != "PIPE": + module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) + module._backward_block_ctx.enter(backward_flag, True) + if not module._is_last_layer: + module.next_module().backward_release(backward_flag) + else: + if module._micro_idx == config['micros'] - 1: + module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=True) + module._backward_block_ctx.enter(backward_flag, True) + +def zero_post_backward(module, grad_inputs, grad_outputs): + backward_flag = 2 if config['zero_level'] == 2 else 0 + if module._mode != "PIPE": + if module._is_first_layer: + module.backward_release(backward_flag) + else: + if module._micro_idx == 0: + module.backward_release(backward_flag) + module._micro_idx -= 1 + +class OneStepNoGradFunc(torch.autograd.Function): + """ + requires_grad = False for all inputs + """ + @staticmethod + def forward(ctx, module, placeholder, *x): + ctx.x = x + ctx.module = module + ctx.rng_state = torch.cuda.get_rng_state() + + with torch.no_grad(): + out = module._module(*x) + zero_post_forward(module, None, out) + if not isinstance(out, torch.Tensor): + return tuple(out) + return out + + @staticmethod + def backward(ctx, grads): + zero_pre_backward(ctx.module, grads) + with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): + torch.cuda.set_rng_state(ctx.rng_state) + x = ctx.x + with torch.enable_grad(): + out = ctx.module._module(*x) + torch.autograd.backward(out, grads) + zero_post_backward(ctx.module, grads, None) + grads = [] + for _ in x: + grads.append(None) + return None, None, *grads + + +class PreHookFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, module, *x): + ctx.module = module + zero_pre_forward(module, x) + return x + + @staticmethod + def backward(ctx, *grads): + zero_post_backward(ctx.module, grads, None) + return None, *grads + +class PostHookFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, module, *out): + ctx.module = module + zero_post_forward(module, None, out) + return out + + @staticmethod + def backward(ctx, *grads): + zero_pre_backward(ctx.module, grads) + return None, *grads diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index 160ef421..31223640 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -185,6 +185,15 @@ def __init__(self, self.inplace = inplace def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + if input.dtype == torch.float32: + return torch.nn.functional.cross_entropy( + input, + target.long(), + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing) + if self.inplace: ret = OpFusedCrossEntropyInplace.apply(input, target.int(), self.ignore_index) # return float tensor else: diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py new file mode 100644 index 00000000..67f9fdee --- /dev/null +++ b/bmtrain/nn/__init__.py @@ -0,0 +1 @@ +from .linear import Linear diff --git a/example/layers/linear.py b/bmtrain/nn/linear.py similarity index 50% rename from example/layers/linear.py rename to bmtrain/nn/linear.py index 0aa0ab00..faf0770e 100644 --- a/example/layers/linear.py +++ b/bmtrain/nn/linear.py @@ -2,6 +2,26 @@ import torch.nn.functional as F import bmtrain as bmt +class CustomLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias=None): + ctx.save_for_backward(x, weight, bias) + return F.linear(x, weight, bias) + + @staticmethod + def backward(ctx, grad_output): + x, weight, bias = ctx.saved_tensors + grad_x = grad_weight = grad_bias = None + if x.requires_grad: + grad_x = grad_output.matmul(weight) + if weight.requires_grad: + dim = grad_output.dim() + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1])) + if bias is not None and bias.requires_grad: + grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + return grad_x, grad_weight, grad_bias + class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: super().__init__() @@ -15,9 +35,9 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp self.register_parameter('bias', None) def forward(self, input): - return F.linear(input, self.weight, self.bias) + return CustomLinear.apply(input, self.weight, self.bias) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None - ) \ No newline at end of file + ) diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 69c299bc..0a34ac46 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -5,310 +5,161 @@ from typing import Dict, Iterable, Iterator, Tuple, Union, List import torch -from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations +from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations from .global_var import config from . import nccl -from .checkpointing import ScopedTensorInspectorContext +from .checkpointing import ( + CheckpointBlockContext +) from . import debug -from .block_layer import CheckpointBlockContext, CheckpointBlock, round_up, _get_param_kw +from .block_layer import CheckpointBlock, round_up, _get_param_kw -class OpMicroForward(torch.autograd.Function): +class PipePreFunction(torch.autograd.Function): @staticmethod - def forward(ctx, placeholder, self : 'PipelineTransformerBlockList', micro_idx, block_ctx_list, layers_dict, save_list, hidden_state, *args): - with PipeContext(self, hidden_state) as pipe_input: - hidden_state = pipe_input[0].detach() - tensors = [arg if torch.is_tensor(arg) else None for arg in args] - others = [arg if not torch.is_tensor(arg) else None for arg in args] - ctx.nontensor_inputs = others - ctx.self = self - ctx.micro_idx = micro_idx - ctx.block_ctx_list = block_ctx_list - ctx.layers_dict = layers_dict - ctx.save_list = copy.deepcopy(save_list) - ctx.num_save_needed = save_list[-1][1]+1 - layer_inputs = [] - layer_inspector = [] - cuda_rng_state = [] - for idx,layer_id in enumerate(self.layer_ids): - with torch.no_grad(): - if save_list[idx][0] == idx: - layer_inputs.append(hidden_state.detach()) - cuda_rng_state.append( torch.cuda.get_rng_state() ) - # gather parameter on load stream - if ctx.micro_idx == 0: - block_ctx_list[idx] = CheckpointBlockContext(self._modules[str(layer_id)], ctx.layers_dict[idx], 1, pipe=True) - block_ctx_list[idx].enter() - # call inner module directly - with ScopedTensorInspectorContext() as inspector: - hidden_state = self._modules[str(layer_id)]._module._call_impl(hidden_state, *args) - if ctx.micro_idx == config["micros"]-1: - block_ctx_list[idx].exit() - for ith, it in enumerate(inspector.hidden_states): - it["inside_pipe"] = { - "stage_id": self.stage_id, - "stages": self.stages, - "st": (layer_id==self.layer_ids[0] and ith==0), - "ed": (layer_id==self.layer_ids[-1] and ith==len(inspector.hidden_states)-1), - } - debug.append("_inspect_hidden_states", it) - layer_inspector.append(inspector.hidden_states) - - ctx.layer_inspector = layer_inspector - ctx.cuda_rng_state = cuda_rng_state - - ctx.save_for_backward(*layer_inputs, *tensors) - pipe_input[0] = hidden_state - if self.return_hidden_states: - middle_hiddens = layer_inputs - for mid in middle_hiddens: - mid.requires_grad_() - middle_hiddens = torch.stack(middle_hiddens, dim=0) - else: - middle_hiddens = None - return tuple([pipe_input[0], middle_hiddens] + [hidden_state["tensor"] for hidden_states in ctx.layer_inspector for hidden_state in hidden_states]) + def forward(ctx, hidden_state, *args): + hidden_state_list = all_gather(hidden_state.clone(), config["pipe_comm"]) + hidden_state_list.requires_grad_() - @staticmethod - def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle : torch.Tensor, *grad_inspector): - def exit_prev(prev_ctx, prev_grad): - if prev_ctx is not None: - if prev_grad: - with torch.enable_grad(): - prev_ctx.exit() - config["load_stream"].record_event(config["load_event"]) - else: - with torch.no_grad(): - prev_ctx.exit() - config["load_stream"].record_event(config["load_event"]) - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad() or when an `inputs` parameter" - " is passed to .backward(). Please use .backward() and do not pass its `inputs`" - " argument.") - all_inputs = [] - input_requires_grad = [] - - layer_inputs = ctx.saved_tensors[:ctx.num_save_needed] - save_args = ctx.saved_tensors[ctx.num_save_needed:] - for tensor, other in zip(save_args, ctx.nontensor_inputs): - if tensor is None: - all_inputs.append(other) - input_requires_grad.append(False) - else: - # detach for tensor inputs - input_requires_grad.append( tensor.requires_grad ) - nw_tensor = tensor.detach() - nw_tensor.requires_grad = tensor.requires_grad - all_inputs.append(nw_tensor) - with PipeContext(ctx.self, grad_hidden_state, backward=True) as pipe_input: - grad_hidden_state = pipe_input[0] - with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): - with torch.enable_grad(): - # overlap load and scatter here - prev_ctx = None - prev_grad = False - for idx, layer_id in list(enumerate(ctx.self.layer_ids))[::-1]: - torch.cuda.set_rng_state(ctx.cuda_rng_state[idx]) - ipt = layer_inputs[ctx.save_list[idx][1]].requires_grad_() - if ctx.micro_idx == 0: - ctx.block_ctx_list[idx] = CheckpointBlockContext(ctx.self._modules[str(layer_id)], ctx.layers_dict[idx], 2, pipe=True) - ctx.block_ctx_list[idx].enter() - if ctx.micro_idx == config["micros"]-1: - exit_prev(prev_ctx, prev_grad) - prev_ctx = ctx.block_ctx_list[idx] - prev_grad = True - - with ScopedTensorInspectorContext() as inspector: - output = ctx.self._modules[str(layer_id)]._module._call_impl(ipt, *all_inputs) - - assert len(ctx.layer_inspector[idx]) == len(inspector.hidden_states), "Backward step changed" - for j, it in enumerate(inspector.hidden_states): - assert it["name"] == ctx.layer_inspector[idx][j]["name"], "Backward step changed" - assert it["shape"] == ctx.layer_inspector[idx][j]["shape"], "Backward step changed" - assert it["group"] == ctx.layer_inspector[idx][j]["group"], "Backward step changed" - - # change the tensor in placeholder - ctx.layer_inspector[idx][j]["tensor"] = it["tensor"] - ctx.layer_inspector[idx][j]["requires_grad"] = it["requires_grad"] - if len(inspector.hidden_states) > 0: - torch.autograd.backward( - [output] + [hidden_state["tensor"] for hidden_state in inspector.hidden_states], - [grad_hidden_state] + list(grad_inspector[-len(inspector.hidden_states):]), - ) - grad_inspector = grad_inspector[:-len(inspector.hidden_states)] - else: - torch.autograd.backward( - [output], - [grad_hidden_state], - ) - grad_hidden_state = ipt.grad - if grad_middle is not None: - grad_hidden_state = grad_hidden_state + grad_middle[idx] - if ctx.micro_idx == config["micros"]-1: - exit_prev(prev_ctx, prev_grad) - for inspector_hiddens in ctx.layer_inspector: - for it in inspector_hiddens: - debug.append("_inspect_hidden_states", it) - - pipe_input[0] = grad_hidden_state - grads = [] - for inp, requires_grad in zip(all_inputs, input_requires_grad): - if requires_grad: - grads.append(inp.grad) - else: - grads.append(None) - return (None, None, None, None, None, None, pipe_input[0]) + tuple(grads) - -class OpPipeTransformerBlockList(torch.autograd.Function): - @staticmethod - def forward(ctx, placeholder, self : 'PipelineTransformerBlockList', save_list, hidden_state, *args): - num_micros = config["micros"] - ctx.self = self - ctx.num_micros = num_micros - block_ctx = [None for _ in range(len(self))] - layers_dict = [{} for _ in range(len(self))] - args_list = [[] for _ in range(num_micros)] batch_related = args[-1] batch_related_origin = [True if i in args[-1] else False for i in range(len(args[:-1]))] batch_related_rule = [] args = args[:-1] + batch_size = hidden_state.shape[0] - assert (batch_size * config["pipe_size"]) % num_micros == 0, f'The batch size {(batch_size * config["pipe_size"])} must be divisible by the number of micro_batch {num_micros}' + num_micros = config["micros"] + args_list = [[] for _ in range(num_micros)] input_requires_grad = [] - inspector_hiddens = [] - ctx.inspector_hiddens_sep = [0] - ctx.micro_inspector = [] - with torch.enable_grad(): - for arg in args: - if torch.is_tensor(arg): - arg_all = all_gather(arg, config['pipe_comm']) - if arg.shape[0] == batch_size: - batch_related_rule.append(True) - arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0) - arg_all = [tensor.detach().requires_grad_(arg.requires_grad) for tensor in arg_all] - else: - batch_related_rule.append(False) - # assert num_micros % self.stages == 0, "batch unrelated only support num_micros % stages == 0" - # arg_all = [arg_all[i // (num_micros // self.stages)].detach().requires_grad_(arg.requires_grad) for i in range(num_micros)] - arg_all = [arg_all[0].detach().requires_grad_(arg.requires_grad) for i in range(num_micros)] - input_requires_grad.append(arg.requires_grad) + for arg in args: + if torch.is_tensor(arg): + arg_all = all_gather(arg, config['pipe_comm']) + if arg.dim() == hidden_state.dim() and arg.shape[0] == batch_size: + batch_related_rule.append(True) + arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0) + arg_all = [tensor.requires_grad_(arg.requires_grad) for tensor in arg_all] else: batch_related_rule.append(False) - arg_all = [arg for _ in range(num_micros)] - input_requires_grad.append(False) - for i in range(num_micros): - args_list[i].append(arg_all[i]) - outputs = [] - if self.return_hidden_states: - middles = [] - hidden_state_list = all_gather(hidden_state, config["pipe_comm"]).flatten(0, 1).detach().requires_grad_() - ctx.hidden_state_list = hidden_state_list - hidden_state_list = hidden_state_list.chunk(num_micros, dim=0) - for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - with ScopedTensorInspectorContext() as inspector: - micro_outputs = OpMicroForward.apply(placeholder, self, micro_idx, block_ctx, layers_dict, save_list, hidden_state, *arg) - output, middle = micro_outputs[:2] - outputs.append(output) - if self.return_hidden_states: - middles.append(middle) - for it in inspector.hidden_states: - inspector_hiddens.append(it["tensor"]) - it["tensor"] = it["tensor"].clone() - debug.append("_inspect_hidden_states", it) - ctx.inspector_hiddens_sep.append(len(inspector_hiddens)) - ctx.micro_inspector.append(inspector.hidden_states) + arg_all = [arg_all[0].requires_grad_(arg.requires_grad) for i in range(num_micros)] + input_requires_grad.append(arg.requires_grad) + else: + batch_related_rule.append(False) + arg_all = [arg for _ in range(num_micros)] + input_requires_grad.append(False) + for i in range(num_micros): + args_list[i].append(arg_all[i]) + ctx.input_requires_grad = input_requires_grad + ctx.args_list = args_list if len(batch_related) == 0: ctx.batch_related = batch_related_rule else: ctx.batch_related = batch_related_origin - ctx.args_list = args_list - ctx.input_requires_grad = input_requires_grad - ctx.output_list = outputs - if self.return_hidden_states: - ctx.middle_list = middles - - with torch.enable_grad(): - last_hidden = torch.cat(outputs, dim=0) - last_hidden_shape = last_hidden.shape - last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) - last_hidden = last_hidden.chunk(self.stages, dim=0) - last_hidden = last_hidden[self.stage_id].clone() - - if self.return_hidden_states: - middle_hiddens = [] - with torch.enable_grad(): - for stage_id in range(self.stages): - if self.stage_id == stage_id: - middle_hidden = torch.cat(middles, dim=1) # [(layers, micro_batch, ...), ] -> (layers, full_batch, ...) - else: - middle_shape = (self.get_part_len_by_stage_id(stage_id),)+last_hidden_shape - middle_hidden = torch.zeros(middle_shape, device=last_hidden.device, dtype=last_hidden.dtype) - middle_hidden = broadcast(middle_hidden, stage_id, config["pipe_comm"]) - middle_hidden = middle_hidden.chunk(self.stages, dim=1) - middle_hidden = middle_hidden[self.stage_id].clone() - middle_hiddens.append(middle_hidden) - middle_hiddens = torch.cat(middle_hiddens, dim=0) - else: - middle_hiddens = None - - ctx.save_for_backward(*inspector_hiddens) - return tuple([last_hidden, middle_hiddens] + [it["tensor"] for inspector_hiddens in ctx.micro_inspector for it in inspector_hiddens]) - + return hidden_state_list, args_list @staticmethod - def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle : torch.Tensor, *grad_inspectors): - inspector_hiddens = ctx.saved_tensors - ipt = ctx.hidden_state_list - args_list = ctx.args_list - input_requires_grad = ctx.input_requires_grad - grad_hidden_state_list = all_gather(grad_hidden_state, config["pipe_comm"]).flatten(start_dim=0, end_dim=1).chunk(ctx.num_micros, dim=0) - if ctx.self.return_hidden_states: - for stage_id in range(ctx.self.stages): - layer_range = ctx.self.get_range_by_stage_id(stage_id) - grad_middle_state = grad_middle[layer_range] - grad_middle_state = all_gather(grad_middle_state.transpose(0,1), config["pipe_comm"]).flatten(start_dim=0, end_dim=1).transpose(0, 1).chunk(ctx.num_micros, dim=1) # (layer, micro_batch, ...) - if ctx.self.stage_id == stage_id: - grad_middle_state_list = grad_middle_state - - for m in range(ctx.num_micros): - outputs = [ctx.output_list[m]] - grad_outputs = [grad_hidden_state_list[m]] - if ctx.self.return_hidden_states: - outputs.append(ctx.middle_list[m]) - grad_outputs.append(grad_middle_state_list[m]) - outputs += list(inspector_hiddens[ctx.inspector_hiddens_sep[m]:ctx.inspector_hiddens_sep[m+1]]) - grad_outputs += list(grad_inspectors[ctx.inspector_hiddens_sep[m]:ctx.inspector_hiddens_sep[m+1]]) - with ScopedTensorInspectorContext() as inspector: - torch.autograd.backward( - outputs, - grad_outputs, - ) - for j, it in enumerate(inspector.hidden_states): - assert it["name"] == ctx.micro_inspector[m][j]["name"], "Backward step changed" - assert it["shape"] == ctx.micro_inspector[m][j]["shape"], "Backward step changed" - assert it["group"] == ctx.micro_inspector[m][j]["group"], "Backward step changed" - - # change the tensor in placeholder - ctx.micro_inspector[m][j]["tensor"] = it["tensor"] - ctx.micro_inspector[m][j]["requires_grad"] = it["requires_grad"] - - grads = [] - for idx,requires_grad in enumerate(input_requires_grad): + def backward(ctx, grads, arg_grads): + grads = broadcast(grads, 0, config['pipe_comm']) + topo = config['topology'] + arg_grads = [] + num_micros = config['micros'] + for idx,requires_grad in enumerate(ctx.input_requires_grad): if requires_grad: - grad = torch.cat([args_list[m][idx].grad for m in range(ctx.num_micros)], dim=0) + grad = torch.cat([ctx.args_list[m][idx].grad for m in range(num_micros)], dim=0) grad = all_reduce(grad, "sum", config["pipe_comm"]) - split_size = ctx.self.stages if ctx.batch_related[idx] else ctx.num_micros + split_size = topo.stages if ctx.batch_related[idx] else num_micros grad = grad.chunk(split_size) if ctx.batch_related[idx]: - grads.append(grad[ctx.self.stage_id]) + arg_grads.append(grad[topo.stage_id]) else: - grads.append(grad[0]) + arg_grads.append(grad[0]) else: - grads.append(None) - grad = broadcast(ipt.grad, 0, config["pipe_comm"]).chunk(ctx.self.stages) - grad = grad[ctx.self.stage_id] + arg_grads.append(None) + arg_grads.append(None) #for append(batch_related) + return grads.chunk(topo.stages, dim=0)[topo.stage_id], *arg_grads + +class PipePostFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, last_hidden, hidden_states=None, forward_stage_ranges=None, backward_stage_ranges=None, last_hidden_shape=None, return_hidden_states=False): + topo = config['topology'] + ctx.return_hidden_states = return_hidden_states + last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) + last_hidden = last_hidden.chunk(topo.stages, dim=0) + output = last_hidden[topo.stage_id] + output.requires_grad_() + + if return_hidden_states: + ctx.stage_id = topo.stage_id + ctx.stages = topo.stages + ctx.backward_stage_ranges = backward_stage_ranges + middle_hiddens = [] + for stage_id in range(ctx.stages): + if ctx.stage_id == stage_id: + middle_hidden = hidden_states + else: + middle_shape = (forward_stage_ranges[stage_id],) + last_hidden_shape + middle_hidden = torch.zeros(middle_shape, device=hidden_states.device, dtype=hidden_states.dtype) + middle_hidden = broadcast(middle_hidden, stage_id, config["pipe_comm"]) + middle_hidden = middle_hidden.chunk(ctx.stages, dim=1) + middle_hidden = middle_hidden[ctx.stage_id].clone() + middle_hiddens.append(middle_hidden) + middle_hiddens = torch.cat(middle_hiddens, dim=0) + middle_hiddens.requires_grad_() + return output, middle_hiddens + else: + return output + + @staticmethod + def backward(ctx, grads, grad_middle=None): + grad_list = all_gather(grads, config["pipe_comm"]) + grad_list = grad_list.flatten(start_dim=0, end_dim=1) + + if ctx.return_hidden_states: + for stage_id in range(ctx.stages): + layer_range = ctx.backward_stage_ranges[stage_id] + grad_middle_state = grad_middle[layer_range] + grad_middle_state = all_gather(grad_middle_state.transpose(0,1), config["pipe_comm"]) + grad_middle_state = grad_middle_state.flatten(start_dim=0, end_dim=1).transpose(0, 1) + if ctx.stage_id == stage_id: + grad_hidden_state_list = grad_middle_state + return grad_list, grad_hidden_state_list, None, None, None, None + else: + return grad_list + +class StagePreFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, stage_id): + ctx.stage_id = stage_id + ctx.is_first_stage = stage_id == 0 + ctx.is_last_stage = stage_id == config['pipe_size'] - 1 + if not ctx.is_first_stage: + input = recv_activations(stage_id - 1, config['pipe_comm']) + input.requires_grad_() + return input + return input + + @staticmethod + def backward(ctx, grad_outputs): + if not ctx.is_first_stage: + send_data = grad_outputs[0] if isinstance(grad_outputs, tuple) else grad_outputs + send_activations(send_data, ctx.stage_id - 1, config['pipe_comm']) + return grad_outputs, None + +class StagePostFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, outputs, stage_id): + ctx.stage_id = stage_id + ctx.is_first_stage = stage_id == 0 + ctx.is_last_stage = stage_id == config['pipe_size'] - 1 + if not ctx.is_last_stage: + send_data = outputs[0] if isinstance(outputs, tuple) else outputs + send_activations(send_data.detach(), stage_id + 1, config['pipe_comm']) + return outputs + + @staticmethod + def backward(ctx, grad_outputs): + if not ctx.is_last_stage: + pre_grad_inputs = recv_activations(ctx.stage_id + 1, config['pipe_comm']) + return pre_grad_inputs, None + return grad_outputs, None - return (None, None, None, grad) + tuple(grads) + (None,) class PipelineTransformerBlockList(torch.nn.Module): r""" @@ -331,9 +182,9 @@ class PipelineTransformerBlockList(torch.nn.Module): """ _modules: Dict[str, CheckpointBlock] - def __init__(self, modules: Iterable[CheckpointBlock]) -> None: + def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: super().__init__() - + self.num_hidden = num_hidden self._modules = {} rank = config['rank'] topo = config['topology'] @@ -345,18 +196,37 @@ def __init__(self, modules: Iterable[CheckpointBlock]) -> None: for idx, module in enumerate(modules): if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) + + module._mode = "PIPE" + module.stage_id = self.stage_id + module.stages = self.stages + self._modules[str(idx)] = module self.layer_ids = self.get_range_by_stage_id(self.stage_id) + + pre_module = None + for i,layer_id in enumerate(self.layer_ids): + module = self._modules[str(layer_id)] + module.set_pre_module(pre_module) + pre_module = module + + module._is_first_stage = True if self.stage_id == 0 else False + module._is_last_stage = True if self.stage_id == self.stages-1 else False + module._is_first_layer = False + module._is_last_layer = False + self._modules[str(self.layer_ids[0])]._is_first_layer = True + self._modules[str(self.layer_ids[-1])]._is_last_layer = True + self.partition_modules(self.layer_ids) self.next_rank = pipe_group[self.pipe_idx, self.stage_id + 1] if self.stage_id < config['pipe_size'] - 1 else -1 self.prev_rank = pipe_group[self.pipe_idx, self.stage_id - 1] if self.stage_id > 0 else -1 # self.micro_batches = config['num_micro_batches'] - + self.save_list = [(i, i) for i in range(len(self.layer_ids))] def __len__(self) -> int: - return len(self._modules) + return len(self._modules) def __iter__(self) -> Iterator[CheckpointBlock]: return iter(self._modules.values()) @@ -366,15 +236,47 @@ def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=False): self.return_hidden_states = return_hidden_states - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - args = list(args) - args.append(batch_related) - outputs = OpPipeTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args) - hidden_state, middle_states = outputs[:2] + batch_size = hidden_state.shape[0] + num_micros = config["micros"] + args = args + (batch_related, ) + hidden_state.requires_grad_() + hidden_state_list, args_list = PipePreFunction.apply(hidden_state, *args) + + hidden_state_list = hidden_state_list.flatten(0, 1).chunk(num_micros, dim=0) + outputs = [] + hidden_states = [] + + for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): + micro_hidden_states = [] + + hidden_state = StagePreFunction.apply(hidden_state, self.stage_id) + + for idx,layer_id in enumerate(self.layer_ids): + self._modules[str(layer_id)]._micro_idx = micro_idx + if return_hidden_states: + micro_hidden_states.append(hidden_state) + hidden_state = self._modules[str(layer_id)](hidden_state, *arg) + hidden_state = StagePostFunction.apply(hidden_state, self.stage_id) + + outputs.append(hidden_state) + if return_hidden_states: + hidden_states.append(torch.stack(micro_hidden_states, dim=0)) + + last_hidden = torch.cat(outputs, dim=0) + last_hidden_shape = last_hidden.shape + if return_hidden_states: - return hidden_state, middle_states + hidden_states = torch.cat(hidden_states, dim=1) + forward_stage_ranges = [] + backward_stage_ranges = [] + for stage_id in range(self.stages): + forward_stage_ranges.append(self.get_part_len_by_stage_id(stage_id)) + backward_stage_ranges.append(self.get_range_by_stage_id(stage_id)) + outputs, hidden_states = PipePostFunction.apply(last_hidden, hidden_states, forward_stage_ranges, backward_stage_ranges, last_hidden_shape, return_hidden_states) + return outputs, hidden_states else: - return hidden_state + outputs = PipePostFunction.apply(last_hidden) + return outputs def get_range_by_stage_id(self, stage_id : int) -> List[int]: part_lens = [0]+[self.get_part_len_by_stage_id(i) for i in range(stage_id+1)] @@ -486,32 +388,3 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): for n, parameter in module._module.named_parameters(): destination[name+n] = recv_activations(self.get_stage_by_layer_id(idx), config['pipe_comm']) -class PipeContext: - def __init__(self, module, hidden_state, backward=False): - self.module = module - self.stage_id = module.stage_id - self.stages = module.stages - self.next_rank = module.next_rank - self.prev_rank = module.prev_rank - self.hidden_state = [hidden_state] - self.backward = backward - self.send_buffer = {} - def enter(self): - if self.backward: - if self.stage_id != self.stages -1: - self.hidden_state[0] = recv_activations(self.stage_id + 1, config['pipe_comm']) - else: - if self.stage_id != 0: - self.hidden_state[0] = recv_activations(self.stage_id - 1, config['pipe_comm']) - return self.hidden_state - def exit(self): - if self.backward: - if self.stage_id != 0: - send_activations(self.hidden_state[0], self.stage_id - 1, config['pipe_comm']) - else: - if self.stage_id != self.stages - 1: - send_activations(self.hidden_state[0], self.stage_id + 1, config['pipe_comm']) - def __enter__(self): - return self.enter() - def __exit__(self, exc_type, exc_val, exc_tb): - self.exit() \ No newline at end of file diff --git a/example/layers/__init__.py b/example/layers/__init__.py index 425d0a1b..ef4617c0 100644 --- a/example/layers/__init__.py +++ b/example/layers/__init__.py @@ -1,6 +1,5 @@ -from .linear import Linear from .embedding import Embedding from .feedforward import Feedforward from .layernorm import Layernorm from .attention import Attention -from .transformer import TransformerEncoder \ No newline at end of file +from .transformer import TransformerEncoder diff --git a/example/layers/attention.py b/example/layers/attention.py index 4a0eec11..243df3ea 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -1,7 +1,7 @@ from typing import Optional import torch import bmtrain as bmt -from layers import Linear +from bmtrain.nn import Linear import math class Attention(bmt.DistributedModule): diff --git a/example/layers/feedforward.py b/example/layers/feedforward.py index 3fe935bf..99d2dc3b 100644 --- a/example/layers/feedforward.py +++ b/example/layers/feedforward.py @@ -1,6 +1,6 @@ import torch import bmtrain as bmt -from layers import Linear +from bmtrain.nn import Linear class Feedforward(bmt.DistributedModule): def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None: diff --git a/example/train.py b/example/train.py index 7bc92400..1a744e20 100644 --- a/example/train.py +++ b/example/train.py @@ -2,6 +2,8 @@ import bmtrain as bmt from models import GPT import time +from bmtrain import optim +from bmtrain import inspect def main(): bmt.init_distributed( @@ -51,10 +53,10 @@ def main(): break loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) - optimizer = bmt.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) + optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) - optim_manager = bmt.optim.OptimManager(loss_scale=2**20) + optim_manager = optim.OptimManager(loss_scale=2**20) optim_manager.add_optimizer(optimizer, lr_scheduler) bmt.synchronize() @@ -66,7 +68,7 @@ def main(): # load data st = time.time() - with bmt.inspect.inspect_tensor() as inspector: + with inspect.inspect_tensor() as inspector: pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) logits = model( enc_input, @@ -87,13 +89,13 @@ def main(): # print parameters of the model if iteration % 100 == 0: bmt.print_rank( - bmt.inspect.format_summary( + inspect.format_summary( inspector.get_summary() ) ) bmt.print_rank( - bmt.inspect.format_summary( - bmt.inspect.inspect_model(model, "*") + inspect.format_summary( + inspect.inspect_model(model, "*") ) ) diff --git a/tests/test_all.py b/tests/test_all.py index b614d3eb..6682aa93 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -24,6 +24,7 @@ ("send_recv", 4), ("nccl_backward", 4), + ("no_grad", 1), ("training", 4), ]) diff --git a/tests/test_has_inf_nan.py b/tests/test_has_inf_nan.py index b1b9b4a9..fda85515 100644 --- a/tests/test_has_inf_nan.py +++ b/tests/test_has_inf_nan.py @@ -1,12 +1,12 @@ from utils import * import torch -import bmtrain.optim._cuda as G +import bmtrain.loss._function as F import random def check(x, v): out = torch.zeros(1, dtype=torch.uint8, device="cuda")[0] - G.f_has_inf_nan(x, out) + F.has_inf_nan(x, out) assert_eq(out.item(), v) def test_main(): @@ -29,4 +29,4 @@ def test_main(): check(x, 1) if __name__ == "__main__": - test_main() \ No newline at end of file + test_main() diff --git a/tests/test_inspector_hidden.py b/tests/test_inspector_hidden.py index 731884ad..c39de5fb 100644 --- a/tests/test_inspector_hidden.py +++ b/tests/test_inspector_hidden.py @@ -7,6 +7,7 @@ from bmtrain.block_layer import CheckpointBlock, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -48,7 +49,7 @@ def __init__(self, dim : int): def forward(self, x): x = self.m1(x) - bmt.inspect.record_tensor(x, "hidden") + inspect.record_tensor(x, "hidden") x = self.m2(x) return x @@ -160,10 +161,10 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len): bmt.init_parameters(m) m = cls(pre, [m for m in ms], post) ret = "" - with bmt.inspect.inspect_tensor() as inspector: + with inspect.inspect_tensor() as inspector: logits = m(inp) - ret += bmt.inspect.format_summary( + ret += inspect.format_summary( inspector.get_summary() ) + "\n" @@ -171,32 +172,32 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len): for i in range(len(ms)//2): loss = loss + (inspector.summary[i]['tensor'] * middle_weight[i]).sum() - with bmt.inspect.inspect_tensor(): + with inspect.inspect_tensor(): loss.backward() - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) + "\n" - ret += bmt.inspect.format_summary( + ret += inspect.format_summary( inspector.get_summary() ) + "\n" - with bmt.inspect.inspect_tensor() as inspector: + with inspect.inspect_tensor() as inspector: logits = m(inp) - ret += bmt.inspect.format_summary( + ret += inspect.format_summary( inspector.get_summary() ) + "\n" loss = (logits * last_weight).sum() - with bmt.inspect.inspect_tensor(): + with inspect.inspect_tensor(): loss.backward() - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) + "\n" - ret += bmt.inspect.format_summary( + ret += inspect.format_summary( inspector.get_summary() ) + "\n" @@ -237,4 +238,4 @@ def test_main(): if __name__ == "__main__": bmt.init_distributed(pipe_size=2) - test_main() \ No newline at end of file + test_main() diff --git a/tests/test_middle_hidden.py b/tests/test_middle_hidden.py index f0d5c559..688cdfe5 100644 --- a/tests/test_middle_hidden.py +++ b/tests/test_middle_hidden.py @@ -3,10 +3,10 @@ import bmtrain as bmt import random import torch -from bmtrain import config from bmtrain.block_layer import CheckpointBlock, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -143,8 +143,8 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid loss = (logits * last_weight).sum() loss.backward() ret += f"========================only last========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) if only_middle: logits, hidden_states = m(inp, return_hidden_states=True) @@ -154,8 +154,8 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid ]) loss.backward() ret += f"========================only middle========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) if mix_test: logits, hidden_states = m(inp, return_hidden_states=True) @@ -165,8 +165,8 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid ]) + (logits * last_weight).sum() loss.backward() ret += f"========================mix========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) return ret + "\n" # replace for matching None grad with zero_grad @@ -209,4 +209,4 @@ def test_main(): if __name__ == "__main__": bmt.init_distributed(pipe_size=4) - test_main() \ No newline at end of file + test_main() diff --git a/tests/test_model_wrapper.py b/tests/test_model_wrapper.py index 409107e3..6f913d3c 100644 --- a/tests/test_model_wrapper.py +++ b/tests/test_model_wrapper.py @@ -164,7 +164,7 @@ def forward(self, out = input_emb for layer in self.transformers: - out = layer(out, position_bias=None, mask=mask_2d) + out = layer(out, mask_2d) out = self.layernorm(out) logits = F.linear(out, self.word_emb.weight) / math.sqrt(self.dim_model) @@ -218,4 +218,4 @@ def test_main(): if __name__ == '__main__': bmt.init_distributed(seed=0) - test_main() \ No newline at end of file + test_main() diff --git a/tests/test_no_grad.py b/tests/test_no_grad.py new file mode 100644 index 00000000..3629921b --- /dev/null +++ b/tests/test_no_grad.py @@ -0,0 +1,46 @@ +import torch +import bmtrain as bmt + +class Layer(torch.nn.Module): + def __init__(self): + super(Layer, self).__init__() + self.linear = bmt.nn.Linear(32, 32) + self.count = 0 + def forward(self, x): + self.count += 1 + return self.linear(x) + +def test_no_grad(): + x = torch.randn(32, 32, device='cuda') + + layer1 = bmt.CheckpointBlock(Layer()) + layer2 = bmt.CheckpointBlock(Layer()) + layer1.linear.weight.requires_grad_(False) + layer1.linear.bias.requires_grad_(False) + y = layer1(x) + assert y.requires_grad == False + y = layer2(y) + y.sum().backward() + assert layer1.count == 1 + assert layer2.count == 2 + +def test_all_input_no_grad(): + linear1 = bmt.nn.Linear(32, 32) + linear2 = bmt.nn.Linear(32, 32) + + x = torch.randn(32,32, device='cuda') + + linear1 = bmt.CheckpointBlock(linear1) + linear2 = bmt.CheckpointBlock(linear2) + y = linear1(x) + y = linear2(y) + y.sum().backward() + assert linear1.weight.grad is not None + assert linear1.bias.grad is not None + assert x.grad is None + +if __name__ == '__main__': + bmt.init_distributed() + + test_no_grad() + test_all_input_no_grad() diff --git a/tests/test_optim.py b/tests/test_optim.py index 81356ede..fdb64521 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,5 +1,6 @@ import torch import bmtrain as bmt +from bmtrain import optim class TestModule(torch.nn.Module): def __init__(self): @@ -29,8 +30,8 @@ def main(): model2 = model2.cuda() model3 = model3.cuda() - opt1 = bmt.optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) - opt2 = bmt.optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) + opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) + opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-3) for _ in range(100): diff --git a/tests/test_optim_state.py b/tests/test_optim_state.py index df697f49..cef06734 100644 --- a/tests/test_optim_state.py +++ b/tests/test_optim_state.py @@ -2,6 +2,7 @@ import bmtrain as bmt import os from copy import deepcopy +from bmtrain import optim class TestSubModule(bmt.DistributedModule): def __init__(self): @@ -67,10 +68,10 @@ def main(): bmt.load(model2, f"test_optim_state_model1.pt") bmt.load(model3, f"test_optim_state_model1.pt") - opt1 = bmt.optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) - opt2 = bmt.optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) + opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) + opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-3) - optim_manager = bmt.optim.OptimManager(loss_scale=256) + optim_manager = optim.OptimManager(loss_scale=256) optim_manager.add_optimizer(opt1) optim_manager.add_optimizer(opt2) optim_manager.add_optimizer(opt3) @@ -121,4 +122,4 @@ def main(): if __name__ == "__main__": bmt.init_distributed() - main() \ No newline at end of file + main() diff --git a/tests/test_other_hidden.py b/tests/test_other_hidden.py index d1e317ad..1f6c8c65 100644 --- a/tests/test_other_hidden.py +++ b/tests/test_other_hidden.py @@ -7,6 +7,7 @@ from bmtrain.block_layer import CheckpointBlock, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -142,22 +143,22 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_pre=False, only_post loss = (pre.weight * last_weight).sum() loss.backward() ret += f"========================only last========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) if only_post: loss = (post.weight * last_weight).sum() loss.backward() ret += f"========================only middle========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) if mix_test: loss = (pre.weight * last_weight).sum() + (post.weight * last_weight).sum() loss.backward() ret += f"========================mix========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) return ret + "\n" # replace for matching None grad with zero_grad diff --git a/tests/test_requires_grad.py b/tests/test_requires_grad.py index 83fe8d17..943275c3 100644 --- a/tests/test_requires_grad.py +++ b/tests/test_requires_grad.py @@ -7,6 +7,7 @@ from bmtrain.pipe_layer import PipelineTransformerBlockList from typing import List import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -25,19 +26,22 @@ def __init__(self, in_features : int, out_features: int, init_weight = None, ini else: self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_) - def forward(self, input): + def forward(self, input, other_bias): ret = F.linear(input, self.weight, self.bias) + ret += other_bias return ret def run(m, a, b): inp = torch.rand((1, 10, 256)).cuda()*100 - logits = m(inp) + bias = torch.rand(256).cuda()*100 + logits = m(inp, bias) loss = logits.sum() loss.backward() bmt.synchronize() - sm = bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + sm = inspect.format_summary( + inspect.inspect_model(m, '*') ) + assert_eq(bias.requires_grad, False) return a.weight.grad is None, a.bias.grad is None, sm def test_main(): @@ -100,4 +104,4 @@ def test_main_pipe(): bmt.init_distributed(pipe_size=1) test_main() - test_main_pipe() \ No newline at end of file + test_main_pipe() diff --git a/tests/test_requires_grad_multi_gpu.py b/tests/test_requires_grad_multi_gpu.py index ebea096e..4a2670ae 100644 --- a/tests/test_requires_grad_multi_gpu.py +++ b/tests/test_requires_grad_multi_gpu.py @@ -2,11 +2,11 @@ import bmtrain as bmt import torch -from bmtrain import config from bmtrain.block_layer import CheckpointBlockContext, CheckpointBlock, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList from typing import List import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -35,8 +35,8 @@ def run(m, a, b): loss = logits.sum() loss.backward() - sm = bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + sm = inspect.format_summary( + inspect.inspect_model(m, '*') ) return sm @@ -93,4 +93,4 @@ def test_main_pipe(): bmt.init_distributed(pipe_size=2) test_main() - test_main_pipe() \ No newline at end of file + test_main_pipe() diff --git a/tests/test_training.py b/tests/test_training.py index 7342fe6c..1d6481c9 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -6,8 +6,8 @@ import math import torch.nn.functional as F import bmtrain as bmt -from bmtrain.global_var import config import os +from bmtrain import inspect class Attention(torch.nn.Module): def __init__(self, @@ -151,6 +151,7 @@ def __init__(self, ) for _ in range(num_layers) ]) + self.run_unroll = False self.layernorm = torch.nn.LayerNorm(dim_model, dtype=dtype) @@ -166,7 +167,7 @@ def forward(self, input_emb = self.pos_emb(pos) + self.word_emb(input) out = input_emb - if isinstance(self.transformers, torch.nn.ModuleList): + if isinstance(self.transformers, torch.nn.ModuleList) or self.run_unroll: for layer in self.transformers: out = layer(out, mask_2d, None) else: @@ -250,7 +251,7 @@ def sub_train_torch(model, loss_func_cls, optimizer_cls): )) logs.append(global_loss) - summary = bmt.inspect.inspect_model(model, "*") + summary = inspect.inspect_model(model, "*") return logs, summary def sub_train(model, loss_func_cls, optimizer_cls): @@ -311,7 +312,7 @@ def sub_train(model, loss_func_cls, optimizer_cls): )) logs.append(global_loss) - summary = bmt.inspect.inspect_model(model, "*") + summary = inspect.inspect_model(model, "*") return logs, summary def train(model, loss_func, optimizer): @@ -376,11 +377,20 @@ def pipe_model(): bmt.load(pipe_model, ckpt_path) return model + def unroll_list_model(): + model = GPT(**kwargs) + list_model = bmt.BMTrainModelWrapper(model) + list_model.transformers = bmt.TransformerBlockList([m for m in list_model.transformers]) + bmt.load(list_model, ckpt_path) + model.run_unroll = True + return model + models = { "torch": torch_model, "wrapper": wrap_model, "blocklist": list_model, "pipelist": pipe_model, + "unroll_blocklist": unroll_list_model, } loss_funcs = { "bmt_entropy": bmt.loss.FusedCrossEntropy, @@ -406,6 +416,7 @@ def add_to_check_list(m, l, o): add_to_check_list("pipelist", "bmt_entropy", "bmt_adam") add_to_check_list("blocklist", "torch_entropy", "bmt_adam") add_to_check_list("blocklist", "bmt_entropy", "bmt_adam_offload") + add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam") if bmt.rank() == 0: os.remove(ckpt_path) check(ret) @@ -419,6 +430,7 @@ def add_to_check_list(m, l, o): add_to_check_list("pipelist", "torch_entropy", "bmt_adam") add_to_check_list("blocklist", "torch_entropy", "bmt_adam_offload") add_to_check_list("blocklist", "torch_entropy", "torch_adam") + add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam") if bmt.rank() == 0: os.remove(ckpt_path) check(ret) @@ -442,4 +454,4 @@ def check_param(info1, info2): if __name__ == '__main__': bmt.init_distributed(pipe_size=2) - test_main(test_fp16=True, test_fp32=True) \ No newline at end of file + test_main(test_fp16=True, test_fp32=True)