Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add alpha scaling to lora #8248

Merged
merged 13 commits into from
Feb 25, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ model:

lora_tuning:
adapter_dim: 32
alpha: ${model.peft.lora_tuning.adapter_dim}
adapter_dropout: 0.0
column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
gather_output: bool = True,
dropout: float = 0.0,
model_parallel_config: Optional[ModelParallelConfig] = None,
alpha: float | None = None,
**kwargs,
):
super().__init__()
Expand All @@ -142,6 +143,7 @@ def __init__(
self.activation = activation_registry[activation]()
self.norm_position = norm_position
self.dim = dim
self.alpha = alpha if alpha is not None else self.dim

# megatron_gpt_peft_models will provide this arg, but deprecated ones do not.
# in case this arg is not provided, use the dummy default config.
Expand Down Expand Up @@ -235,6 +237,8 @@ def forward(self, x):
if self.dropout is not None:
x = self.dropout(x)

x = x * (self.alpha / self.dim)

return x


Expand All @@ -250,6 +254,7 @@ class ParallelLinearAdapterConfig(AdapterConfig):
row_init_method: str = 'zero'
gather_output: bool = True
dropout: float = 0.0
alpha: float | None = None
network_alpha: int | None = None
_target_: str = "{0}.{1}".format(ParallelLinearAdapter.__module__, ParallelLinearAdapter.__name__)

Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self, cfg):
if num_query_groups is None:
num_query_groups = cfg.num_attention_heads
qkv_projection_size = projection_size + (2 * kv_channels * num_query_groups)
alpha = lora_cfg.get("alpha", lora_cfg.adapter_dim)

config_args = {
"in_features": cfg.hidden_size,
Expand All @@ -86,6 +87,7 @@ def __init__(self, cfg):
"row_init_method": lora_cfg.get("row_init_method", "zero"),
"gather_output": False,
"dropout": lora_cfg.adapter_dropout,
"alpha": alpha,
}

if lora_cfg.weight_tying:
Expand Down
Loading