Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning message for beta and gamma parameters #31654

Merged
merged 5 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@

XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning."


if is_accelerate_available():
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
Expand Down Expand Up @@ -662,8 +664,10 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
for key in state_dict.keys():
new_key = None
if "gamma" in key:
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
new_key = key.replace("gamma", "weight")
if "beta" in key:
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
Expand Down Expand Up @@ -807,8 +811,10 @@ def _load_state_dict_into_meta_model(
for key in state_dict.keys():
new_key = None
if "gamma" in key:
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
new_key = key.replace("gamma", "weight")
if "beta" in key:
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
Expand Down
51 changes: 51 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,57 @@ def test_model_from_pretrained_from_mlx(self):
outputs_from_saved = new_model(input_ids)
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))

def test_warning_for_beta_gamma_parameters(self):
class TestModelGamma(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.gamma_param = nn.Parameter(torch.ones(10))
self.post_init()

def forward(self):
return self.gamma_param.sum()

logger = logging.get_logger("transformers.modeling_utils")
config = PretrainedConfig()
warning_msg_gamma = "A parameter name that contains `gamma` will be renamed internally"
model = TestModelGamma(config)
Comment on lines +1514 to +1515
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More importantly, we should check that the parameter is renamed as well

Copy link
Contributor Author

@OmarManzoor OmarManzoor Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this out and it seems that the parameter is not renamed at all. Basically when we load the model using from_pretrained it seems that the parameter is still present with the name gamma_param.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't rename the value in the model, but will rename the value in the state_dict, I believe. Could you dive into the loading logic and verify what's happening?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried updating the tests. Could you kindly have a look?


with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl1:
_, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)

missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn(warning_msg_gamma, cl1.out)
self.assertIn("gamma_param", missing_keys)
self.assertIn("weight_param", unexpected_keys)

class TestModelBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.beta_param = nn.Parameter(torch.ones(10))
self.post_init()

def forward(self):
return self.beta_param.sum()

warning_msg_beta = "A parameter name that contains `beta` will be renamed internally"
model = TestModelBeta(config)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as cl2:
_, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)

missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn(warning_msg_beta, cl2.out)
self.assertIn("beta_param", missing_keys)
self.assertIn("bias_param", unexpected_keys)


@slow
@require_torch
Expand Down