Skip to content

Commit

Permalink
Merge pull request #75 from hSterz/master
Browse files Browse the repository at this point in the history
Label Information
  • Loading branch information
hSterz authored Oct 27, 2020
2 parents 64d6cda + 5901aeb commit e070dc7
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 9 deletions.
62 changes: 56 additions & 6 deletions src/transformers/adapter_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,6 @@ def adapter_fusion(self, hidden_states, adapter_stack, residual, query):
up_list.append(up)

if len(up_list) > 0:

up_list = torch.stack(up_list)
up_list = up_list.permute(1, 2, 0, 3)

Expand Down Expand Up @@ -604,7 +603,14 @@ def set_active_adapters(self, adapter_names: list):
logger.info("No prediction head for task_name '{}' available.".format(head_name))

def add_classification_head(
self, head_name, num_labels=2, layers=2, activation_function="tanh", overwrite_ok=False, multilabel=False
self,
head_name,
num_labels=2,
layers=2,
activation_function="tanh",
overwrite_ok=False,
multilabel=False,
id2label=None,
):
"""Adds a sequence classification head on top of the model.
Expand All @@ -620,17 +626,17 @@ def add_classification_head(
head_type = "multilabel_classification"
else:
head_type = "classification"

config = {
"head_type": head_type,
"num_labels": num_labels,
"layers": layers,
"activation_function": activation_function,
"label2id": {label: id_ for id_, label in id2label.items()} if id2label else None,
}
self.add_prediction_head(head_name, config, overwrite_ok)

def add_multiple_choice_head(
self, head_name, num_choices=2, layers=2, activation_function="tanh", overwrite_ok=False,
self, head_name, num_choices=2, layers=2, activation_function="tanh", overwrite_ok=False, id2label=None
):
"""Adds a multiple choice head on top of the model.
Expand All @@ -646,11 +652,12 @@ def add_multiple_choice_head(
"num_choices": num_choices,
"layers": layers,
"activation_function": activation_function,
"label2id": {label: id_ for id_, label in id2label.items()} if id2label else None,
}
self.add_prediction_head(head_name, config, overwrite_ok)

