diff --git a/src/adapters/heads/base.py b/src/adapters/heads/base.py index d9c7386fec..201dd9bac7 100644 --- a/src/adapters/heads/base.py +++ b/src/adapters/heads/base.py @@ -20,6 +20,7 @@ from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition from ..context import AdapterSetup, ForwardContext +from ..loading import PredictionHeadLoader from ..methods.modeling import Activation_Function_Class from ..model_mixin import ModelWithHeadsAdaptersMixin @@ -891,3 +892,48 @@ def get_labels(self, head_name=None): return None else: return list(label_dict.values()) + + # This method is called during model loading in from_pretrained() to apply the state_dict to the model. + # Override it to inject adapter head logic. + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + loaded_keys, + *args, + **kwargs, + ): + # Filter only weights not part of base model + if state_dict is not None: + head_state_dict = { + key: value for key, value in state_dict.items() if not key.startswith(cls.base_model_prefix) + } + else: + head_state_dict = None + head_name = "default" + loader = PredictionHeadLoader(model, error_on_missing=False, convert_to_flex_head=True) + head_config, new_head_state_dict = loader.convert_static_to_flex_head(head_state_dict, load_as=head_name) + + if head_config is not None: + # add head from config + if head_name in model.heads: + logger.warning("Overwriting existing head '{}'".format(head_name)) + + model.add_prediction_head_from_config(head_name, head_config, overwrite_ok=True) + + if new_head_state_dict is not None: + for k in head_state_dict: + del state_dict[k] + loaded_keys.remove(k) + for k in new_head_state_dict: + state_dict[k] = new_head_state_dict[k] + loaded_keys.append(k) + + return super()._load_pretrained_model( + model, + state_dict, + loaded_keys, + *args, + **kwargs, + ) diff --git a/src/adapters/loading.py b/src/adapters/loading.py index 57d03ea62f..25ad7e4fc1 100644 --- a/src/adapters/loading.py +++ b/src/adapters/loading.py @@ -766,3 +766,43 @@ def load(self, save_directory, load_as=None, loading_info=None, **kwargs): ) return save_directory, head_name + + def convert_static_to_flex_head(self, state_dict, load_as="default"): + """ + Loads a prediction head module from the given state dict, which contains a static head checkpoint. + + Args: + state_dict (dict): The static head checkpoint from which to load the head module. Can be None. + load_as (str, optional): Load the weights with this name. Defaults to None. + + Returns: + Tuple[dict, dict]: A tuple consisting of the head config and the state dict of the loaded weights. + """ + assert self.convert_to_flex_head, "load_from_state_dict() can only be used with convert_to_flex_head=True." + assert hasattr(self.model, "heads"), "load_from_state_dict() can only be used with flex heads model class." + + conversion_rename_func = None + + original_model_class = self.model.config.architectures[0] if self.model.config.architectures else None + if original_model_class in STATIC_TO_FLEX_HEAD_MAP: + head_config, conversion_rename_func = get_head_config_and_rename_list( + original_model_class, + load_as, + getattr(self.model.config, "label2id"), + ) + elif self.error_on_missing: + raise ValueError( + f"Cannot automatically convert prediction head of model class {original_model_class} to flex head." + ) + else: + return None, None + + # Load head weights + if state_dict is not None: + new_state_dict = {} + for k, v in state_dict.items(): + new_k = conversion_rename_func(k) + new_state_dict[new_k] = v + else: + new_state_dict = None + return head_config, new_state_dict diff --git a/tests_adapters/test_adapter_conversion.py b/tests_adapters/test_adapter_conversion.py index ac57daa315..df209c12ba 100644 --- a/tests_adapters/test_adapter_conversion.py +++ b/tests_adapters/test_adapter_conversion.py @@ -198,3 +198,39 @@ def test_equivalent_language_generation(self): self.assertEquals(model_gen.shape, flex_model_gen.shape) self.assertTrue(torch.equal(model_gen, flex_model_gen)) + + def test_full_model_conversion(self): + if self.config_class not in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING: + self.skipTest("No sequence classification class.") + + static_head_model = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING[self.config_class](self.config()) + adapters.init(static_head_model) + static_head_model.eval() + + with tempfile.TemporaryDirectory() as temp_dir: + static_head_model.save_pretrained(temp_dir) + + flex_head_model, loading_info = AutoAdapterModel.from_pretrained(temp_dir, output_loading_info=True) + + # Roberta-based models always have a pooler, which is not used by the tested head + keys_to_ignore = ["roberta.pooler.dense.weight", "roberta.pooler.dense.bias"] + + missing_keys = [k for k in loading_info["missing_keys"] if k not in keys_to_ignore] + + self.assertEqual(0, len(missing_keys), "Missing keys: {}".format(", ".join(missing_keys))) + self.assertEqual( + 0, + len(loading_info["unexpected_keys"]), + "Unexpected keys: {}".format(", ".join(loading_info["unexpected_keys"])), + ) + + # static head is re-loaded as "default" + self.assertIn("default", flex_head_model.heads) + + # check equal output + in_data = self.get_input_samples(config=flex_head_model.config) + static_head_model.to(torch_device) + flex_head_model.to(torch_device) + output1 = static_head_model(**in_data) + output2 = flex_head_model(**in_data, head="default") + self.assertTrue(torch.allclose(output1.logits, output2.logits))