Skip to content

Commit

Permalink
Add warning message for beta and gamma parameters (huggingface#31654)
Browse files Browse the repository at this point in the history
* Add warning message for  and  parameters

* Fix when the warning is raised

* Formatting changes

* Improve testing and remove duplicated warning from _fix_key
  • Loading branch information
OmarManzoor authored and MHRDYN7 committed Jul 23, 2024
1 parent 6d252c0 commit 4e52922
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@

XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning."


if is_accelerate_available():
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
Expand Down Expand Up @@ -662,8 +664,10 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
for key in state_dict.keys():
new_key = None
if "gamma" in key:
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
new_key = key.replace("gamma", "weight")
if "beta" in key:
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
Expand Down Expand Up @@ -807,8 +811,10 @@ def _load_state_dict_into_meta_model(
for key in state_dict.keys():
new_key = None
if "gamma" in key:
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
new_key = key.replace("gamma", "weight")
if "beta" in key:
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
Expand Down
51 changes: 51 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,57 @@ def test_model_from_pretrained_from_mlx(self):
outputs_from_saved = new_model(input_ids)
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))

def test_warning_for_beta_gamma_parameters(self):
class TestModelGamma(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.gamma_param = nn.Parameter(torch.ones(10))
self.post_init()

def forward(self):
return self.gamma_param.sum()

logger = logging.get_logger("transformers.modeling_utils")
config = PretrainedConfig()
warning_msg_gamma = "A parameter name that contains `gamma` will be renamed internally"
model = TestModelGamma(config)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl1:
_, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)

missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn(warning_msg_gamma, cl1.out)
self.assertIn("gamma_param", missing_keys)
self.assertIn("weight_param", unexpected_keys)

class TestModelBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.beta_param = nn.Parameter(torch.ones(10))
self.post_init()

def forward(self):
return self.beta_param.sum()

warning_msg_beta = "A parameter name that contains `beta` will be renamed internally"
model = TestModelBeta(config)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl2:
_, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)

missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn(warning_msg_beta, cl2.out)
self.assertIn("beta_param", missing_keys)
self.assertIn("bias_param", unexpected_keys)


@slow
@require_torch
Expand Down

0 comments on commit 4e52922

Please sign in to comment.