diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index ca71dca92a..4ec641d6c1 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1671,7 +1671,10 @@ def _infer_library_from_model( if library_name is not None: return library_name - if ( + # SentenceTransformer models have no config attributes + if hasattr(model, "_model_config"): + library_name = "sentence_transformers" + elif ( hasattr(model, "pretrained_cfg") or hasattr(model.config, "pretrained_cfg") or hasattr(model.config, "architecture") @@ -1679,8 +1682,6 @@ def _infer_library_from_model( library_name = "timm" elif hasattr(model.config, "_diffusers_version") or getattr(model, "config_name", "") == "model_index.json": library_name = "diffusers" - elif hasattr(model, "_model_config"): - library_name = "sentence_transformers" else: library_name = "transformers" return library_name @@ -1905,7 +1906,6 @@ def get_model_from_task( model_class = TasksManager.get_model_class_for_task( task, framework, model_type=model_type, model_class_name=model_class_name, library=library_name ) - if library_name == "timm": model = model_class(f"hf_hub:{model_name_or_path}", pretrained=True, exportable=True) model = model.to(torch_dtype).to(device)