Skip to content

Commit

Permalink
QLoRA (#9340)
Browse files Browse the repository at this point in the history
* temp qlora implementation

Signed-off-by: Chen Cui <[email protected]>

* swap nf4 after model instantiation

Signed-off-by: Chen Cui <[email protected]>

* load model on cpu and then quantize on gpu

Signed-off-by: Chen Cui <[email protected]>

* model init on cpu to prevent memory spike

Signed-off-by: Chen Cui <[email protected]>

* account for TE versions

Signed-off-by: Chen Cui <[email protected]>

* guard use_cpu_initialization

Signed-off-by: Chen Cui <[email protected]>

* fix layernorm autograd Function

Signed-off-by: Chen Cui <[email protected]>

* add unit tests

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

* move cpu init to library code

Signed-off-by: Chen Cui <[email protected]>

* copyright header and nf4 quantize on GPU

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

* fix cpu init

Signed-off-by: Chen Cui <[email protected]>

* comments

Signed-off-by: Chen Cui <[email protected]>

* fix test

Signed-off-by: Chen Cui <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: cuichenx <[email protected]>
  • Loading branch information
cuichenx and cuichenx authored Jun 7, 2024
1 parent c665430 commit ceffb49
Show file tree
Hide file tree
Showing 6 changed files with 376 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
model_provider_func=self.model_provider_func,
wrap_with_ddp=False,
virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None),
on_cpu=cfg.get('fsdp', False) and cfg.get('use_cpu_initialization', False),
on_cpu=cfg.get('use_cpu_initialization', False),
)

# if we're not using interleaved, then self.model is a module.
Expand Down Expand Up @@ -887,10 +887,18 @@ def training_step(self, dataloader_iter):
self.megatron_timer_stop('allreduce_first_last_embeddings')

if self.log_memory_usage:
mem_reserved = torch.cuda.max_memory_reserved()
max_memory_reserved = torch.cuda.max_memory_reserved()
memory_allocated = torch.cuda.memory_allocated()
self.log(
'peak_memory_usage',
mem_reserved,
max_memory_reserved,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)
self.log(
'memory_allocated',
memory_allocated,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
Expand Down
26 changes: 19 additions & 7 deletions nemo/collections/nlp/models/nlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@


class NLPModel(ModelPT, Exportable):
"""Base class for NLP Models.
"""
"""Base class for NLP Models."""

def __init__(self, cfg: DictConfig, trainer: Trainer = None, no_lm_init=False):

Expand Down Expand Up @@ -120,7 +119,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None, no_lm_init=False):
if cfg.get('language_model').get('config_file'):
config_file = self.register_artifact('language_model.config_file', cfg.language_model.config_file)
bert_model = get_lm_model(
config_file=config_file, config_dict=config_dict, vocab_file=vocab_file, trainer=trainer, cfg=cfg,
config_file=config_file,
config_dict=config_dict,
vocab_file=vocab_file,
trainer=trainer,
cfg=cfg,
)
# set the tokenizer if it is not initialized explicitly
if ((hasattr(self, 'tokenizer') and self.tokenizer is None) or not hasattr(self, 'tokenizer')) and hasattr(
Expand All @@ -146,16 +149,18 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None, no_lm_init=False):
self.register_bert_model()

