Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add adapter support to CLIP #483

Merged
merged 8 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions adapter_docs/classes/models/clip.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
CLIP
=====

.. note::
Adapter implementation notes:
- CLIP consists of two separate Transformer encoder models, a ViT-style Transformer for visual features and a language model for textual features. Both encoders can be fitted with adapters. As usual, the ``leave_out`` parameter can be used to specify the layers in which adapters should be added. For CLIP, layer IDs are counted globally across both encoders, starting from the text encoder. I.e., for a CLIP model with 12 layers in each Transformer encoder, the text encoder will have IDs 0-11 and the vision encoder will have IDs 12-23.
- As CLIP does not come with pre-supported task-specific prediction heads, there is currently no ``CLIPAdapterModel`` class. Use ``CLIPModel`` instead.

The CLIP model was proposed in `Learning Transferable Visual Models From Natural Language Supervision <https://arxiv.org/abs/2103.00020>`_ by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh,
Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. CLIP
(Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be
instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing
for the task, similarly to the zero-shot capabilities of GPT-2 and 3.

The abstract from the paper is the following:

*State-of-the-art computer vision systems are trained to predict a fixed set of predetermined object categories. This
restricted form of supervision limits their generality and usability since additional labeled data is needed to specify
any other visual concept. Learning directly from raw text about images is a promising alternative which leverages a
much broader source of supervision. We demonstrate that the simple pre-training task of predicting which caption goes
with which image is an efficient and scalable way to learn SOTA image representations from scratch on a dataset of 400
million (image, text) pairs collected from the internet. After pre-training, natural language is used to reference
learned visual concepts (or describe new ones) enabling zero-shot transfer of the model to downstream tasks. We study
the performance of this approach by benchmarking on over 30 different existing computer vision datasets, spanning tasks
such as OCR, action recognition in videos, geo-localization, and many types of fine-grained object classification. The
model transfers non-trivially to most tasks and is often competitive with a fully supervised baseline without the need
for any dataset specific training. For instance, we match the accuracy of the original ResNet-50 on ImageNet zero-shot
without needing to use any of the 1.28 million training examples it was trained on. We release our code and pre-trained
model weights at this https URL.*

CLIPTextModel
~~~~~~~~~~~~~

.. autoclass:: transformers.CLIPTextModel
:members:
:inherited-members: CLIPPreTrainedModel

CLIPVisionModel
~~~~~~~~~~~~~~~

.. autoclass:: transformers.CLIPVisionModel
:members:
:inherited-members: CLIPPreTrainedModel

CLIPModel
~~~~~~~~~

.. autoclass:: transformers.CLIPModel
:members:
:inherited-members: CLIPPreTrainedModel
1 change: 1 addition & 0 deletions adapter_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
classes/models/bart
classes/models/beit
classes/models/bert
classes/models/clip
classes/models/deberta
classes/models/deberta_v2
classes/models/distilbert
Expand Down
1 change: 1 addition & 0 deletions adapter_docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The table below further shows which model architectures support which adaptation
| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | |
| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"model_mixin": [
"EmbeddingAdaptersMixin",
"InvertibleAdaptersMixin",
"InvertibleAdaptersWrapperMixin",
"ModelAdaptersMixin",
"ModelWithHeadsAdaptersMixin",
],
Expand Down Expand Up @@ -200,6 +201,7 @@
from .model_mixin import (
EmbeddingAdaptersMixin,
InvertibleAdaptersMixin,
InvertibleAdaptersWrapperMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)
Expand Down
17 changes: 4 additions & 13 deletions src/transformers/adapters/mixins/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
InvertibleAdaptersMixin,
InvertibleAdaptersWrapperMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)
Expand All @@ -31,9 +31,11 @@ def _init_adapter_modules(self):
self.cross_attention_adapters._init_adapter_modules()


class BartModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin):
class BartModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelAdaptersMixin):
"""Adds adapters to the BartModel class."""

invertible_adapters_base_name = "encoder"

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
if hasattr(self, "encoder"):
for i, layer in enumerate(self.encoder.layers):
Expand All @@ -44,17 +46,6 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.decoder.layers):
yield i, layer

def _init_adapter_modules(self):
if hasattr(self, "encoder"):
# In BART, the invertible adapters are implemented by the encoder module.
# Therefore, relay mixin calls to the encoder here.
self.invertible_adapters = self.encoder.invertible_adapters
self.add_invertible_adapter = self.encoder.add_invertible_adapter
self.get_invertible_adapter = self.encoder.get_invertible_adapter
self.enable_invertible_adapters = self.encoder.enable_invertible_adapters
self.invertible_adapters_forward = self.encoder.invertible_adapters_forward
super()._init_adapter_modules()


class BartModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin):
pass
51 changes: 51 additions & 0 deletions src/transformers/adapters/mixins/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Iterable, Tuple

import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
InvertibleAdaptersWrapperMixin,
ModelAdaptersMixin,
)


