Skip to content

Commit

Permalink
Automatically convert heads when loading with XAdapterModel (#594)
Browse files Browse the repository at this point in the history
Changes loading logic of XAdapterModel classes such that static heads of
loaded checkpoint will be automatically converted.

Limitation: Not working for sharded checkpoints.
  • Loading branch information
calpt authored Nov 11, 2023
1 parent dfe17e9 commit 6ae327a
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition
from ..context import AdapterSetup, ForwardContext
from ..loading import PredictionHeadLoader
from ..methods.modeling import Activation_Function_Class
from ..model_mixin import ModelWithHeadsAdaptersMixin

Expand Down Expand Up @@ -891,3 +892,48 @@ def get_labels(self, head_name=None):
return None
else:
return list(label_dict.values())

# This method is called during model loading in from_pretrained() to apply the state_dict to the model.
# Override it to inject adapter head logic.
@classmethod
def _load_pretrained_model(
cls,
model,
state_dict,
loaded_keys,
*args,
**kwargs,
):
# Filter only weights not part of base model
if state_dict is not None:
head_state_dict = {
key: value for key, value in state_dict.items() if not key.startswith(cls.base_model_prefix)
}
else:
head_state_dict = None
head_name = "default"
loader = PredictionHeadLoader(model, error_on_missing=False, convert_to_flex_head=True)
head_config, new_head_state_dict = loader.convert_static_to_flex_head(head_state_dict, load_as=head_name)

if head_config is not None:
# add head from config
if head_name in model.heads:
logger.warning("Overwriting existing head '{}'".format(head_name))

model.add_prediction_head_from_config(head_name, head_config, overwrite_ok=True)

if new_head_state_dict is not None:
for k in head_state_dict:
del state_dict[k]
loaded_keys.remove(k)
for k in new_head_state_dict:
state_dict[k] = new_head_state_dict[k]
loaded_keys.append(k)

return super()._load_pretrained_model(
model,
state_dict,
loaded_keys,
*args,
**kwargs,
)
40 changes: 40 additions & 0 deletions src/adapters/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,3 +766,43 @@ def load(self, save_directory, load_as=None, loading_info=None, **kwargs):
)

return save_directory, head_name

def convert_static_to_flex_head(self, state_dict, load_as="default"):
"""
Loads a prediction head module from the given state dict, which contains a static head checkpoint.
Args:
state_dict (dict): The static head checkpoint from which to load the head module. Can be None.
load_as (str, optional): Load the weights with this name. Defaults to None.
Returns:
Tuple[dict, dict]: A tuple consisting of the head config and the state dict of the loaded weights.
"""
assert self.convert_to_flex_head, "load_from_state_dict() can only be used with convert_to_flex_head=True."
assert hasattr(self.model, "heads"), "load_from_state_dict() can only be used with flex heads model class."

conversion_rename_func = None

original_model_class = self.model.config.architectures[0] if self.model.config.architectures else None
if original_model_class in STATIC_TO_FLEX_HEAD_MAP:
head_config, conversion_rename_func = get_head_config_and_rename_list(
original_model_class,
load_as,
getattr(self.model.config, "label2id"),
)
elif self.error_on_missing:
raise ValueError(
f"Cannot automatically convert prediction head of model class {original_model_class} to flex head."
)
else:
return None, None

# Load head weights
if state_dict is not None:
new_state_dict = {}
for k, v in state_dict.items():
new_k = conversion_rename_func(k)
new_state_dict[new_k] = v
else:
new_state_dict = None
return head_config, new_state_dict
36 changes: 36 additions & 0 deletions tests_adapters/test_adapter_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,39 @@ def test_equivalent_language_generation(self):

self.assertEquals(model_gen.shape, flex_model_gen.shape)
self.assertTrue(torch.equal(model_gen, flex_model_gen))

def test_full_model_conversion(self):
if self.config_class not in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING:
self.skipTest("No sequence classification class.")

static_head_model = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING[self.config_class](self.config())
adapters.init(static_head_model)
static_head_model.eval()

with tempfile.TemporaryDirectory() as temp_dir:
static_head_model.save_pretrained(temp_dir)

flex_head_model, loading_info = AutoAdapterModel.from_pretrained(temp_dir, output_loading_info=True)

# Roberta-based models always have a pooler, which is not used by the tested head
keys_to_ignore = ["roberta.pooler.dense.weight", "roberta.pooler.dense.bias"]

missing_keys = [k for k in loading_info["missing_keys"] if k not in keys_to_ignore]

self.assertEqual(0, len(missing_keys), "Missing keys: {}".format(", ".join(missing_keys)))
self.assertEqual(
0,
len(loading_info["unexpected_keys"]),
"Unexpected keys: {}".format(", ".join(loading_info["unexpected_keys"])),
)

# static head is re-loaded as "default"
self.assertIn("default", flex_head_model.heads)

# check equal output
in_data = self.get_input_samples(config=flex_head_model.config)
static_head_model.to(torch_device)
flex_head_model.to(torch_device)
output1 = static_head_model(**in_data)
output2 = flex_head_model(**in_data, head="default")
self.assertTrue(torch.allclose(output1.logits, output2.logits))

0 comments on commit 6ae327a

Please sign in to comment.