def register_artifact(
self, config_path: str, src: str, verify_src_exists: bool = False,
self,
config_path: str,
src: str,
verify_src_exists: bool = False,
):
""" Overrides ModelPT register_artifact default behavior.
"""Overrides ModelPT register_artifact default behavior.
NLP models usually need artifacts that are optional."""
return super().register_artifact(config_path, src, verify_src_exists=verify_src_exists)

@rank_zero_only
def register_bert_model(self):
"""Adds encoder config to .nemo archive for Jarvis.
"""
"""Adds encoder config to .nemo archive for Jarvis."""
# check if there is an encoder, warn if not
if self.bert_model is not None:
# get encoder config and create source for artifact
Expand Down Expand Up @@ -462,6 +467,13 @@ def restore_from(
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(restore_path):
save_restore_connector.model_extracted_dir = restore_path
if (
isinstance(override_config_path, DictConfig)
and override_config_path.get('use_cpu_initialization', False)
and map_location is None
):
logging.info('use_cpu_initialization is True, loading checkpoint on CPU')
map_location = 'cpu'
return super().restore_from(
restore_path, override_config_path, map_location, strict, return_config, save_restore_connector, trainer
)
246 changes: 246 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/adapters/qlora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from importlib.metadata import version
from typing import TYPE_CHECKING, Dict, Optional

import torch
import torch.nn.functional as F
from pkg_resources import packaging
from torch import Tensor, nn

from nemo.collections.nlp.parts.peft_config import LORA_CONFIG_TO_MCORE_MAP, get_target_modules
from nemo.utils import logging

te_version = packaging.version.Version(version("transformer-engine"))

if TYPE_CHECKING:
from megatron.core.models.gpt import MCoreGPTModel
from omegaconf import DictConfig


class NF4Weight(nn.Parameter):
def __new__(
cls,
data: torch.Tensor,
is_nf4_quantized: bool = False,
block_size: int = 64,
scale_block_size: int = 256,
):
self = torch.Tensor._make_subclass(cls, data, require_grad=False)
self._nf4_quantizer = None
self.is_nf4_quantized = is_nf4_quantized
self.block_size = block_size
self.scale_block_size = scale_block_size
return self

def quantize(self, device='cuda') -> torch.Tensor:
from modelopt.torch.quantization.nn import TensorQuantizer
from modelopt.torch.quantization.tensor_quant import QuantDescriptor

# initialize the quantizer
nf4_desc = QuantDescriptor(
num_bits=4,
block_sizes={-1: self.block_size, "scale_bits": 8, "scale_block_sizes": {-1: self.scale_block_size}},
fake_quant=False,
)
self._nf4_quantizer = TensorQuantizer(nf4_desc)

# quantize on GPU directly
nf4_tensor = self._nf4_quantizer(self.data.to(device))
self.quantized_data = nf4_tensor
self.is_nf4_quantized = True
return self

def dequantize(self):
assert self.is_nf4_quantized, "NF4 Tensor is not yet quantized, cannot dequantize."
return self._nf4_quantizer(self.quantized_data)

def cuda(self, device=None, non_blocking=False):
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)

def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

if device is not None and device.type == "cuda":
# Note: self.data remains on CPU. Only self.quantized_data is on GPU
return self.quantize() if not self.is_nf4_quantized else self
else:
return NF4Weight(
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
self.is_nf4_quantized,
self.block_size,
self.scale_block_size,
)

def __repr__(self, *, tensor_contents=None):
if self.is_nf4_quantized:
return f"NF4Weight(is_nf4_quantized=True, quantized_data={self.quantized_data}"
else:
return f"NF4Weight(is_nf4_quantized=False, data={self.data}"


class _LinearNF4(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, weight: NF4Weight):
ctx.nf4_weight = weight
return F.linear(input, weight.dequantize().to(input.device))

@staticmethod
def backward(ctx, grad_output):
weight: NF4Weight = ctx.nf4_weight
return grad_output @ weight.dequantize().to(grad_output.device), None


class NF4LinearWrapper(nn.Module):
"""
NF4 Linear Layer for QLoRA as introduced in `QLORA: Efficient Finetuning of Quantized LLMs <https://arxiv.org/abs/2305.14314>`_.
This wrapper module is instantiated in `on_load_checkpoint` and replaces TERowParallelLinear
Tensor Parallel is not supported.
Args:
bf16_linear_weight: Weight tensor in BF16 to wrap with NF4Weight
"""

def __init__(self, bf16_linear_weight: torch.Tensor):
super().__init__()

# quantize the weight upon initialization
self.weight = NF4Weight(bf16_linear_weight).cuda()

def forward(self, x: torch.Tensor):
"""
Args:
x (Tensor): input tensor with shape ``(..., in_dim)``
Returns:
Tensor: output tensor with shape ``(..., out_dim)``
"""
return _LinearNF4.apply(x, self.weight), None


class NF4LayerNormLinearWrapper(NF4LinearWrapper):
"""
Layernorm + NF4 Linear for QLoRA.
This class only combines the two modules for compatibility with TE's LayernormLinear layer, so that
the implementation for LoRA and QLoRA can share the same code path.
It does NOT fuse the two operations like TE does.
This wrapper module is instantiated in `on_load_checkpoint` and replaces TELayerNormColumnParallelLinear
Tensor Parallel is not supported.
Args:
bf16_linear_weight: Weight tensor in BF16 to wrap with NF4Weight
layer_norm_weight: layernorm weight tensor
layer_norm_bias: layernorm bias tensor, only if normalization is LayerNorm
normalization: Same as TELayerNormColumnParallelLinear.config.normalization
zero_centered_gamma: Same as TELayerNormColumnParallelLinear.config.zero_centered_gamma
"""

def __init__(
self,
bf16_linear_weight: torch.Tensor,
layer_norm_weight: torch.Tensor,
layer_norm_bias: Optional[torch.Tensor],
normalization: str,
zero_centered_gamma: bool,
):
super().__init__(bf16_linear_weight)
self.layer_norm_weight = nn.Parameter(layer_norm_weight)
if normalization != "RMSNorm":
self.layer_norm_bias = nn.Parameter(layer_norm_bias)
else:
self.layer_norm_bias = None

self.zero_centered_gamma = zero_centered_gamma
self.normalization = normalization
self.layer_norm_fn = self._create_layer_norm_fn()
self.te_return_bias = False

def _create_layer_norm_fn(self):
'''
create the layernorm function signature in TE. Assume this layer is already running without gradients
since this is for QLoRA.
'''
if self.normalization == 'LayerNorm':
from transformer_engine.pytorch.module.layernorm import _LayerNorm

layer_norm_fn = _LayerNorm.apply
elif self.normalization == 'RMSNorm':
from transformer_engine.pytorch.module.rmsnorm import _RMSNorm

layer_norm_fn = _RMSNorm.apply
else:
raise ValueError("Unsupported normalization type:", self.normalization)

return layer_norm_fn

def forward(self, x):
layer_norm_args = [
x, # inp
self.layer_norm_weight,
1e-5, # eps,
0, # fwd_rmsnorm_sm_margin,
0, # bwd_rmsnorm_sm_margin,
self.zero_centered_gamma,
True, # is_grad_enabled,
x.dtype, # activation_dtype,
]
if te_version >= packaging.version.Version("1.6"):
layer_norm_args.insert(5, 0) # inf_rmsnorm_sm_margin
if self.normalization == "LayerNorm":
layer_norm_args.insert(2, self.layer_norm_bias)
layernorm_output = self.layer_norm_fn(*layer_norm_args)
linear_output = _LinearNF4.apply(layernorm_output, self.weight)
return (linear_output, layernorm_output), None


def qlora_load_model(model: 'MCoreGPTModel', model_cfg: 'DictConfig', checkpoint: Dict[str, Tensor]):
# swap linear layer and cast weight to nf4
qlora_targets = [
LORA_CONFIG_TO_MCORE_MAP[x] for x in get_target_modules(model_cfg.peft.lora_tuning, default=('all',))
]

# if not load directly on device, need to load the rest of the model
# this block should only load word_embeddings, final_layernorm and output_layer weights.
if not model_cfg.get("dist_ckpt_load_on_device", True):
checkpoint_state_dict = {}
for key, value in checkpoint.items():
if not any(qlora_target in key for qlora_target in qlora_targets):
checkpoint_state_dict[key.replace('model.', '')] = value
model.load_state_dict(checkpoint_state_dict, strict=False)

def replace_linear(module: nn.Module, prefix=""):
for name, child in module.named_children():
if name in qlora_targets:
bf16_weight = checkpoint[f"{prefix}.{name}.weight"]
logging.info(f'QLoRA: Quantizing linear layer: {prefix}.{name}')
if name in ['linear_proj', 'linear_fc2']:
setattr(module, name, NF4LinearWrapper(bf16_weight))
else: # name in ['linear_qkv', 'linear_fc1']
layer_norm_weight = checkpoint[f"{prefix}.{name}.layer_norm_weight"]
layer_norm_bias = checkpoint.get(f"{prefix}.{name}.layer_norm_bias", None)
normalization = module.config.normalization
zero_centered_gamma = module.config.layernorm_zero_centered_gamma
setattr(
module,
name,
NF4LayerNormLinearWrapper(
bf16_weight, layer_norm_weight, layer_norm_bias, normalization, zero_centered_gamma
),
)
else:
replace_linear(child, prefix=f"{prefix}.{name}")

replace_linear(model, prefix="model")
10 changes: 9 additions & 1 deletion nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,15 @@ def on_load_checkpoint(self, checkpoint) -> None:
self.model[i].module.load_state_dict(checkpoint[f'model{i}'], strict=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
else:
super().on_load_checkpoint(checkpoint)
cfg_peft = self.cfg.get('peft', None)
if cfg_peft and cfg_peft['peft_scheme'] == 'qlora':
from nemo.collections.nlp.modules.common.megatron.adapters.qlora import qlora_load_model

qlora_load_model(
self.model.module if self.megatron_amp_O2 else self.model, self.cfg, checkpoint['state_dict']
)
else:
super().on_load_checkpoint(checkpoint)

@classmethod
def merge_cfg_with(cls, path: str, cfg: DictConfig) -> DictConfig:
Expand Down
16 changes: 14 additions & 2 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,16 @@
"all": "all",
}

LORA_CONFIG_TO_MCORE_MAP = {
"attention_qkv": "linear_qkv",
"attention_dense": "linear_proj",
"mlp_fc1": "linear_fc1",
"mlp_fc2": "linear_fc2",
}


def get_target_modules(lora_cfg):
original_target_modules = lora_cfg.get("target_modules", ["attention_qkv"])
def get_target_modules(lora_cfg, default=("attention_qkv",)):
original_target_modules = lora_cfg.get("target_modules", default)
target_modules = []

for module in original_target_modules:
Expand Down Expand Up @@ -251,6 +258,10 @@ def _create_lora_config(
return adapter_cfg


class QLoraPEFTConfig(LoraPEFTConfig):
pass


class IA3PEFTConfig(PEFTConfig):
def __init__(self, cfg):
mlp_infused_adapter_cfg = MLPInfusedAdapterConfig(
Expand Down Expand Up @@ -360,6 +371,7 @@ def __init__(self, cfg):
"ia3": IA3PEFTConfig,
"ptuning": PtuningPEFTConfig,
"lora": LoraPEFTConfig,
"qlora": QLoraPEFTConfig,
"selective": SelectivePEFTConfig,
'none': None,
None: None,
Expand Down
Loading

0 comments on commit ceffb49

Please sign in to comment.