Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LlamaRMSNorm() Dtype Casting Error #30236

Closed
2 of 4 tasks
Ritz111 opened this issue Apr 13, 2024 · 4 comments
Closed
2 of 4 tasks

LlamaRMSNorm() Dtype Casting Error #30236

Ritz111 opened this issue Apr 13, 2024 · 4 comments

Comments

@Ritz111
Copy link

Ritz111 commented Apr 13, 2024

System Info

transformers==4.37.2

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

@ArthurZucker @younesbelkada
Hi~ I found a bug in the LlamaRMSNorm(nn.Module) (lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py)

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

On the last line, if the input_dtype is bfloat16, the return tensor will still be float32 because the self.weight has been initialized as float32. Thus the last line should be modified to:

return (self.weight * hidden_states).to(input_dtype)

Expected behavior

see above and looking forward to your reply~ Thank you

@younesbelkada
Copy link
Contributor

Hi @Ritz111
Thanks ! I think this is not a bug, see: #23535 for more details

@GuWei007
Copy link

why should class LlamaRMSNorm do ”hidden_states = hidden_states.to(torch.float32)“ ,why not flow the type promotion rules of PyToch ops

@GuWei007
Copy link

GuWei007 commented May 13, 2024

self.weight is bf16,hidden_states is fp32
I found that the dtype of these two methods are different.
method 1:
return (self.weight * hidden_states).to(input_dtype) # (bf16 * fp32).to(input_dtype)
method 2:
return self.weight * hidden_states.to(input_dtype) # bf16 * bf16

Copy link

github-actions bot commented Jun 6, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants