Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add adapter_summary() method #371

Merged
merged 4 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 81 additions & 2 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ def delete_adapter(self, adapter_name: str):
return
del self.config.adapters.adapters[adapter_name]
self.apply_to_adapter_layers(lambda i, layer: layer.delete_adapter(adapter_name))
# PHM Layer
if adapter_name in self.shared_parameters:
del self.shared_parameters[adapter_name]
if isinstance(self, InvertibleAdaptersMixin):
self.delete_invertible_adapter(adapter_name)
# Reset active adapters if this was the only active adapter
Expand Down Expand Up @@ -754,20 +757,96 @@ def get_adapter(self, name) -> dict:

Returns:
dict: A nested dictionary containing the weights of the adapter. The dictionary is structured as follow:
{<layer id>: {<module location>: <nn.Module>}}.
{<layer id>: {<module location>: <nn.Module>}}. <layer id> = -1 indicates global/ shared weights.
"""
destination = defaultdict(dict)

# global weights are saved at index -1
if name in self.shared_parameters:
destination[-1]["shared"] = self.shared_parameters[name]
if isinstance(self, InvertibleAdaptersMixin) and name in self.invertible_adapters:
destination[-1]["invertible"] = self.invertible_adapters[name]

# use a custom index to ensure numbering is from 0 to N layers
for i, (_, layer) in enumerate(self.iter_layers()):
for module in layer.modules():
if isinstance(module, AdapterLayerBase):
adapter_module = module.get_adapter(name)
if adapter_module is not None:
destination[i][module.location_key] = adapter_module
# location_key might already be added before -> concat to ModuleList
if module.location_key in destination[i]:
old_module = destination[i][module.location_key]
if isinstance(old_module, nn.ModuleList):
old_module.append(adapter_module)
else:
destination[i][module.location_key] = nn.ModuleList([old_module, adapter_module])
else:
destination[i][module.location_key] = adapter_module

return dict(destination)

def adapter_summary(self, as_dict=False) -> Union[str, dict]:
"""
Returns a string summary of all adapters currently added to the model. Each entry in the summary table has the
following attributes:

- name: the name of the adapter
- architecture: the architectural base of the adapter
- #param: the number of parameters of the adapter
- %param: the number of parameters of the adapter relative to the full model
- active: whether the adapter is active
- train: whether the adapter weights are enabled for training
"""
# table header
header = ["name", "architecture", "#param", "%param", "active", "train"]
# rows containing adapter info
rows = []
# fill in data for adapters
for name, config_name in self.config.adapters.adapters.items():
config = self.config.adapters.config_map[config_name]
row = {"name": name, "architecture": config.architecture or "bottleneck"}
weights = self.get_adapter(name)
row["active"] = self.active_adapters is not None and name in self.active_adapters.flatten()
# count parameters
no_params = 0
train = True
for _, module_dict in weights.items():
for _, module in module_dict.items():
no_params += sum(p.numel() for p in module.parameters())
train &= all(p.requires_grad for p in module.parameters())
row["#param"] = no_params
row["train"] = train
rows.append(row)
# count no. of parameters in base network
model_no_params = sum(p.numel() for p in self.base_model.parameters())
model_no_params -= sum([r["#param"] for r in rows])
# add %param info
for row in rows:
row["%param"] = row["#param"] / model_no_params * 100
# add full model info
rows.append(
{
"name": "Full model",
"#param": model_no_params,
"%param": 100.0,
"train": not getattr(self.base_model, "model_frozen", False),
}
)

if as_dict:
return rows
else:
# print
total_length = 80
header_format = "{:<25}{:<15}{:>12}{:>12}{:>8}{:>8}"
row_format = "{:<25}{:<15}{:>12}{:>12.3f}{:>8}{:>8}"
s = [header_format.format(*map(lambda x: x.title(), header))]
s.append("-" * total_length)
for row in rows:
s.append(row_format.format(*[row.get(h, "") for h in header]))
s.insert(len(s) - 1, "-" * total_length)
return "\n".join(s)

def eject_prefix_tuning(self, name: str):
"""
Converts the prefix tuning with the given name from the reparameterized form into the flat form.
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/adapters/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapt
def get_adapter(self, adapter_name):
# Make sure to only return params once
if adapter_name in self.prefixes and self.prefixes[adapter_name] == 0:
return self.pool.get_prefix(adapter_name)
prefix_module = self.pool.get_prefix(adapter_name)
if prefix_module is not None:
return prefix_module[self.location_key]

return None

Expand Down
16 changes: 16 additions & 0 deletions tests_adapters/methods/test_adapter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,22 @@ def test_model_config_serialization(self):
# should not raise an exception
model.config.to_json_string()

def test_model_adapter_summary(self):
# count model parameters before
model = self.get_model()
model_no_params = sum(p.numel() for p in model.parameters())
for k, v in ADAPTER_CONFIG_MAP.items():
# HACK: reduce the reduction factor such that
# the small test model can have a phm_dim of 4
if hasattr(v, "phm_layer") and v.phm_layer:
v = v.__class__(reduction_factor=4)
model.add_adapter(k, config=v)
summary = model.adapter_summary(as_dict=True)
self.assertEqual(len(ADAPTER_CONFIG_MAP) + 1, len(summary))
for name in ADAPTER_CONFIG_MAP.keys():
self.assertTrue(any([row["name"] == name for row in summary]))
self.assertEqual(model_no_params, summary[-1]["#param"])

def test_loading_adapter_weights_with_prefix(self):
if self.config_class not in ADAPTER_MODEL_MAPPING:
self.skipTest("Does not support flex heads.")
Expand Down
2 changes: 1 addition & 1 deletion tests_adapters/test_adapter_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def test_parallel_training(self):

train_dataset = self.dataset(tokenizer)
training_args = TrainingArguments(
output_dir="./examples", do_train=True, learning_rate=0.1, max_steps=15, no_cuda=True
output_dir="./examples", do_train=True, learning_rate=0.5, max_steps=20, no_cuda=True
)

# evaluate
Expand Down