Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass device in Logits Processor's init #29804

Merged
merged 14 commits into from
Jun 4, 2024

Conversation

zucchini-nlp
Copy link
Member

What does this PR do?

This PR adds the ability to pass in device when initializing LogitsProcessors and is one more step towards compile compatibility.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@gante

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall notes before going to details:

  1. In the processors that take eos_token_id as input: see Generate: consistently handle special tokens as tensors #29788. In this PR, the special tokens are treated as tensors by default, solving most of the needed changes. I would rebase this PR on main after that PR is merged, as some of the changes here will become redundant :)
  2. On the processors that don't need to use device, such as TemperatureLogitsWarper -- let's not add unused arguments. Clean interfaces are important 🧼 (unless there are significant benefits from standardizing them)
  3. Let's not throw a warning when the device is not passed and tensors are initialized on CPU. A .to operation is not that expensive :)

@zucchini-nlp
Copy link
Member Author

  1. Cool, I did not notice that
    2 and 3. Okay, thought we need it for consistency like we had with other new args in public classes. Will remove it and rebase later

@zucchini-nlp
Copy link
Member Author

Not stale

@zucchini-nlp
Copy link
Member Author

This PR now can be reviewed. Rebased main and updated the changes. All the tests from RUN_SLOW=1 pytest tests/generation are passing on my end

@zucchini-nlp zucchini-nlp requested a review from gante May 9, 2024 20:22
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for improving generate :D

src/transformers/models/whisper/generation_whisper.py Outdated Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

zucchini-nlp commented May 22, 2024

@gante Ah I forgot whisper is encoder-decoder. Oke, now it infers device from one of the inputs passed by the user.

@huggingface huggingface deleted a comment from github-actions bot May 22, 2024
@ArthurZucker
Copy link
Collaborator

How could the bot come 🤣 anyways on it!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, not sure input_ids device is the always the best, and we need a small test to see which feature is enable by this potentially!

Comment on lines 146 to 148
if device is None:
device = "cpu"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd argue that we can just set it to "cpu" in the arg no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly for users who use/pass LogitsProcessor as a standalone kwarg, because 'generate()' takes care that device is not None.

I think we should raise warning for BC saying users to pass-in the device, but let's ask @gante if he's okay with it. If I am not misunderstanding, we shouldn't raise warnings 🤔

Let's not throw a warning when the device is not passed and tensors are initialized on CPU. A .to operation is not that expensive :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, don't think it's a problem to silently do this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Down to just default to CPU which was already the behaviour by default before this PR no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh my bad, didn't read carefully the first comment. Setting in the arg as default is better, right

My concern is that before this PR, we were placing these on scores.device during "_ [call]_ " , but anyway I still get lost at when to do BC deprecation and when to not do 😄

self.eos_token_id = self.eos_token_id.to(scores.device)

@@ -1700,7 +1737,7 @@ def generate(
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
device=inputs_tensor.device,
device=input_ids.device,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this required ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! I thought that it was me who changed to inputs_tensor and was trying to revert 😆 I'll revert it back, no difference whichever tensor we use here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be use self.device? or lm_head.device? (which is not always there but still)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to make sure dive placement on multi GPU works, might already be tested !

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll try to add a test if I can. But we can be quite sure device placement of inputs is the correct one, as discusses with @gante this PR recommends to use input's device and not model params in multiGPU setting

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. Any how LGTM

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you rebase your branch ? (format changes seems unrelated?)

@zucchini-nlp
Copy link
Member Author

Oke, rebased main and the unnecessary formatting is removed. Will merge as I guess we don't need to add warnings :)

@zucchini-nlp zucchini-nlp merged commit 83238ee into huggingface:main Jun 4, 2024
23 checks passed
zucchini-nlp added a commit to zucchini-nlp/transformers that referenced this pull request Jun 11, 2024
* add device in logits processor

* remove device when not needed

* codestyle

* tests

* forgot `melody` version

* Update src/transformers/models/whisper/generation_whisper.py

Co-authored-by: Joao Gante <[email protected]>

* codestyle

* updates

---------

Co-authored-by: Joao Gante <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants