-
Notifications
You must be signed in to change notification settings - Fork 455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
transformers v4.45 support #2023
Changes from 20 commits
cae68e5
69a1d2e
450a5a4
e29f466
e69658c
0cc167d
c98d5d6
9a6f601
fadadc9
9fa9e9f
bf913c2
94dee27
3bfa30e
7bf1d30
f01cccf
e206d44
3572a0b
d6e97cf
0e2ed87
bc28f03
e7d3ba4
e146328
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ | |
|
||
import numpy as np | ||
import onnx | ||
import transformers | ||
from transformers.modeling_utils import get_parameter_dtype | ||
from transformers.utils import is_tf_available, is_torch_available | ||
|
||
|
@@ -34,6 +35,7 @@ | |
DEFAULT_DUMMY_SHAPES, | ||
ONNX_WEIGHTS_NAME, | ||
TORCH_MINIMUM_VERSION, | ||
check_if_transformers_greater, | ||
is_diffusers_available, | ||
is_torch_onnx_support_available, | ||
logging, | ||
|
@@ -999,6 +1001,10 @@ def onnx_export_from_model( | |
>>> onnx_export_from_model(model, output="gpt2_onnx/") | ||
``` | ||
""" | ||
if check_if_transformers_greater("4.44.99"): | ||
raise ImportError( | ||
f"ONNX conversion disabled for now for transformers version greater than v4.45, found {transformers.__version__}" | ||
) | ||
Comment on lines
+1004
to
+1007
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. disabling onnx export as additional fixes are needed (but don't want to block the latest transformers release for other subpackages) @michaelbenayoun @xenova There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok to me, maybe specify that it's temporary. |
||
|
||
TasksManager.standardize_model_attributes(model) | ||
|
||
|
@@ -1120,6 +1126,18 @@ def onnx_export_from_model( | |
if isinstance(atol, dict): | ||
atol = atol[task.replace("-with-past", "")] | ||
|
||
if check_if_transformers_greater("4.44.99"): | ||
misplaced_generation_parameters = model.config._get_non_default_generation_parameters() | ||
if model.can_generate() and len(misplaced_generation_parameters) > 0: | ||
logger.warning( | ||
"Moving the following attributes in the config to the generation config: " | ||
f"{misplaced_generation_parameters}. You are seeing this warning because you've set " | ||
"generation parameters in the model config, as opposed to in the generation config.", | ||
) | ||
for param_name, param_value in misplaced_generation_parameters.items(): | ||
setattr(model.generation_config, param_name, param_value) | ||
setattr(model.config, param_name, None) | ||
|
||
# Saving the model config and preprocessor as this is needed sometimes. | ||
model.config.save_pretrained(output) | ||
generation_config = getattr(model, "generation_config", None) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe add docstring on how this custom implementation solves the problem.