Skip to content

Commit

Permalink
Add download redirect for AH adapters to HF (adapter-hub#704)
Browse files Browse the repository at this point in the history
This PR implements an automatic redirect for `load_adapter()` from Hub
repo adapters to the HF Hub:
- If `source` is not specified, will automatically redirect attempts to
load from Hub repo to HF Hub. Will also issue a warning to set new id
explicitly.
- if `source="ah"` explicitly, will still load from Hub repo link and
issue a warning.

Also updated docs accordingly & removed sections specific to Hub repo
loading.

---------

Co-authored-by: Timo Imhof <[email protected]>
  • Loading branch information
dainis-boumber and TimoImhof committed Aug 24, 2024
1 parent 87519c0 commit 5519a8d
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 50 deletions.
48 changes: 12 additions & 36 deletions docs/loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@

## Finding pre-trained adapters

**[AdapterHub.ml](https://adapterhub.ml/explore)** provides a central collection of all pre-trained adapters uploaded via [our Hub repository](https://github.com/adapter-hub/hub) or Hugging Face's [Model Hub](https://huggingface.co/models).
You can easily find pre-trained adapters for your task of interest along with all relevant information and code snippets to get started (also see below).
**[AdapterHub.ml](https://adapterhub.ml/explore)** provides a central collection of all pre-trained adapters uploaded via Hugging Face's [Model Hub](https://huggingface.co/models).
You can easily find pre-trained adapters for your task of interest along with all relevant information and code snippets to get started.

```{eval-rst}
.. note::
The original `Hub repository <https://github.com/adapter-hub/hub>`_ (via ``source="ah"``) has been archived and migrated to the HuggingFace Model Hub. The Adapters library supports automatic redirecting to the HF Model Hub when attempting to load adapters from the original Hub repository.
```

Alternatively, [`list_adapters()`](adapters.utils.list_adapters) provides a programmatical way of accessing all available pre-trained adapters.
This will return an [`AdapterInfo`](adapters.utils.AdapterInfo) object for each retrieved adapter.
Expand All @@ -12,8 +17,8 @@ E.g., we can use it to retrieve information for all adapters trained for a speci
```python
from adapters import list_adapters

# source can be "ah" (AdapterHub), "hf" (huggingface.co) or None (for both, default)
adapter_infos = list_adapters(source="ah", model_name="bert-base-uncased")
# source can be "ah" (archived Hub repo), "hf" (huggingface.co) or None (for both, default)
adapter_infos = list_adapters(source="hf", model_name="bert-base-uncased")

for adapter_info in adapter_infos:
print("Id:", adapter_info.adapter_id)
Expand Down Expand Up @@ -51,11 +56,13 @@ adapter_name = model.load_adapter('sst-2')

In the minimal case, that's everything we need to specify to load a pre-trained task adapter for sentiment analysis, trained on the `sst-2` dataset using BERT base and a suitable adapter configuration.
The name of the adapter is returned by [`load_adapter()`](adapters.ModelWithHeadsAdaptersMixin.load_adapter), so we can [activate it](adapter_composition.md) in the next step:

```python
model.set_active_adapters(adapter_name)
```

As the second example, let's have a look at how to load an adapter based on the [`AdapterInfo`](adapters.utils.AdapterInfo) returned by the [`list_adapters()`](adapters.utils.list_adapters) method from [above](#finding-pre-trained-adapters):

```python
from adapters import AutoAdapterModel, list_available_adapters

Expand Down Expand Up @@ -84,7 +91,7 @@ model.load_adapter(

We will go through the different arguments and their meaning one by one:

- The first argument passed to the method specifies the name of the adapter we want to load from Adapter-Hub. The library will search for an available adapter module with this name that matches the model architecture as well as the adapter type and configuration we requested. As the identifier `sst-2` resolves to a unique entry in the Hub, the corresponding adapter can be successfully loaded based on this information. To get an overview of all available adapter identifiers, please refer to [the Adapter-Hub website](https://adapterhub.ml/explore). The different format options of the identifier string are further described in [How adapter resolving works](#how-adapter-resolving-works).
- The first argument passed to the method specifies the name of the adapter we want to load from Adapter-Hub. The library will search for an available adapter module with this name that matches the model architecture as well as the adapter type and configuration we requested. As the identifier `sst-2` resolves to a unique entry in the Hub, the corresponding adapter can be successfully loaded based on this information. To get an overview of all available adapter identifiers, please refer to [the Adapter-Hub website](https://adapterhub.ml/explore).

- The `config` argument defines the adapter architecture the loaded adapter should have.
The value of this parameter can be either a string identifier for one of the predefined architectures, the identifier of an architecture available in the Hub or a dictionary representing a full adapter configuration.
Expand All @@ -101,34 +108,3 @@ To load the adapter using a custom name, we can use the `load_as` parameter.

- Finally the `source` parameter provides the possibility to load adapters from alternative adapter repositories.
Besides the default value `ah`, referring to AdapterHub, it's also possible to pass `hf` to [load adapters from Hugging Face's Model Hub](huggingface_hub.md).

## How adapter resolving works

As described in the previous section, the methods for loading adapters are able to resolve the correct adapter weights
based on the given identifier string, the model name and the adapter configuration.
Using this information, the `adapters` library searches for a matching entry in the index of the [Hub GitHub repo](https://github.com/adapter-hub/hub).

The identifier string used to find a matching adapter follows a format consisting of three components:
```
<task>/<subtask>@<username>
```

- `<task>`: A generic task identifier referring to a category of similar tasked (e.g. `sentiment`, `nli`)
- `<subtask>`: A dataset or domain, on which the adapter was trained (e.g. `multinli`, `wiki`)
- `<username>`: The name of the user or organization that uploaded the pre-trained adapter

An example of a full identifier following this format might look like `qa/squad1.1@example-org`.

```{eval-rst}
.. important::
In many cases, you don't have to give the full string identifier with all three components to successfully load an adapter from the Hub. You can drop the `<username>` you don't care about the uploader of the adapter. Also, if the resulting identifier is still unique, you can drop the ``<task>`` or the ``<subtask>``. So, ``qa/squad1.1``, ``squad1.1`` or ``squad1.1@example-org`` all may be valid identifiers.
```

An alternative adapter identifier format is given by:

```
@<username>/<filename>
```

where `<filename>` refers to the name of a adapter file in the [Hub repo](https://github.com/adapter-hub/hub).
In contrast to the previous three-component identifier, this identifier is guaranteed to be unique.
10 changes: 6 additions & 4 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,9 +805,11 @@ def load_adapter(
saved will be used.
source (str, optional): Identifier of the source(s) from where to load the adapter. Can be:
- "ah" (default): search on AdapterHub.
- "hf": search on HuggingFace model hub.
- None: search on all sources
- "ah": search on AdapterHub Hub repo.
Note: the Hub repo has been archived and all adapters have been moved to HuggingFace Model Hub.
Loading from this source is deprecated.
- "hf": search on HuggingFace Model Hub.
- None (default): search on all sources
leave_out: Dynamically drop adapter modules in the specified Transformer layers when loading the adapter.
set_active (bool, optional):
Set the loaded adapter to be the active one. By default (False), the adapter is loaded but not
Expand Down Expand Up @@ -1688,4 +1690,4 @@ def freeze_embeddings(self, freeze=True):
p.requires_grad = not freeze
else:
for p in self.get_output_embeddings().parameters():
p.requires_grad = not freeze
p.requires_grad = not freeze
45 changes: 39 additions & 6 deletions src/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ def pull_from_hub(
adapter_config: Optional[Union[dict, str]] = None,
version: str = None,
strict: bool = False,
redirect_to_hf_hub: bool = False,
**kwargs
) -> str:
"""
Expand All @@ -623,6 +624,9 @@ def pull_from_hub(
version (str, optional): The version of the adapter to be loaded. Defaults to None.
strict (bool, optional):
If set to True, only allow adapters exactly matching the given config to be loaded. Defaults to False.
redirect_to_hf_hub (bool, optional):
If set to True, the function will redirect to the HuggingFace Model Hub instead of AdapterHub.
Defaults to False.
Returns:
str: The local path to which the adapter has been downloaded.
Expand All @@ -636,13 +640,27 @@ def pull_from_hub(
hub_entry_url = find_in_index(specifier, model_name, adapter_config=adapter_config, strict=strict)
if not hub_entry_url:
raise EnvironmentError("No adapter with name '{}' was found in the adapter index.".format(specifier))
hub_entry = http_get_json(hub_entry_url)

hf_hub_specifier = "AdapterHub/" + os.path.basename(hub_entry_url).split(".")[0]
if redirect_to_hf_hub:
logger.warning(
"Automatic redirect to HF Model Hub repo '{}'. Please switch to the new ID to remove this warning.".format(
hf_hub_specifier
)
)
return pull_from_hf_model_hub(hf_hub_specifier, version=version, **kwargs)
else:
logger.warning(
"Loading adapters from this source is deprecated. This adapter has moved to '{}'. Please switch to the new"
" ID to remove this warning.".format(hf_hub_specifier)
)

hub_entry = http_get_json(hub_entry_url)
# set version
if not version:
version = hub_entry["default_version"]
elif version not in hub_entry["files"]:
logger.warn("Version '{}' of adapter '{}' not found. Falling back to default.".format(version, specifier))
logger.warning("Version '{}' of adapter '{}' not found. Falling back to default.".format(version, specifier))
version = hub_entry["default_version"]
file_entry = hub_entry["files"][version]

Expand Down Expand Up @@ -672,6 +690,7 @@ def resolve_adapter_path(
adapter_config: Union[dict, str] = None,
version: str = None,
source: str = None,
redirect_to_hf_hub: bool = False,
**kwargs
) -> str:
"""
Expand All @@ -689,10 +708,14 @@ def resolve_adapter_path(
version (str, optional): The version of the adapter to be loaded. Defaults to None.
source (str, optional): Identifier of the source(s) from where to get adapters. Can be either:
- "ah": search on AdapterHub.ml.
- "ah": search on AdapterHub.ml. Note: this source is deprecated in favor of "hf".
- "hf": search on HuggingFace model hub (huggingface.co).
- None (default): search on all sources
redirect_to_hf_hub (bool, optional):
If set to True, the function will redirect to the HuggingFace Model Hub instead of AdapterHub.
Defaults to False.
Returns:
str: The local path from where the adapter module can be loaded.
"""
Expand All @@ -718,7 +741,12 @@ def resolve_adapter_path(
)
elif source == "ah":
return pull_from_hub(
adapter_name_or_path, model_name, adapter_config=adapter_config, version=version, **kwargs
adapter_name_or_path,
model_name,
adapter_config=adapter_config,
version=version,
redirect_to_hf_hub=redirect_to_hf_hub,
**kwargs,
)
elif source == "hf":
return pull_from_hf_model_hub(adapter_name_or_path, version=version, **kwargs)
Expand All @@ -731,7 +759,12 @@ def resolve_adapter_path(
logger.info("Attempting to load adapter from source 'ah'...")
try:
return pull_from_hub(
adapter_name_or_path, model_name, adapter_config=adapter_config, version=version, **kwargs
adapter_name_or_path,
model_name,
adapter_config=adapter_config,
version=version,
redirect_to_hf_hub=True,
**kwargs,
)
except Exception as ex:
logger.info(ex)
Expand Down Expand Up @@ -879,4 +912,4 @@ def patch_forward(module: torch.nn.Module):
# we need to explicitly set to potentially overriden forward methods on adapter init.
# The `add_hook_to_module()` method is e.g. used for `device_map="auto"` in the `PreTrainedModel.from_pretrained()` method.
if hasattr(module, "_old_forward"):
module._old_forward = module.__class__.forward.__get__(module, module.__class__)
module._old_forward = module.__class__.forward.__get__(module, module.__class__)
7 changes: 3 additions & 4 deletions tests/test_adapter_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from adapters import ADAPTER_CONFIG_MAP, AdapterConfig, BertAdapterModel, get_adapter_config_hash
from adapters.trainer import AdapterTrainer as Trainer
from adapters.utils import find_in_index
from transformers import ( # get_adapter_config_hash,
from transformers import (
AutoModel,
AutoTokenizer,
BertForSequenceClassification,
GlueDataset,
GlueDataTrainingArguments,
TrainingArguments,
Expand Down Expand Up @@ -50,7 +49,7 @@ def test_load_task_adapter_from_hub(self):
for config in ["pfeiffer", "houlsby"]:
with self.subTest(config=config):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
model = BertAdapterModel.from_pretrained("bert-base-uncased")
adapters.init(model)

loading_info = {}
Expand Down Expand Up @@ -163,4 +162,4 @@ def test_load_adapter_with_head_from_hub(self):
in_data = ids_tensor((1, 128), 1000)
model.to(torch_device)
output = model(in_data)
self.assertEqual([1, 128], list(output[0].size()))
self.assertEqual([1, 128], list(output[0].size()))

0 comments on commit 5519a8d

Please sign in to comment.