From 59e3426407cf888025e8d8ab4521d198b595f475 Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Mon, 21 Mar 2022 12:48:35 +0100 Subject: [PATCH] Allow `source=None` to search all source in `load_adapter()` --- src/transformers/adapters/model_mixin.py | 6 +++--- src/transformers/adapters/utils.py | 25 +++++++++++++++++++++++- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index 1ae4737e47..91a2ba0cfa 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -444,7 +444,7 @@ def load_adapter( version: str = None, model_name: str = None, load_as: str = None, - source: str = "ah", + source: str = None, custom_weights_loaders: Optional[List[WeightsLoader]] = None, leave_out: Optional[List[int]] = None, id2label=None, @@ -471,7 +471,7 @@ def load_adapter( - "ah" (default): search on AdapterHub. - "hf": search on HuggingFace model hub. - - None: only search on local file system + - None: search on all sources leave_out: Dynamically drop adapter modules in the specified Transformer layers when loading the adapter. set_active (bool, optional): Set the loaded adapter to be the active one. By default (False), the adapter is loaded but not activated. @@ -876,7 +876,7 @@ def load_adapter( version: str = None, model_name: str = None, load_as: str = None, - source: str = "ah", + source: str = None, with_head: bool = True, custom_weights_loaders: Optional[List[WeightsLoader]] = None, leave_out: Optional[List[int]] = None, diff --git a/src/transformers/adapters/utils.py b/src/transformers/adapters/utils.py index 2df3a66009..4a44728672 100644 --- a/src/transformers/adapters/utils.py +++ b/src/transformers/adapters/utils.py @@ -422,7 +422,7 @@ def resolve_adapter_path( model_name: str = None, adapter_config: Union[dict, str] = None, version: str = None, - source: str = "ah", + source: str = None, **kwargs ) -> str: """ @@ -438,6 +438,11 @@ def resolve_adapter_path( model_name (str, optional): The identifier of the pre-trained model for which to load an adapter. adapter_config (Union[dict, str], optional): The configuration of the adapter to be loaded. version (str, optional): The version of the adapter to be loaded. Defaults to None. + source (str, optional): Identifier of the source(s) from where to get adapters. Can be either: + + - "ah": search on AdapterHub.ml. + - "hf": search on HuggingFace model hub (huggingface.co). + - None (default): search on all sources Returns: str: The local path from where the adapter module can be loaded. @@ -466,6 +471,24 @@ def resolve_adapter_path( ) elif source == "hf": return pull_from_hf_model_hub(adapter_name_or_path, version=version, **kwargs) + elif source is None: + try: + logger.info("Attempting to load adapter from source 'ah'...") + return pull_from_hub( + adapter_name_or_path, model_name, adapter_config=adapter_config, version=version, **kwargs + ) + except EnvironmentError as ex: + logger.info(ex) + logger.info("Attempting to load adapter from source 'hf'...") + try: + return pull_from_hf_model_hub(adapter_name_or_path, version=version, **kwargs) + except Exception as ex: + logger.info(ex) + raise EnvironmentError( + "Unable to load adapter {} from any source. Please check the name of the adapter or the source.".format( + adapter_name_or_path + ) + ) else: raise ValueError("Unable to identify {} as a valid module location.".format(adapter_name_or_path))