Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with eos_token_id in LanguageModel.generate() when using local models #177

Open
zzn-nzz opened this issue Jul 20, 2024 · 2 comments
Open

Comments

@zzn-nzz
Copy link

zzn-nzz commented Jul 20, 2024

Description

When running the following code calling generate method using different models (e.g., Mistral-7B-Instruct-v0.2 and meta-llama-3-8B):

from transformers import AutoModelForCausalLM, AutoTokenizer
from nnsight import LanguageModel

model_path = "Mistral-7B-Instruct-v0.2"
base_model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
llm = LanguageModel(base_model, tokenizer=tokenizer)
llm.generate(prompt="Hi", max_new_tokens=1000)

I encountered the following warning, and the generation process does not complete:

Setting `pad_token_id` to `eos_token_id`:FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64) for open-end generation.
`eos_token_id` should consist of positive integers, but is FakeTensor(..., device='cuda:0', size=(1,), dtype=torch.int64). Your generation will not stop until the maximum length is reached. Depending on other flags, it may even crash.

with the execution ending in failure.

This issue occurs with all models I've tried, including Mistral-7B-Instruct-v0.2, meta-llama-3-8B, and gemma-2-9b.

nnsight<0.2 works but it does not work with newest transformers library.

Environment:

torch==2.3.1
nnsight==0.2.19
transformers==4.42.4
Python Version: Python 3.8.19

Reproducibility:

The issue occurs consistently every time the code is run with the mentioned models.

Full Traceback:

Setting `pad_token_id` to `eos_token_id`:FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64) for open-end generation.
`eos_token_id` should consist of positive integers, but is FakeTensor(..., device='cuda:0', size=(1,), dtype=torch.int64). Your generation will not stop until the maximum length is reached. Depending on other flags, it may even crash.
Traceback (most recent call last):
  File "play.py", line 82, in <module>
    response = get_response(now_agent, oppo_agent, goals, history, i, save_reps=True)
  File "play.py", line 46, in get_response
    resp = llm.generate_response(prompt)
  File "/mnt/cvda/zhangzhining/new_desktop/ibs/LM_hf.py", line 30, in generate_response
    pass
  File "/DATA/disk1/zhangzhining/anaconda3/envs/LLM/lib/python3.8/site-packages/nnsight/contexts/Runner.py", line 40, in __exit__
    raise exc_val
  File "/mnt/cvda/zhangzhining/new_desktop/ibs/LM_hf.py", line 29, in generate_response
    with generator.invoke(prompt) as invoker:
  File "/DATA/disk1/zhangzhining/anaconda3/envs/LLM/lib/python3.8/site-packages/nnsight/contexts/Invoker.py", line 67, in __enter__
    self.tracer._model._execute(
  File "/DATA/disk1/zhangzhining/anaconda3/envs/LLM/lib/python3.8/site-packages/nnsight/models/mixins/Generation.py", line 19, in _execute
    return self._execute_generate(prepared_inputs, *args, **kwargs)
  File "/DATA/disk1/zhangzhining/anaconda3/envs/LLM/lib/python3.8/site-packages/nnsight/models/LanguageModel.py", line 288, in _execute_generate
    output = self._model.generate(
  File "/DATA/disk1/zhangzhining/anaconda3/envs/LLM/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/DATA/disk1/zhangzhining/anaconda3/envs/LLM/lib/python3.8/site-packages/transformers/generation/utils.py", line 1914, in generate
    result = self._sample(
  File "/DATA/disk1/zhangzhining/anaconda3/envs/LLM/lib/python3.8/site-packages/transformers/generation/utils.py", line 2648, in _sample
    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  File "/DATA/disk1/zhangzhining/anaconda3/envs/LLM/lib/python3.8/site-packages/transformers/models/mistral/modeling_mistral.py", line 1273, in prepare_inputs_for_generation
    input_ids = input_ids[:, past_length:]
  File "/DATA/disk1/zhangzhining/anaconda3/envs/LLM/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1780, in __torch_function__
    return func(*args, **kwargs)
  File "/DATA/disk1/zhangzhining/anaconda3/envs/LLM/lib/python3.8/site-packages/torch/fx/experimental/sym_node.py", line 352, in guard_int
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
  File "/DATA/disk1/zhangzhining/anaconda3/envs/LLM/lib/python3.8/site-packages/torch/fx/experimental/recording.py", line 231, in wrapper
    return fn(*args, **kwargs)
  File "/DATA/disk1/zhangzhining/anaconda3/envs/LLM/lib/python3.8/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4138, in evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

Potential framework code culprit (scroll up for full backtrace):
  File "/DATA/disk1/zhangzhining/anaconda3/envs/LLM/lib/python3.8/site-packages/transformers/models/mistral/modeling_mistral.py", line 1273, in prepare_inputs_for_generation
    input_ids = input_ids[:, past_length:]

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

Please let me know if you need any more information to reproduce or diagnose this issue.

@OmarMohammed88
Copy link

@zzn-nzz Have u found a solution for this issue?

@JadenFiotto-Kaufman
Copy link
Member

@OmarMohammed88 Does this work?

from transformers import AutoModelForCausalLM, AutoTokenizer
from nnsight import LanguageModel

model_path = "Mistral-7B-Instruct-v0.2"
base_model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
llm = LanguageModel(base_model, tokenizer=tokenizer, dispatch=True)

with llm.generate("Hi", max_new_tokens=1000, scan=False, validate=False):
    output = model.generator.output.save()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants