Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support copies #32159

Closed
wants to merge 6 commits into from
Closed

support copies #32159

wants to merge 6 commits into from

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Jul 23, 2024

What does this PR do?

We can't copy the cache 😢 inheriting from module fixes this easily
This renders us unable to re-use prompts / system prompt like this:

import os, torch, copy
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
device = "cuda"
ckpt = "path-to-ckpt"
INITIAL_PROMPT = "From now on, you are going to answer all my questions with historical details. Make sure to always add a bit of french here and there, for style."

model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16)
model.to(device)

tokenizer = AutoTokenizer.from_pretrained(ckpt)

prompt_cache = DynamicCache()
inputs = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
prompt_cache = model(**inputs, past_key_values = prompt_cache).past_key_values


prompt = "Why are french people obsessed with french?"
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
past_key_values = copy.deepcopy(prompt_cache)
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20) 
response = tokenizer.batch_decode(outputs)[0]
print(response)
"""
"""

prompt = "What is the best city to swim in?"
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**new_inputs, past_key_values=copy.deepcopy(prompt_cache),max_new_tokens=20) 
response = tokenizer.batch_decode(outputs)[0]
print(response)
"""

@amyeroberts
Copy link
Collaborator

We can't copy the cache 😢

What kind of copying are we talking about here? Like cache.copy?

@gante
Copy link
Member

gante commented Jul 23, 2024

@amyeroberts copy.deepcopy

On main, without the fix, we get

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment.  If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https://github.com/pytorch/pytorch/pull/103001

Cache copying is needed to reuse the cache from the prompt. E.g. to run new prompts on top of the system prompt without spending compute on the system prompt.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vladfaust
Copy link

I'm sorry if it's not the right place to ask this question, but.

In Llama.cpp it's trivial to save and load state to/from disk to maintain the cache between sessions. Is it currently possible with Transformers, and if yes, could you please provide a minimal example or point to docs?

Cheers,

@gante
Copy link
Member

gante commented Aug 6, 2024

@vladfaust yes it is possible, but it requires custom code (i.e. you would need to store and restore the cache's tensors).

We will add a user-friendly API for that in the future :)

Copy link

@nirbenda nirbenda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ArthurZucker
Copy link
Collaborator Author

Ps this was actually already merged in #32168 so I'll close this one!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants