Skip to content

Commit

Permalink
Refactorings in model config & base model classes (#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt authored Mar 23, 2022
1 parent 3884182 commit 5893a89
Show file tree
Hide file tree
Showing 28 changed files with 171 additions and 202 deletions.
6 changes: 0 additions & 6 deletions adapter_docs/classes/model_mixins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@ InvertibleAdaptersMixin
.. autoclass:: transformers.InvertibleAdaptersMixin
:members:

ModelConfigAdaptersMixin
----------------------------------

.. autoclass:: transformers.ModelConfigAdaptersMixin
:members:

ModelAdaptersMixin
------------------

Expand Down
5 changes: 3 additions & 2 deletions adding_adapters_to_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ For this purpose, there typically exists a module `src/transformers/adapters/mix
- The model classes with heads (e.g. `BertForSequenceClassification`) should directly implement `ModelWithHeadsAdaptersMixin`.
- To additionally support Prefix Tuning, it's necessary to apply the forward call to the `PrefixTuningShim` module in the respective attention layer.
- Again, have a look at existing implementations, e.g. `modeling_distilbert.py` or `modeling_bart.py`.
- Add the mixin for config classes, `ModelConfigAdaptersMixin`, to the model configuration class in `configuration_<model_type>`.
- There are some naming differences on the config attributes of different model architectures. The adapter implementation requires some additional attributes with a specific name to be available. These currently are `hidden_dropout_prob` and `attention_probs_dropout_prob` as in the `BertConfig` class.
- Adapt the config class to the requirements of adapters in `src/transformers/adapters/wrappers/configuration.py`.
- There are some naming differences on the config attributes of different model architectures. The adapter implementation requires some additional attributes with a specific name to be available. These currently are `num_attention_heads`, `hidden_size`, `hidden_dropout_prob` and `attention_probs_dropout_prob` as in the `BertConfig` class.
If your model config does not provide these, add corresponding mappings to `CONFIG_CLASS_KEYS_MAPPING`.

### `...AdapterModel` class

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def run(self):

setup(
name="adapter-transformers",
version="2.3.0a0",
version="3.0.0a0",
author="Jonas Pfeiffer, Andreas Rücklé, Clifton Poth, Hannah Sterz, based on work by Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Suraj Patil, Stas Bekman, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
author_email="[email protected]",
description="A friendly fork of Huggingface's Transformers, adding Adapters to PyTorch language models",
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).

__adapters_version__ = "2.3.0a0"
__version__ = "4.12.5"

# Work around to update TensorFlow's absl.logging threshold which alters the
Expand Down Expand Up @@ -1397,7 +1396,6 @@
"MBartModelWithHeads",
"ModelAdaptersConfig",
"ModelAdaptersMixin",
"ModelConfigAdaptersMixin",
"ModelWithFlexibleHeadsAdaptersMixin",
"ModelWithHeadsAdaptersMixin",
"MultiLingAdapterArguments",
Expand Down Expand Up @@ -3181,7 +3179,6 @@
MBartModelWithHeads,
ModelAdaptersConfig,
ModelAdaptersMixin,
ModelConfigAdaptersMixin,
ModelWithFlexibleHeadsAdaptersMixin,
ModelWithHeadsAdaptersMixin,
MultiLingAdapterArguments,
Expand Down Expand Up @@ -3708,7 +3705,7 @@
globals()["__file__"],
_import_structure,
module_spec=__spec__,
extra_objects={"__version__": __version__, "__adapters_version__": __adapters_version__},
extra_objects={"__version__": __version__},
)


Expand Down
17 changes: 9 additions & 8 deletions src/transformers/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "3.0.0a0"

from typing import TYPE_CHECKING

from ..file_utils import _LazyModule
Expand Down Expand Up @@ -77,7 +79,6 @@
"model_mixin": [
"InvertibleAdaptersMixin",
"ModelAdaptersMixin",
"ModelConfigAdaptersMixin",
"ModelWithHeadsAdaptersMixin",
],
"models.auto": [
Expand Down Expand Up @@ -184,12 +185,7 @@
TaggingHead,
)
from .layer import AdapterLayer, AdapterLayerBase
from .model_mixin import (
InvertibleAdaptersMixin,
ModelAdaptersMixin,
ModelConfigAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)
from .model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin, ModelWithHeadsAdaptersMixin
from .models.auto import ADAPTER_MODEL_MAPPING, MODEL_WITH_HEADS_MAPPING, AutoAdapterModel, AutoModelWithHeads
from .models.bart import BartAdapterModel, BartModelWithHeads
from .models.bert import BertAdapterModel, BertModelWithHeads
Expand All @@ -213,4 +209,9 @@
else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
extra_objects={"__version__": __version__},
)
5 changes: 2 additions & 3 deletions src/transformers/adapters/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from dataclasses import FrozenInstanceError, asdict, dataclass, field, replace
from typing import List, Optional, Union

from transformers import __adapters_version__

from . import __version__
from .composition import AdapterCompositionBlock
from .utils import get_adapter_config_hash, resolve_adapter_config

Expand Down Expand Up @@ -706,7 +705,7 @@ def build_full_config(adapter_config, model_config, save_id2label=False, **kwarg
config_dict["config"] = adapter_config.to_dict()
else:
config_dict["config"] = adapter_config
config_dict["version"] = __adapters_version__
config_dict["version"] = __version__
return config_dict


Expand Down
2 changes: 1 addition & 1 deletion src/transformers/adapters/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def adapter_layer_forward(self, hidden_states, input_tensor, layer_norm):
"""
Called for each forward pass through adapters.
"""
if hasattr(self.config, "adapters"):
if getattr(self.config, "is_adaptable", False):
# First check current context before falling back to defined setup
context = AdapterSetup.get_context()
if context is not None:
Expand Down
51 changes: 12 additions & 39 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,15 @@
from torch import nn

from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition
from .configuration import (
AdapterConfig,
AdapterConfigBase,
AdapterFusionConfig,
ModelAdaptersConfig,
get_adapter_config_hash,
)
from .configuration import AdapterConfig, AdapterConfigBase, AdapterFusionConfig, get_adapter_config_hash
from .context import AdapterSetup, ForwardContext
from .hub_mixin import PushAdapterToHubMixin
from .layer import AdapterLayer, AdapterLayerBase
from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader
from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock
from .prefix_tuning import PrefixTuningPool, PrefixTuningShim
from .utils import EMBEDDING_FILE, TOKENIZER_PATH, inherit_doc
from .wrappers.configuration import wrap_config


logger = logging.getLogger(__name__)
Expand All @@ -36,6 +31,9 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.invertible_adapters = nn.ModuleDict(dict())

# Make sure config is wrapped
self.config = wrap_config(self.config)

def add_invertible_adapter(self, adapter_name: str):
"""
Adds an invertible adapter module for the adapter with the given name. If the given adapter does not specify an
Expand Down Expand Up @@ -97,46 +95,21 @@ def invertible_adapters_forward(self, hidden_states, rev=False):
return hidden_states


class ModelConfigAdaptersMixin(ABC):
"""
Mixin for model config classes, adding support for adapters.
Besides adding this mixin to the config class of a model supporting adapters, make sure the following attributes/
properties are present: hidden_dropout_prob, attention_probs_dropout_prob.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# adapter configuration
adapter_config_dict = kwargs.pop("adapters", None)
if adapter_config_dict:
self.adapters = ModelAdaptersConfig(**adapter_config_dict)
else:
self.adapters = ModelAdaptersConfig()
# Convert AdapterFusions from old format for backwards compatibility
fusion_models = kwargs.pop("adapter_fusion_models", [])
fusion_config = kwargs.pop("adapter_fusion", None)
for fusion_adapter_names in fusion_models:
self.adapters.add_fusion(fusion_adapter_names, config=fusion_config)


class ModelAdaptersMixin(PushAdapterToHubMixin, ABC):
"""Mixin for transformer models adding support for loading/ saving adapters."""

def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.model_name = None
if config.name_or_path and not os.path.exists(config.name_or_path):
self.model_name = config.name_or_path
else:
self.model_name = None
self.loaded_embeddings = {}
self.shared_parameters = nn.ModuleDict()
self._active_embedding = "default"

# In some cases, the config is not an instance of a directly supported config class such as BertConfig.
# Thus, we check the adapters config here to make sure everything is correct.
if not hasattr(config, "adapters"):
config.adapters = ModelAdaptersConfig()
elif config.adapters is not None and not isinstance(config.adapters, ModelAdaptersConfig):
config.adapters = ModelAdaptersConfig(**config.adapters)
# Make sure config is wrapped
self.config = wrap_config(self.config)

def _link_prefix_to_pool(self, layer):
if isinstance(layer, PrefixTuningShim):
Expand Down Expand Up @@ -217,7 +190,7 @@ def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBloc
# TODO implement fusion for invertible adapters

def has_adapters(self):
if not getattr(self.config, "adapters", None):
if not getattr(self.config, "is_adaptable", None):
return False
return len(self.config.adapters.adapters) > 0

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/adapters/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def get_adapter(self, adapter_name):
return None

def forward(self, key_states, value_states, attention_mask=None, invert_mask=True):
if hasattr(self.config, "adapters"):
if getattr(self.config, "is_adaptable", False):
# First check current context before falling back to defined setup
context = AdapterSetup.get_context()
if context is not None:
Expand Down
13 changes: 3 additions & 10 deletions src/transformers/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import shutil
import tarfile
from collections.abc import Mapping
from dataclasses import asdict, dataclass, is_dataclass
from dataclasses import dataclass
from enum import Enum
from os.path import basename, isdir, isfile, join
from pathlib import Path
Expand All @@ -19,8 +19,8 @@
from filelock import FileLock
from huggingface_hub import HfApi, snapshot_download

from .. import __adapters_version__
from ..file_utils import get_from_cache, is_remote_url, torch_cache_home
from . import __version__


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -110,13 +110,6 @@ class AdapterInfo:
sha1_checksum: Optional[str] = None


class DataclassJSONEncoder(json.JSONEncoder):
def default(self, o):
if is_dataclass(o):
return asdict(o)
return super().default(o)


def _minimize_dict(d):
if isinstance(d, Mapping):
return {k: _minimize_dict(v) for (k, v) in d.items() if v}
Expand Down Expand Up @@ -428,7 +421,7 @@ def pull_from_hf_model_hub(specifier: str, version: str = None, **kwargs) -> str
revision=version,
cache_dir=kwargs.pop("cache_dir", None),
library_name="adapter-transformers",
library_version=__adapters_version__,
library_version=__version__,
)
return download_path

Expand Down
Empty file.
106 changes: 106 additions & 0 deletions src/transformers/adapters/wrappers/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import types

from ...configuration_utils import PretrainedConfig
from ...models.encoder_decoder.configuration_encoder_decoder import EncoderDecoderConfig
from ..configuration import ModelAdaptersConfig


CONFIG_CLASS_KEYS_MAPPING = {
"bart": {
"num_attention_heads": "encoder_attention_heads",
"hidden_size": "d_model",
"hidden_dropout_prob": "dropout",
"attention_probs_dropout_prob": "attention_dropout",
},
"bert": {},
"distilbert": {
"hidden_dropout_prob": "dropout",
"attention_probs_dropout_prob": "attention_dropout",
},
"gpt2": {
"hidden_dropout_prob": "resid_pdrop",
"attention_probs_dropout_prob": "attn_pdrop",
},
"mbart": {
"num_attention_heads": "encoder_attention_heads",
"hidden_size": "d_model",
"hidden_dropout_prob": "dropout",
"attention_probs_dropout_prob": "attention_dropout",
},
"roberta": {},
"t5": {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"hidden_dropout_prob": "dropout_rate",
"attention_probs_dropout_prob": "dropout_rate",
},
"xlm_roberta": {},
}


def _to_dict_new(self):
output = self._to_dict_original()
if hasattr(self, "adapters") and not isinstance(output["adapters"], dict):
output["adapters"] = self.adapters.to_dict()
if "custom_heads" in output.keys():
del output["custom_heads"]

# delete handles to overriden methods
del output["to_dict"]
del output["_to_dict_original"]
del output["is_adaptable"]

return output


def wrap_config(config: PretrainedConfig) -> PretrainedConfig:
"""
Makes required changes to a model config class to allow usage with adapters.
Args:
config (PretrainedConfig): The config to be wrapped.
Returns:
PretrainedConfig: The same config object, with modifications applied.
"""
if getattr(config, "is_adaptable", False):
return config

# Init ModelAdaptersConfig
if not hasattr(config, "adapters"):
config.adapters = ModelAdaptersConfig()
elif config.adapters is not None and not isinstance(config.adapters, ModelAdaptersConfig):
config.adapters = ModelAdaptersConfig(**config.adapters)

# Convert AdapterFusions from old format for backwards compatibility
fusion_models = getattr(config, "adapter_fusion_models", [])
fusion_config = getattr(config, "adapter_fusion", None)
for fusion_adapter_names in fusion_models:
config.adapters.add_fusion(fusion_adapter_names, config=fusion_config)

# Ensure missing keys are in class
if config.model_type in CONFIG_CLASS_KEYS_MAPPING:
for key, value in CONFIG_CLASS_KEYS_MAPPING[config.model_type].items():
if key not in config.attribute_map:
config.attribute_map[key] = value

# Override to_dict() to add adapters
if not hasattr(config, "_to_dict_original"):
config._to_dict_original = config.to_dict
config.to_dict = types.MethodType(_to_dict_new, config)

# Ensure custom_heads attribute is present
if not hasattr(config, "custom_heads"):
config.custom_heads = {}

if isinstance(config, EncoderDecoderConfig):
# make sure adapter config is shared
wrap_config(config.encoder)
wrap_config(config.decoder)
config.decoder.adapters = config.encoder.adapters
config.adapters = config.encoder.adapters

config.is_adaptable = True

return config
Loading

0 comments on commit 5893a89

Please sign in to comment.