-
Notifications
You must be signed in to change notification settings - Fork 455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add text2text-generation-with-past test for encoder-decoder model #1338
Add text2text-generation-with-past test for encoder-decoder model #1338
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few nits.
I do not understand everyting, but looks good to me from what I could read.
Thanks @mht-sharma !
for model_id in model_ids: | ||
if ( | ||
model_arch == "encoder-decoder" | ||
and use_cache is True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and use_cache is True | |
and use_cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is the common terminology used in the repo, to explicitly check the value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's me who advocated to use that :D But more and more I agree with Michael, we should not explicitly check is True
when the type is bool
, we should only for Optional[bool]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this explicit check is something which is also followed in the transformers, I remember Lewis mentioning this, but yeah I agree
def _get_model_ids(self, model_arch): | ||
model_ids = MODEL_NAMES[model_arch] | ||
if isinstance(model_ids, dict): | ||
model_ids = list(model_ids.keys()) | ||
else: | ||
model_ids = [model_ids] | ||
return model_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what you call model_arch
is a "model type" like bert
, t5
and so on?
This function returns a list of model ids from the Hub for a given architecture?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a helper function for the test to get the list of model_ids
for a given model_arch(model_type)
from MODEL_NAMES
. Earlier there was one model_id
for a model_arch
. But with this PR there can be multiple (Format is same as used in onnx export).
for model_id in model_ids: | ||
if ( | ||
model_arch == "encoder-decoder" | ||
and use_cache is True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
and use_cache is True | |
and use_cache |
for model_id in model_ids: | ||
if ( | ||
model_arch == "encoder-decoder" | ||
and use_cache is True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for adding the tests! It is fine to me as long as they pass.
Merging as the test error are not related to PR |
What does this PR do?
Creating a test for an
text2text-generation-with-past
tasks for an encoder-decoder model #851--with-past
with the modelBefore submitting