-
Notifications
You must be signed in to change notification settings - Fork 346
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactorings in model config & base model classes (#304)
- Loading branch information
Showing
28 changed files
with
171 additions
and
202 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -348,7 +348,7 @@ def run(self): | |
|
||
setup( | ||
name="adapter-transformers", | ||
version="2.3.0a0", | ||
version="3.0.0a0", | ||
author="Jonas Pfeiffer, Andreas Rücklé, Clifton Poth, Hannah Sterz, based on work by Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Suraj Patil, Stas Bekman, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors", | ||
author_email="[email protected]", | ||
description="A friendly fork of Huggingface's Transformers, adding Adapters to PyTorch language models", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import types | ||
|
||
from ...configuration_utils import PretrainedConfig | ||
from ...models.encoder_decoder.configuration_encoder_decoder import EncoderDecoderConfig | ||
from ..configuration import ModelAdaptersConfig | ||
|
||
|
||
CONFIG_CLASS_KEYS_MAPPING = { | ||
"bart": { | ||
"num_attention_heads": "encoder_attention_heads", | ||
"hidden_size": "d_model", | ||
"hidden_dropout_prob": "dropout", | ||
"attention_probs_dropout_prob": "attention_dropout", | ||
}, | ||
"bert": {}, | ||
"distilbert": { | ||
"hidden_dropout_prob": "dropout", | ||
"attention_probs_dropout_prob": "attention_dropout", | ||
}, | ||
"gpt2": { | ||
"hidden_dropout_prob": "resid_pdrop", | ||
"attention_probs_dropout_prob": "attn_pdrop", | ||
}, | ||
"mbart": { | ||
"num_attention_heads": "encoder_attention_heads", | ||
"hidden_size": "d_model", | ||
"hidden_dropout_prob": "dropout", | ||
"attention_probs_dropout_prob": "attention_dropout", | ||
}, | ||
"roberta": {}, | ||
"t5": { | ||
"hidden_size": "d_model", | ||
"num_attention_heads": "num_heads", | ||
"num_hidden_layers": "num_layers", | ||
"hidden_dropout_prob": "dropout_rate", | ||
"attention_probs_dropout_prob": "dropout_rate", | ||
}, | ||
"xlm_roberta": {}, | ||
} | ||
|
||
|
||
def _to_dict_new(self): | ||
output = self._to_dict_original() | ||
if hasattr(self, "adapters") and not isinstance(output["adapters"], dict): | ||
output["adapters"] = self.adapters.to_dict() | ||
if "custom_heads" in output.keys(): | ||
del output["custom_heads"] | ||
|
||
# delete handles to overriden methods | ||
del output["to_dict"] | ||
del output["_to_dict_original"] | ||
del output["is_adaptable"] | ||
|
||
return output | ||
|
||
|
||
def wrap_config(config: PretrainedConfig) -> PretrainedConfig: | ||
""" | ||
Makes required changes to a model config class to allow usage with adapters. | ||
Args: | ||
config (PretrainedConfig): The config to be wrapped. | ||
Returns: | ||
PretrainedConfig: The same config object, with modifications applied. | ||
""" | ||
if getattr(config, "is_adaptable", False): | ||
return config | ||
|
||
# Init ModelAdaptersConfig | ||
if not hasattr(config, "adapters"): | ||
config.adapters = ModelAdaptersConfig() | ||
elif config.adapters is not None and not isinstance(config.adapters, ModelAdaptersConfig): | ||
config.adapters = ModelAdaptersConfig(**config.adapters) | ||
|
||
# Convert AdapterFusions from old format for backwards compatibility | ||
fusion_models = getattr(config, "adapter_fusion_models", []) | ||
fusion_config = getattr(config, "adapter_fusion", None) | ||
for fusion_adapter_names in fusion_models: | ||
config.adapters.add_fusion(fusion_adapter_names, config=fusion_config) | ||
|
||
# Ensure missing keys are in class | ||
if config.model_type in CONFIG_CLASS_KEYS_MAPPING: | ||
for key, value in CONFIG_CLASS_KEYS_MAPPING[config.model_type].items(): | ||
if key not in config.attribute_map: | ||
config.attribute_map[key] = value | ||
|
||
# Override to_dict() to add adapters | ||
if not hasattr(config, "_to_dict_original"): | ||
config._to_dict_original = config.to_dict | ||
config.to_dict = types.MethodType(_to_dict_new, config) | ||
|
||
# Ensure custom_heads attribute is present | ||
if not hasattr(config, "custom_heads"): | ||
config.custom_heads = {} | ||
|
||
if isinstance(config, EncoderDecoderConfig): | ||
# make sure adapter config is shared | ||
wrap_config(config.encoder) | ||
wrap_config(config.decoder) | ||
config.decoder.adapters = config.encoder.adapters | ||
config.adapters = config.encoder.adapters | ||
|
||
config.is_adaptable = True | ||
|
||
return config |
Oops, something went wrong.