From c9065b0d02434faf48d41606cced162b16105e27 Mon Sep 17 00:00:00 2001 From: Dainis Boumber Date: Fri, 23 Aug 2024 23:47:36 +0000 Subject: [PATCH] Move custom head dict out of config (adapter-hub#700) To make the model_config serializable and prevent the error mentioned in adapter-hub#680 move the costum_heads dictionary out of the config and make it a separate attribute of the model class. --- src/adapters/heads/model_mixin.py | 12 ++++++------ tests/test_adapter_custom_head.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/adapters/heads/model_mixin.py b/src/adapters/heads/model_mixin.py index 89947732c6..24f27602b9 100644 --- a/src/adapters/heads/model_mixin.py +++ b/src/adapters/heads/model_mixin.py @@ -53,8 +53,8 @@ class ModelWithFlexibleHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._convert_to_flex_head = True - if not hasattr(self.config, "custom_heads"): - self.config.custom_heads = {} + if not hasattr(self, "custom_heads"): + self.custom_heads = {} self._active_heads = [] def head_type(head_type_str: str): @@ -176,7 +176,7 @@ def add_prediction_head_from_config( head_class = MODEL_HEAD_MAP[head_type] head = head_class(self, head_name, **config) self.add_prediction_head(head, overwrite_ok=overwrite_ok, set_active=set_active) - elif head_type in self.config.custom_heads: + elif head_type in self.custom_heads: # we have to re-add the head type for custom heads self.add_custom_head(head_type, head_name, overwrite_ok=overwrite_ok, **config) else: @@ -193,7 +193,7 @@ def get_prediction_heads_config(self): return heads def register_custom_head(self, identifier, head): - self.config.custom_heads[identifier] = head + self.custom_heads[identifier] = head @property def active_head(self) -> Union[str, List[str]]: @@ -253,8 +253,8 @@ def set_active_adapters( ) def add_custom_head(self, head_type, head_name, overwrite_ok=False, set_active=True, **kwargs): - if head_type in self.config.custom_heads: - head = self.config.custom_heads[head_type](self, head_name, **kwargs) + if head_type in self.custom_heads: + head = self.custom_heads[head_type](self, head_name, **kwargs) # When a build-in head is added as a custom head it does not have the head_type property if not hasattr(head.config, "head_type"): head.config["head_type"] = head_type diff --git a/tests/test_adapter_custom_head.py b/tests/test_adapter_custom_head.py index b68662bfc6..8e29636d05 100644 --- a/tests/test_adapter_custom_head.py +++ b/tests/test_adapter_custom_head.py @@ -48,9 +48,10 @@ def test_add_custom_head(self): def test_save_load_custom_head(self): model_name = "bert-base-uncased" model_config = AutoConfig.from_pretrained(model_name) - model_config.custom_heads = {"tag": CustomHead} model1 = AutoAdapterModel.from_pretrained(model_name, config=model_config) model2 = AutoAdapterModel.from_pretrained(model_name, config=model_config) + model1.custom_heads = {"tag": CustomHead} + model2.custom_heads = {"tag": CustomHead} config = {"num_labels": 3, "layers": 2, "activation_function": "tanh"} model1.add_custom_head(head_type="tag", head_name="custom_head", **config)