class CLIPEncoderLayerAdaptersMixin:
"""Adds adapters to the CLIPEncoderLayer module of CLIP."""

def _init_adapter_modules(self):
self.attention_adapters = AdapterLayer("mh_adapter", self.config)
self.output_adapters = AdapterLayer("output_adapter", self.config)
self.attention_adapters._init_adapter_modules()
self.output_adapters._init_adapter_modules()


class CLIPTextModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelAdaptersMixin):
"""Adds adapters to the CLIPTextModel class."""

invertible_adapters_base_name = "text_model"

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.text_model.encoder.layers):
yield i, layer


class CLIPVisionModelAdaptersMixin(ModelAdaptersMixin):
"""Adds adapters to the a CLIPVisionModel class."""

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.vision_model.encoder.layers):
yield i, layer


class CLIPModelAdaptersMixin(EmbeddingAdaptersWrapperMixin, InvertibleAdaptersWrapperMixin, ModelAdaptersMixin):
"""Adds adapters to the CLIPModel class."""

invertible_adapters_base_name = "text_model"

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.text_model.encoder.layers):
yield i, layer
for i, layer in enumerate(self.vision_model.encoder.layers, start=len(self.text_model.encoder.layers)):
yield i, layer
18 changes: 4 additions & 14 deletions src/transformers/adapters/mixins/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..layer import AdapterLayer
from ..model_mixin import (
EmbeddingAdaptersMixin,
InvertibleAdaptersMixin,
InvertibleAdaptersWrapperMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)
Expand All @@ -26,9 +26,11 @@ def __init__(self):
super().__init__("output_adapter", None)


class T5ModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin):
class T5ModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelAdaptersMixin):
"""Adds adapters to the T5Model class."""

invertible_adapters_base_name = "encoder"

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
global_i = 0
if hasattr(self, "encoder"):
Expand All @@ -39,18 +41,6 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.decoder.block, start=global_i):
yield i, layer

def _init_adapter_modules(self):
if hasattr(self, "encoder"):
# In T5, the invertible adapters are implemented by the encoder module.
# Therefore, relay mixin calls to the encoder here.
self.invertible_adapters = self.encoder.invertible_adapters
self.add_invertible_adapter = self.encoder.add_invertible_adapter
self.get_invertible_adapter = self.encoder.get_invertible_adapter
self.enable_invertible_adapters = self.encoder.enable_invertible_adapters
self.invertible_adapters_forward = self.encoder.invertible_adapters_forward
self.delete_invertible_adapter = self.encoder.delete_invertible_adapter
super()._init_adapter_modules()


# EmbeddingAdaptersWrapperMixin not required here as base and heads model are identical
class T5ModelWithHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin):
Expand Down
61 changes: 56 additions & 5 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(self, *args, **kwargs):
self.invertible_adapters = nn.ModuleDict(dict())

# Make sure config is wrapped
self.config = wrap_config(self.config)
if hasattr(self, "config"):
self.config = wrap_config(self.config)

def add_invertible_adapter(self, adapter_name: str):
"""
Expand Down Expand Up @@ -102,6 +103,54 @@ def invertible_adapters_forward(self, hidden_states, rev=False):
return hidden_states


class InvertibleAdaptersWrapperMixin:
"""
Mixin for Transformer models supporting invertible adapters in a child module. When applying this mixin, set
`invertible_adapters_base_name` to the name of the child module that includes `InvertibleAdaptersMixin`.
"""

invertible_adapters_base_name = ""

@property
def invertible_adapters_base(self):
return getattr(self, self.invertible_adapters_base_name, None)

@property
def invertible_adapters(self):
if self.invertible_adapters_base is not None:
return self.invertible_adapters_base.invertible_adapters
return None

def add_invertible_adapter(self, adapter_name: str):
"""
Adds an invertible adapter module for the adapter with the given name. If the given adapter does not specify an
invertible adapter config, this method does nothing.

