Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

truncation cause error while using bert_score #2673

Closed
zhoubay opened this issue Aug 3, 2024 · 4 comments · Fixed by #2776
Closed

truncation cause error while using bert_score #2673

zhoubay opened this issue Aug 3, 2024 · 4 comments · Fixed by #2776
Assignees
Labels
bug / fix Something isn't working question Further information is requested
Milestone

Comments

@zhoubay
Copy link

zhoubay commented Aug 3, 2024

While using bert_score, I'm trying to use a model that restricted to 512 tokens, i.e. BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext. But there's an error implying I'm trying to use larger tensor than 512 tokens, even though I set the max_length as something lower than 512, that is 500.

  File "/root/miniconda3/envs/pl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/pl/lib/python3.9/site-packages/torchmetrics/metric.py", line 236, in forward
    self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
  File "/root/miniconda3/envs/pl/lib/python3.9/site-packages/torchmetrics/metric.py", line 303, in _forward_reduce_state_update
    batch_val = self.compute()
  File "/root/miniconda3/envs/pl/lib/python3.9/site-packages/torchmetrics/metric.py", line 532, in wrapped_func
    value = compute(*args, **kwargs)
  File "/root/miniconda3/envs/pl/lib/python3.9/site-packages/torchmetrics/text/bert.py", line 221, in compute
    return bert_score(
  File "/root/miniconda3/envs/pl/lib/python3.9/site-packages/torchmetrics/functional/text/bert.py", line 414, in bert_score
    preds_embeddings, preds_idf_scale = _get_embeddings_and_idf_scale(
  File "/root/miniconda3/envs/pl/lib/python3.9/site-packages/torchmetrics/functional/text/bert.py", line 98, in _get_embeddings_and_idf_scale
    out = model(batch["input_ids"], batch["attention_mask"], output_hidden_states=True)
  File "/root/miniconda3/envs/pl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/pl/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 979, in forward
    buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
RuntimeError: The expanded size of the tensor (523) must match the existing size (512) at non-singleton dimension 1.  Target sizes: [1, 523].  Tensor sizes: [1, 512]

The snippets here:

pred = ["abc "*2000]
gt = ["def "*2000]
bert_score = BERTScore(model_name_or_path="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext", device="cuda:0", max_length=500, lang="en")
result = bert_score(pred, gt)

After detailed checking codes, the culprit is the following line, which setting truncation as False.

truncation=False,

May I ask why using this? Or could there be a flag that control whether truncate or not?

Copy link

github-actions bot commented Aug 3, 2024

Hi! thanks for your contribution!, great first issue!

@Borda Borda changed the title truncation cause error while using bert_score truncation cause error while using bert_score Aug 5, 2024
@Borda
Copy link
Member

Borda commented Aug 5, 2024

May I ask why using this? Or could there be a flag that control whether truncate or not?

that is a good question @stancld?
from my perspective, we can make it an argument... @zhoubay could you pls send a PR with a proposed implementation?

In the context of language models (LLMs), truncation refers to the process of shortening the input text to fit within the model's maximum token limit. Language models like GPT-4 have a maximum number of tokens (words, subwords, or characters) they can process in a single input. If the input text exceeds this limit, truncation is applied to ensure the input does not surpass the model's capacity.

@zhoubay
Copy link
Author

zhoubay commented Aug 6, 2024

May I ask why using this? Or could there be a flag that control whether truncate or not?

that is a good question @stancld? from my perspective, we can make it an argument... @zhoubay could you pls send a PR with a proposed implementation?

In the context of language models (LLMs), truncation refers to the process of shortening the input text to fit within the model's maximum token limit. Language models like GPT-4 have a maximum number of tokens (words, subwords, or characters) they can process in a single input. If the input text exceeds this limit, truncation is applied to ensure the input does not surpass the model's capacity.

Sure, my pleasure! I'll get down to it as soon as possible!

@Borda Borda added bug / fix Something isn't working question Further information is requested labels Aug 21, 2024
@SkafteNicki
Copy link
Member

@zhoubay how is it going here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants