Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Directly use the backbone functional graph for CausalLM.generate() #1862

Open
mattdangerw opened this issue Sep 22, 2024 · 0 comments
Open
Assignees
Labels
Gemma Gemma model specific issues team-created Issues created by Keras Hub team as part of development roadmap.

Comments

@mattdangerw
Copy link
Member

Currently, to support the extra inputs we need for generation (e.g. cache, index, encoder hidden states for seq2seq), we are using layers from our backbone class while disregarding the functional graph and layer connectivity of the backbone. See call_with_cache. If we were able to directly use the backbone graph for generation, we would support a lot more advanced generative use cases.

Keras recently added support for optional functional inputs. We should build on that by adding a number of optional inputs to our backbones (e.g. cache, cache_index, token_positions, attention_mask). This would allow customization in a lot of directions:

  • Backbones would be more readily useful for more advance non-generative cases without needing to reach into sublayers.
  • Generation would be more easily customizable by passing a modified backbone to a CausalLM.
@mattdangerw mattdangerw self-assigned this Sep 22, 2024
@github-actions github-actions bot added the Gemma Gemma model specific issues label Sep 22, 2024
@mattdangerw mattdangerw changed the title Directly use the functional graph for generative forward passes Directly use the backbone functional graph for CausalLM.generate() Sep 22, 2024
@sachinprasadhs sachinprasadhs added the team-created Issues created by Keras Hub team as part of development roadmap. label Nov 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues team-created Issues created by Keras Hub team as part of development roadmap.
Projects
None yet
Development

No branches or pull requests

2 participants