Args:
adapter_name (str): The name of the adapter for which to add an invertible adapter module.
"""
if self.invertible_adapters_base is not None:
self.invertible_adapters_base.add_invertible_adapter(adapter_name)

def delete_invertible_adapter(self, adapter_name: str):
if self.invertible_adapters_base is not None:
self.invertible_adapters_base.delete_invertible_adapter(adapter_name)

def get_invertible_adapter(self):
if self.invertible_adapters_base is not None:
return self.invertible_adapters_base.get_invertible_adapter()
return None

def enable_invertible_adapters(self, adapter_names):
if self.invertible_adapters_base is not None:
self.invertible_adapters_base.enable_invertible_adapters(adapter_names)

def invertible_adapters_forward(self, hidden_states, rev=False):
if self.invertible_adapters_base is not None:
return self.invertible_adapters_base.invertible_adapters_forward(hidden_states, rev=rev)
return hidden_states


class EmbeddingAdaptersMixin:
"""Mixin for Transformer models adding support for dynamically switching embeddings."""

Expand Down Expand Up @@ -330,7 +379,7 @@ def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], tra
for param in self.base_model.shared_parameters[adapter_name].values():
param.requires_grad = True

if isinstance(self, InvertibleAdaptersMixin):
if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin):
self.enable_invertible_adapters(adapter_setup.flatten())
# use the adapters to be trained by default in every forward pass
self.set_active_adapters(adapter_setup)
Expand Down Expand Up @@ -446,7 +495,7 @@ def _add_adapter_weights(self, adapter_name: str):
for module in self.modules():
if isinstance(module, PrefixTuningPool):
module.confirm_prefix(adapter_name)
if isinstance(self, InvertibleAdaptersMixin):
if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin):
self.add_invertible_adapter(adapter_name)

def add_fusion(self, adapter_names: Union[Fuse, list], adapter_fusion_config=None, override_kwargs=None):
Expand Down Expand Up @@ -516,7 +565,7 @@ def delete_adapter(self, adapter_name: str):
# PHM Layer
if adapter_name in self.base_model.shared_parameters:
del self.base_model.shared_parameters[adapter_name]
if isinstance(self, InvertibleAdaptersMixin):
if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin):
self.delete_invertible_adapter(adapter_name)
# Reset active adapters if this was the only active adapter
if self.active_adapters == Stack(adapter_name):
Expand Down Expand Up @@ -827,7 +876,9 @@ def get_adapter(self, name) -> dict:
# global weights are saved at index -1
if name in self.base_model.shared_parameters:
destination[-1]["shared"] = self.base_model.shared_parameters[name]
if isinstance(self, InvertibleAdaptersMixin) and name in self.invertible_adapters:
if (
isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin)
) and name in self.invertible_adapters:
destination[-1]["invertible"] = self.invertible_adapters[name]

# use a custom index to ensure numbering is from 0 to N layers
Expand Down
30 changes: 19 additions & 11 deletions src/transformers/adapters/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,16 @@ def __init__(self, config):
self.prefix_counts = {}
self.prefix_tunings = nn.ModuleDict()

def indicate_prefix(self, prefix_name: str, location_key: str):
def indicate_prefix(self, prefix_name: str, location_key: str, **kwargs):
if prefix_name not in self.prefix_counts:
self.prefix_counts[prefix_name] = {location_key: 1}
self.prefix_counts[prefix_name] = {location_key: {"count": 1, **kwargs}}
elif location_key not in self.prefix_counts[prefix_name]:
self.prefix_counts[prefix_name][location_key] = 1
self.prefix_counts[prefix_name][location_key] = {"count": 1, **kwargs}
else:
self.prefix_counts[prefix_name][location_key] += 1
# TODO-AH: Check if kwargs are the same
self.prefix_counts[prefix_name][location_key]["count"] += 1

return self.prefix_counts[prefix_name][location_key] - 1
return self.prefix_counts[prefix_name][location_key]["count"] - 1

def confirm_prefix(self, prefix_name: str):
"""Create Prefix Tuning module based on shim layer infications."""
Expand All @@ -164,11 +165,11 @@ def confirm_prefix(self, prefix_name: str):
raise ValueError(f"Prefix {prefix_name} not found in PrefixTuningPool")

module_configs = {}
for location_key, count in self.prefix_counts[prefix_name].items():
for location_key, location_config in self.prefix_counts[prefix_name].items():
module_configs[location_key] = {
"n_layers": count,
"n_heads": self.config.num_attention_heads,
"input_size": self.config.hidden_size,
"n_layers": location_config["count"],
"n_heads": location_config["n_heads"],
"input_size": location_config["input_size"],
}
prefix_tuning = PrefixTuningGroup(module_configs, prefix_tuning_config)
prefix_tuning.train(self.training) # make sure training mode is consistent
Expand Down Expand Up @@ -232,10 +233,12 @@ class PrefixTuningShim(AdapterLayerBase, nn.Module):
config (:class:`~transformers.PretrainedConfig`): The model config.
"""

def __init__(self, location_key: str, config):
def __init__(self, location_key: str, config, add_model_type_to_key: bool = False):
super().__init__()
self.config = config
self.location_key = location_key
if add_model_type_to_key:
self.location_key = f"{self.config.model_type}_{self.location_key}"
self.prefixes = {}
self.prefix_gates = nn.ModuleDict()

Expand All @@ -256,7 +259,12 @@ def add_adapter(self, adapter_name: str, layer_idx: int):
location_key=used_location_key,
)
if prefix_tuning_config is not None:
prefix_id = self.pool.indicate_prefix(adapter_name, self.location_key)
prefix_id = self.pool.indicate_prefix(
adapter_name,
self.location_key,
n_heads=self.config.num_attention_heads,
input_size=self.config.hidden_size,
)
self.prefixes[adapter_name] = prefix_id

if prefix_tuning_config.use_gating:
Expand Down
Loading