Skip to content

Commit

Permalink
Jiaqiz/option to disable adapters & merge all lora layers (#8029)
Browse files Browse the repository at this point in the history
* Added LoRA support for the Dense layer of Attention

* Added LoRA MLP support to MCore and NeMo models.

* Change LoRA config default to QKV.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed bug with ddp training.

* use adapter only when it is enabled

Signed-off-by: jiaqi zeng <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lora merge script (#8113)

Signed-off-by: Chen Cui <[email protected]>
Co-authored-by: Adi Renduchintala <[email protected]>

* add peft ckpt to nemo

Signed-off-by: Jiaqi Zeng <[email protected]>

* merge lora weights for all layers, mcore only

Signed-off-by: Jiaqi Zeng <[email protected]>

* support/fix cpu initialization

Signed-off-by: Chen Cui <[email protected]>

* add example usage

Signed-off-by: Chen Cui <[email protected]>

* fix TP due to distributed checkpoint

Signed-off-by: Chen Cui <[email protected]>

* updating the logic of merging lora weights for all layers, mcore only

Signed-off-by: Jiaqi Zeng <[email protected]>

* MCoreMixin chages.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* merge in fp32 then cast back

Signed-off-by: Jiaqi Zeng <[email protected]>

* remove ckpt to nemo

Signed-off-by: Jiaqi Zeng <[email protected]>

* fix import

Signed-off-by: Jiaqi Zeng <[email protected]>

---------

Signed-off-by: jiaqi zeng <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Jiaqi Zeng <[email protected]>
Co-authored-by: Tugrul Konuk <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adi Renduchintala <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
  • Loading branch information
5 people authored and yaoyu-33 committed Feb 26, 2024
1 parent 6f33603 commit 7b41390
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
# LoRA logic
if self.is_adapter_available():
lora_kqv_adapter = self.get_adapter_module(AdapterName.LORA_KQV_ADAPTER)
if lora_kqv_adapter:
if lora_kqv_adapter and self.adapter_cfg[AdapterName.LORA_KQV_ADAPTER]['enabled']:
if isinstance(self.linear_qkv, TELayerNormColumnParallelLinear):
lora_mixed_qkv = lora_kqv_adapter(layernorm_output)
elif isinstance(self.linear_qkv, TEColumnParallelLinear):
Expand Down Expand Up @@ -138,11 +138,11 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
if self.is_adapter_available():
key_infused_adapter = self.get_adapter_module(AdapterName.KEY_INFUSED)
value_infused_adapter = self.get_adapter_module(AdapterName.VALUE_INFUSED)
if key_infused_adapter:
if key_infused_adapter and self.adapter_cfg[AdapterName.KEY_INFUSED]['enabled']:
assert value_infused_adapter is not None, "Expected value_infused_adapter not found!"
kls = key.shape
key = key_infused_adapter(key.reshape(kls[0], kls[1], -1)).reshape(kls).to(query.dtype)
if value_infused_adapter:
if value_infused_adapter and self.adapter_cfg[AdapterName.VALUE_INFUSED]['enabled']:
assert key_infused_adapter is not None, "Expected key_infused_adapter not found!"
vls = value.shape
value = value_infused_adapter(value.reshape(vls[0], vls[1], -1)).reshape(vls).to(query.dtype)
Expand Down Expand Up @@ -229,7 +229,7 @@ def forward(
# LoRA logic
if self.is_adapter_available():
lora_linear_proj_adapter = self.get_adapter_module(AdapterName.LORA_DENSE_ATTENTION_ADAPTER)
if lora_linear_proj_adapter:
if lora_linear_proj_adapter and self.adapter_cfg[AdapterName.LORA_DENSE_ATTENTION_ADAPTER]['enabled']:
lora_output = lora_linear_proj_adapter(core_attn_out)
output = output + lora_output

Expand All @@ -252,7 +252,7 @@ def forward(self, hidden_states):
# LoRA logic
if self.is_adapter_available():
lora_linear_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER)
if lora_linear_fc1_adapter:
if lora_linear_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']:
lora_output = lora_linear_fc1_adapter(hidden_states)
intermediate_parallel = intermediate_parallel + lora_output

Expand Down Expand Up @@ -283,7 +283,7 @@ def glu(x):
# LoRA logic
if self.is_adapter_available():
lora_linear_fc2_adapter = self.get_adapter_module(AdapterName.LORA_4HtoH_ADAPTER)
if lora_linear_fc2_adapter:
if lora_linear_fc2_adapter and self.adapter_cfg[AdapterName.LORA_4HtoH_ADAPTER]['enabled']:
lora_output = lora_linear_fc2_adapter(intermediate_parallel)
output = output + lora_output
return output, output_bias
Expand All @@ -303,7 +303,9 @@ def forward(self, input_ids, position_ids):
_sq, _bs, _hs = encoder_input.size()
ptuning_adapter = self.get_adapter_module(AdapterName.PTUNING_ADAPTER)
v = ptuning_adapter.virtual_tokens
if ptuning_adapter and _sq >= v: # The sequence should be longer the v to insert virtual embeddings.
if (
ptuning_adapter and self.adapter_cfg[AdapterName.PTUNING_ADAPTER]['enabled'] and _sq >= v
): # The sequence should be longer the v to insert virtual embeddings.
virtual_embeddings = ptuning_adapter(_bs)
encoder_input = encoder_input[
v:, :, :
Expand Down Expand Up @@ -349,7 +351,7 @@ def forward(
# adapter logic
if self.is_adapter_available():
adapter_1 = self.get_adapter_module(AdapterName.PRE_ATTN_ADAPTER)
if adapter_1:
if adapter_1 and self.adapter_cfg[AdapterName.PRE_ATTN_ADAPTER]['enabled']:
attention_output, bias = attention_output_with_bias
attention_output = (
adapter_1(attention_output) + attention_output
Expand Down Expand Up @@ -399,7 +401,7 @@ def forward(
# adapter logic
if self.is_adapter_available():
adapter_2 = self.get_adapter_module(AdapterName.POST_ATTN_ADAPTER)
if adapter_2:
if adapter_2 and self.adapter_cfg[AdapterName.POST_ATTN_ADAPTER]['enabled']:
mlp_output, bias = mlp_output_with_bias
mlp_output = adapter_2(mlp_output) + mlp_output # simple adapter call with residual connection
mlp_output_with_bias = (mlp_output, bias)
Expand Down
12 changes: 6 additions & 6 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def forward(
mixed_x_layer, _ = self.query_key_value(hidden_states)
if self.is_adapter_available():
lora_kqv_adapter = self.get_adapter_module(AdapterName.LORA_KQV_ADAPTER)
if lora_kqv_adapter:
if lora_kqv_adapter and self.adapter_cfg[AdapterName.LORA_KQV_ADAPTER]['enabled']:
lora_mixed_x_layer = lora_kqv_adapter(hidden_states)
mixed_x_layer = mixed_x_layer + lora_mixed_x_layer

Expand All @@ -437,7 +437,7 @@ def forward(
mixed_kv_layer, _ = self.key_value(encoder_output)
if self.is_adapter_available():
lora_kv_adapter = self.get_adapter_module(AdapterName.LORA_KV_ADAPTER)
if lora_kv_adapter:
if lora_kv_adapter and self.adapter_cfg[AdapterName.LORA_KV_ADAPTER]['enabled']:
lora_mixed_kv_layer = lora_kv_adapter(encoder_output)
mixed_kv_layer = mixed_kv_layer + lora_mixed_kv_layer

Expand All @@ -459,7 +459,7 @@ def forward(
query_layer, _ = self.query(hidden_states)
if self.is_adapter_available():
lora_q_adapter = self.get_adapter_module(AdapterName.LORA_Q_ADAPTER)
if lora_q_adapter:
if lora_q_adapter and self.adapter_cfg[AdapterName.LORA_Q_ADAPTER]['enabled']:
lora_q_layer = lora_q_adapter(hidden_states)
query_layer = query_layer + lora_q_layer
# [sq, b, hp] --> [sq, b, np, hn]
Expand All @@ -472,11 +472,11 @@ def forward(
if self.is_adapter_available():
key_infused_adapter = self.get_adapter_module(AdapterName.KEY_INFUSED)
value_infused_adapter = self.get_adapter_module(AdapterName.VALUE_INFUSED)
if key_infused_adapter:
if key_infused_adapter and self.adapter_cfg[AdapterName.KEY_INFUSED]['enabled']:
assert value_infused_adapter is not None, "Expected value_infused_adapter not found!"
kls = key_layer.shape
key_layer = key_infused_adapter(key_layer.reshape(kls[0], kls[1], -1)).reshape(kls)
if value_infused_adapter:
if value_infused_adapter and self.adapter_cfg[AdapterName.VALUE_INFUSED]['enabled']:
assert key_infused_adapter is not None, "Expected key_infused_adapter not found!"
vls = value_layer.shape
value_layer = value_infused_adapter(value_layer.reshape(vls[0], vls[1], -1)).reshape(vls)
Expand Down Expand Up @@ -574,7 +574,7 @@ def forward(
output, bias = self.dense(context_layer)
if self.is_adapter_available():
lora_dense_adapter = self.get_adapter_module(AdapterName.LORA_DENSE_ATTENTION_ADAPTER)
if lora_dense_adapter:
if lora_dense_adapter and self.adapter_cfg[AdapterName.LORA_DENSE_ATTENTION_ADAPTER]['enabled']:
lora_dense_output = lora_dense_adapter(context_layer)
output = output + lora_dense_output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,9 @@ def forward(
_sq, _bs, _hs = encoder_input.size()
ptuning_adapter = self.get_adapter_module(AdapterName.PTUNING_ADAPTER)
v = ptuning_adapter.virtual_tokens
if ptuning_adapter and _sq >= v: # The sequence should be longer the v to insert virtual embeddings.
if (
ptuning_adapter and self.adapter_cfg[AdapterName.PTUNING_ADAPTER]['enabled'] and _sq >= v
): # The sequence should be longer the v to insert virtual embeddings.
virtual_embeddings = ptuning_adapter(_bs)
encoder_input = encoder_input[
v:, :, :
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/modules/common/megatron/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def forward(self, hidden_states):
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.is_adapter_available():
lora_dense_h_to_4h_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER)
if lora_dense_h_to_4h_adapter:
if lora_dense_h_to_4h_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']:
lora_intermediate_parallel = lora_dense_h_to_4h_adapter(hidden_states)
intermediate_parallel = intermediate_parallel + lora_intermediate_parallel

Expand Down Expand Up @@ -270,7 +270,7 @@ def forward(self, hidden_states):
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
if self.is_adapter_available():
lora_dense_4h_to_h_adapter = self.get_adapter_module(AdapterName.LORA_4HtoH_ADAPTER)
if lora_dense_4h_to_h_adapter:
if lora_dense_4h_to_h_adapter and self.adapter_cfg[AdapterName.LORA_4HtoH_ADAPTER]['enabled']:
lora_output = lora_dense_4h_to_h_adapter(intermediate_parallel)
output = output + lora_output
return output, output_bias
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def forward(
ptuning_adapter = self.get_adapter_module(AdapterName.PTUNING_ADAPTER)
v = ptuning_adapter.virtual_tokens
if (
ptuning_adapter and _sq >= v
ptuning_adapter and self.adapter_cfg[AdapterName.PTUNING_ADAPTER]['enabled'] and _sq >= v
): # The sequence should be longer the v to insert virtual embeddings.
virtual_embeddings = ptuning_adapter(_bs)
enc_input = enc_input[
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def forward(

if self.is_adapter_available():
adapter_1 = self.get_adapter_module(AdapterName.PRE_ATTN_ADAPTER)
if adapter_1:
if adapter_1 and self.adapter_cfg[AdapterName.PRE_ATTN_ADAPTER]['enabled']:
attention_output = (
adapter_1(attention_output) + attention_output
) # simple adapter call with residual connection
Expand Down Expand Up @@ -615,7 +615,7 @@ def forward(
if self.is_adapter_available():
# TODO: (@adithyre) was able to move adapter_2 back to the end of the transformer after ptl 1.7 update.
adapter_2 = self.get_adapter_module(AdapterName.POST_ATTN_ADAPTER)
if adapter_2:
if adapter_2 and self.adapter_cfg[AdapterName.POST_ATTN_ADAPTER]['enabled']:
mlp_output = adapter_2(mlp_output) + mlp_output # simple adapter call with residual connection

residual = layernorm_input
Expand Down
71 changes: 57 additions & 14 deletions scripts/nlp_language_modeling/merge_lora_weights/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def load_lora(lora_nemo, tp):
def fix_for_O2(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k.replace('model.language_model', 'model.module.language_model')] = v
if "model.module." not in k:
new_state_dict[k.replace('model.', 'model.module.')] = v
return new_state_dict


Expand All @@ -110,22 +111,61 @@ def merge(
curr_rank: current tp rank of the base model which is being merged with Lora.
mcore: whether the model uses megatron core.
"""

for nl in range(num_layers):
if mcore:
key_self_attn_kqv = f'model.decoder.layers.{nl}.self_attention.linear_qkv.weight'
key_lora_in = f'model.decoder.layers.{nl}.self_attention.adapter_layer.lora_kqv_adapter.linear_in.weight'
key_lora_out = f'model.decoder.layers.{nl}.self_attention.adapter_layer.lora_kqv_adapter.linear_out.weight'
else:
mcore_layer_to_lora = {}
mcore_layer_to_lora["attention_qkv"] = {
"base_model_layer": "self_attention.linear_qkv.weight",
"lora_in": "self_attention.adapter_layer.lora_kqv_adapter.linear_in.weight",
"lora_out": "self_attention.adapter_layer.lora_kqv_adapter.linear_out.weight",
}
mcore_layer_to_lora["attention_dense"] = {
"base_model_layer": "self_attention.linear_proj.weight",
"lora_in": "self_attention.adapter_layer.lora_dense_attention_adapter.linear_in.weight",
"lora_out": "self_attention.adapter_layer.lora_dense_attention_adapter.linear_out.weight",
}
mcore_layer_to_lora["mlp_fc1"] = {
"base_model_layer": "mlp.linear_fc1.weight",
"lora_in": "mlp.adapter_layer.lora_hto4h_adapter.linear_in.weight",
"lora_out": "mlp.adapter_layer.lora_hto4h_adapter.linear_out.weight",
}
mcore_layer_to_lora["mlp_fc2"] = {
"base_model_layer": "mlp.linear_fc2.weight",
"lora_in": "mlp.adapter_layer.lora_4htoh_adapter.linear_in.weight",
"lora_out": "mlp.adapter_layer.lora_4htoh_adapter.linear_out.weight",
}

if mcore:
for nl in range(num_layers):
for key in mcore_layer_to_lora.keys():
key_base = f'model.decoder.layers.{nl}.{mcore_layer_to_lora[key]["base_model_layer"]}'
key_lora_in = f'model.decoder.layers.{nl}.{mcore_layer_to_lora[key]["lora_in"]}'
key_lora_out = f'model.decoder.layers.{nl}.{mcore_layer_to_lora[key]["lora_out"]}'
if key_lora_in in lora_state_dict[0] and key_lora_out in lora_state_dict[0]:
if key in ["attention_qkv", 'mlp_fc1']:
wt_lora_in = torch.cat([lora_state_dict[_tp][key_lora_in] for _tp in range(tp)], dim=0).float()
else:
wt_lora_in = torch.cat([lora_state_dict[_tp][key_lora_in] for _tp in range(tp)], dim=1).float()

wt_lora_out = torch.cat([lora_state_dict[_tp][key_lora_out] for _tp in range(tp)], dim=0).float()
wt_base = base_model_state_dict[key_base]
wt_lora = wt_lora_out @ wt_lora_in
base_model_state_dict[key_base] = (wt_base.float() + wt_lora.to(wt_base.device)).type_as(wt_base)
print(f'merging for weight {key_base}')
else:
logging.warning("Non-mcore model only supports merging lora weights for attention_qkv layers")
for nl in range(num_layers):
key_self_attn_kqv = f'model.language_model.encoder.layers.{nl}.self_attention.query_key_value.weight'
key_lora_in = f'model.language_model.encoder.layers.{nl}.self_attention.adapter_layer.lora_kqv_adapter.linear_in.weight'
key_lora_out = f'model.language_model.encoder.layers.{nl}.self_attention.adapter_layer.lora_kqv_adapter.linear_out.weight'
wt_lora_in = torch.cat([lora_state_dict[_tp][key_lora_in] for _tp in range(tp)], dim=0)
wt_lora_out = torch.cat([lora_state_dict[_tp][key_lora_out] for _tp in range(tp)], dim=0)
wt_self_attn = base_model_state_dict[key_self_attn_kqv]
wt_lora = wt_lora_out @ wt_lora_in
base_model_state_dict[key_self_attn_kqv] = wt_self_attn + wt_lora.type_as(wt_self_attn)
print("merging for weight", key_self_attn_kqv)

wt_lora_in = torch.cat([lora_state_dict[_tp][key_lora_in] for _tp in range(tp)], dim=0).float()
wt_lora_out = torch.cat([lora_state_dict[_tp][key_lora_out] for _tp in range(tp)], dim=0).float()
wt_self_attn = base_model_state_dict[key_self_attn_kqv]
wt_lora = wt_lora_out @ wt_lora_in
base_model_state_dict[key_self_attn_kqv] = (
wt_self_attn.float() + wt_lora.to(wt_self_attn.device)
).type_as(wt_self_attn)
print("merging for weight", key_self_attn_kqv)

return base_model_state_dict


Expand Down Expand Up @@ -214,6 +254,9 @@ def main(cfg) -> None:
# load the merged_weights back into the base model, for this current rank.
if model.cfg.megatron_amp_O2:
merged_weights = fix_for_O2(merged_weights)
model.cfg.use_cpu_initialization = (
False # set it back to False otherwise the merged model won't be loaded properly for futher tuning
)
model.load_state_dict(merged_weights)

if cfg.trainer.accelerator != 'cpu' and model.global_rank == 0:
Expand Down

0 comments on commit 7b41390

Please sign in to comment.