def add_tagging_head(
self, head_name, num_labels=2, layers=1, activation_function="tanh", overwrite_ok=False,
self, head_name, num_labels=2, layers=1, activation_function="tanh", overwrite_ok=False, id2label=None
):
"""Adds a token classification head on top of the model.
Expand All @@ -666,17 +673,19 @@ def add_tagging_head(
"num_labels": num_labels,
"layers": layers,
"activation_function": activation_function,
"label2id": {label: id_ for id_, label in id2label.items()} if id2label else None,
}
self.add_prediction_head(head_name, config, overwrite_ok)

def add_qa_head(
self, head_name, num_labels=2, layers=1, activation_function="tanh", overwrite_ok=False,
self, head_name, num_labels=2, layers=1, activation_function="tanh", overwrite_ok=False, id2label=None
):
config = {
"head_type": "question_answering",
"num_labels": num_labels,
"layers": layers,
"activation_function": activation_function,
"label2id": {label: id_ for id_, label in id2label.items()} if id2label else None,
}
self.add_prediction_head(head_name, config, overwrite_ok)

Expand All @@ -686,6 +695,12 @@ def add_prediction_head(
if head_name not in self.config.prediction_heads or overwrite_ok:
self.config.prediction_heads[head_name] = config

if "label2id" not in config.keys() or config["label2id"] is None:
if "num_labels" in config.keys():
config["label2id"] = {"LABEL_" + str(num): num for num in range(config["num_labels"])}
if "num_choices" in config.keys():
config["label2id"] = {"LABEL_" + str(num): num for num in range(config["num_choices"])}

logger.info(f"Adding head '{head_name}' with config {config}.")
self._add_prediction_head_module(head_name)
self.active_head = head_name
Expand Down Expand Up @@ -810,3 +825,38 @@ def forward_head(self, outputs, head_name=None, attention_mask=None, labels=None
raise ValueError("Unknown head_type '{}'".format(head["head_type"]))

return outputs # (loss), logits, (hidden_states), (attentions)

def get_labels_dict(self, head_name=None):
"""
Returns the id2label dict for the given head
Args:
head_name: (str, optional) the name of the head which labels should be returned. Default is None.
If the name is None the labels of the active head are returned
Returns: id2label
"""
if head_name is None:
head_name = self.active_head
if head_name is None:
raise ValueError("No head name given and no active head in the model")
if "label2id" in self.config.prediction_heads[head_name].keys():
return {id_: label for label, id_ in self.config.prediction_heads[head_name]["label2id"].items()}
else:
return None

def get_labels(self, head_name=None):
"""
Returns the labels the given head is assigning/predicting
Args:
head_name: (str, optional) the name of the head which labels should be returned. Default is None.
If the name is None the labels of the active head are returned
Returns: labels
"""
label_dict = self.get_labels_dict(head_name)
if label_dict is None:
return None
else:
return list(label_dict.values())
4 changes: 3 additions & 1 deletion src/transformers/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,11 @@ def to_dict(self):
return output_dict


def build_full_config(adapter_config, model_config, **kwargs):
def build_full_config(adapter_config, model_config, save_id2label=False, **kwargs):
config_dict = {"model_type": model_config.model_type, "hidden_size": model_config.hidden_size}
config_dict.update(kwargs)
if not hasattr(model_config, "prediction_heads") and save_id2label:
config_dict["label2id"] = model_config.label2id
if is_dataclass(adapter_config):
config_dict["config"] = adapter_config.to_dict()
else:
Expand Down
16 changes: 14 additions & 2 deletions src/transformers/adapter_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def save(self, save_directory: str, name: str = None):
name=name,
model_name=self.model.model_name,
model_class=self.model.__class__.__name__,
save_id2label=True,
)
self.weights_helper.save_weights_config(save_directory, config_dict)

Expand Down Expand Up @@ -588,6 +589,9 @@ def load(self, save_directory, load_as=None, loading_info=None):
# Load head config if available - otherwise just blindly try to load the weights
if isfile(join(save_directory, HEAD_CONFIG_NAME)):
config = self.weights_helper.load_weights_config(save_directory)
if (not config["config"] is None) and "label2id" in config["config"].keys():
config["config"]["label2id"] = {label: id_ for label, id_ in config["config"]["label2id"].items()}
config["config"]["id2label"] = {id_: label for label, id_ in config["config"]["label2id"].items()}
# make sure that the model class of the loaded head matches the current class
if self.model.__class__.__name__ != config["model_class"]:
if self.error_on_missing:
Expand All @@ -603,7 +607,10 @@ def load(self, save_directory, load_as=None, loading_info=None):
if head_name in self.model.config.prediction_heads:
logger.warning("Overwriting existing head '{}'".format(head_name))
self.model.add_prediction_head(head_name, config["config"], overwrite_ok=True)

else:
if "label2id" in config.keys():
self.model.config.id2label = {int(id_): label for label, id_ in config["label2id"].items()}
self.model.config.label2id = {label: int(id_) for label, id_ in config["label2id"].items()}
# Load head weights
filter_func = self.filter_func(head_name)
if load_as:
Expand All @@ -623,7 +630,6 @@ class ModelAdaptersMixin(ABC):
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.model_name = None

self._active_adapter_names = None

# These methods have to be implemented by every deriving class:
Expand Down Expand Up @@ -1024,3 +1030,9 @@ def save_all_adapters(
super().save_all_adapters(
save_directory, meta_dict=meta_dict, custom_weights_loaders=custom_weights_loaders,
)

def get_labels(self):
return list(self.config.id2label.values())

def get_labels_dict(self):
return self.config.id2label
104 changes: 104 additions & 0 deletions tests/test_save_id2label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import unittest
from tempfile import TemporaryDirectory
from typing import Dict

from transformers import AutoConfig, AutoModelForTokenClassification, AutoModelWithHeads, BertForSequenceClassification


def get_default(num_label):
labels = ["LABEL_" + str(i) for i in range(num_label)]
label_dict = {id_: label for id_, label in enumerate(labels)}
return labels, label_dict


class TestSaveLabel(unittest.TestCase):
def setUp(self):
self.labels = [
"ADJ",
"ADP",
"ADV",
"AUX",
"CCONJ",
"DET",
"INTJ",
"NOUN",
"NUM",
"PART",
"PRON",
"PROPN",
"PUNCT",
"SCONJ",
"SYM",
"VERB",
"X",
]
self.label_map: Dict[int, str] = {i: label for i, label in enumerate(self.labels)}
self.model_name = "bert-base-uncased"
self.config = AutoConfig.from_pretrained(
self.model_name,
num_labels=len(self.labels),
id2label=self.label_map,
label2id={label: i for i, label in enumerate(self.labels)},
)

def test_classification_model_head_labels(self):
model = AutoModelForTokenClassification.from_pretrained(self.model_name, config=self.config)
with TemporaryDirectory() as temp_dir:
model.save_head(temp_dir)
model.load_head(temp_dir)

self.assertEqual(self.labels, model.get_labels())
self.assertDictEqual(self.label_map, model.get_labels_dict())

def test_sequ_classification_model_head_labels(self):
model = BertForSequenceClassification.from_pretrained(self.model_name, config=self.config)
with TemporaryDirectory() as temp_dir:
model.save_head(temp_dir)
model.load_head(temp_dir)

self.assertEqual(self.labels, model.get_labels())
self.assertDictEqual(self.label_map, model.get_labels_dict())

def test_model_with_heads_tagging_head_labels(self):
model = AutoModelWithHeads.from_pretrained(self.model_name, config=self.config)
model.add_tagging_head("test_head", num_labels=len(self.labels), id2label=self.label_map)
with TemporaryDirectory() as temp_dir:
model.save_head(temp_dir, "test_head")
model.load_head(temp_dir)
# this is just loaded to test whether loading an adapter changes the label information
model.load_adapter("sst-2", "text_task")

self.assertEqual(self.labels, model.get_labels())
self.assertDictEqual(self.label_map, model.get_labels_dict())

def test_multiple_heads_label(self):
model = AutoModelWithHeads.from_pretrained(self.model_name, config=self.config)
model.add_tagging_head("test_head", num_labels=len(self.labels), id2label=self.label_map)
with TemporaryDirectory() as temp_dir:
model.save_head(temp_dir, "test_head")
model.load_head(temp_dir)
# adapter loaded for testing whether it changes label information
model.load_adapter("sst-2", "text_task")
model.add_classification_head("classification_head")
default_label, default_label_dict = get_default(2)

self.assertEqual(model.get_labels("classification_head"), default_label)
self.assertEqual(model.get_labels_dict("classification_head"), default_label_dict)

def test_model_with_heads_multiple_heads(self):
model = AutoModelWithHeads.from_pretrained(self.model_name, config=self.config)
model.add_tagging_head("test_head", num_labels=len(self.labels), id2label=self.label_map)
model.add_classification_head("second_head", num_labels=5)
with TemporaryDirectory() as temp_dir:
model.save_head(temp_dir + "/test_head", "test_head")
model.load_head(temp_dir + "/test_head")
model.save_head(temp_dir + "/second_head", "second_head")
model.load_head(temp_dir + "/second_head")
model.load_adapter("sst-2", "text_task")

self.assertEqual(model.get_labels("test_head"), self.labels)
self.assertEqual(model.get_labels_dict("test_head"), self.label_map)


if __name__ == "__main__":
unittest.main()

0 comments on commit e070dc7

Please sign in to comment.