Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Transformer Head Pruner #3884

Merged
merged 67 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
1a1172c
local code sync
xiaowu0162 Jun 29, 2021
5f27c35
graph-based weight grouping
xiaowu0162 Jun 30, 2021
d960426
fix for pipeline
xiaowu0162 Jun 30, 2021
faedb0f
pipeline related
xiaowu0162 Jun 30, 2021
c62b9a1
add activation-based maskers; refactor example
xiaowu0162 Jul 1, 2021
6877b64
minor fix
xiaowu0162 Jul 1, 2021
595864e
change graph-based grouping logic
xiaowu0162 Jul 2, 2021
bd7ff9f
remove redundant code
xiaowu0162 Jul 2, 2021
b28725f
Add taylor masker
xiaowu0162 Jul 6, 2021
80bdf06
debug
xiaowu0162 Jul 6, 2021
d5582dd
debug
xiaowu0162 Jul 6, 2021
0715a70
Add global sorting
xiaowu0162 Jul 6, 2021
d1e5d8d
debug
xiaowu0162 Jul 6, 2021
aece26a
debug
xiaowu0162 Jul 6, 2021
9cf94fe
Add iterative pruning
xiaowu0162 Jul 6, 2021
7c73fc8
debug
xiaowu0162 Jul 6, 2021
79186e2
Simplify API; add doc strings
xiaowu0162 Jul 7, 2021
690969a
debug
xiaowu0162 Jul 7, 2021
9d34493
docstring
xiaowu0162 Jul 7, 2021
1e2329e
example v1
xiaowu0162 Jul 7, 2021
7121051
Merge branch 'microsoft:master' into bertpruner
xiaowu0162 Jul 8, 2021
94e4804
doc skeleton
xiaowu0162 Jul 9, 2021
a5a92d9
doc update
xiaowu0162 Jul 9, 2021
5399dd8
doc update
xiaowu0162 Jul 9, 2021
ec5cdf2
doc update
xiaowu0162 Jul 9, 2021
f747a9a
doc update
xiaowu0162 Jul 9, 2021
f70343c
update
xiaowu0162 Jul 9, 2021
e6b6d84
doc debug
xiaowu0162 Jul 11, 2021
5aa63d2
update examples
xiaowu0162 Jul 11, 2021
b0b01fc
debug
xiaowu0162 Jul 11, 2021
af51872
debug
xiaowu0162 Jul 11, 2021
521b4c8
debug
xiaowu0162 Jul 11, 2021
42b3d5a
debug
xiaowu0162 Jul 11, 2021
59b8adb
fix ungrouped module removing logic
xiaowu0162 Jul 16, 2021
af2144a
Update shape dependency to align with master
xiaowu0162 Jul 16, 2021
d8a11c2
Merge branch 'master' into bertpruner
xiaowu0162 Jul 16, 2021
3e446ed
doc string debug
xiaowu0162 Jul 16, 2021
8fa6263
resolve comments
xiaowu0162 Jul 16, 2021
f45865f
update docs
xiaowu0162 Jul 16, 2021
ec35267
redo example
xiaowu0162 Jul 18, 2021
b46bcee
docstring
xiaowu0162 Jul 18, 2021
8f83131
debug
xiaowu0162 Jul 18, 2021
941a301
doc
xiaowu0162 Jul 18, 2021
be5f38b
debug
xiaowu0162 Jul 18, 2021
f193106
debug
xiaowu0162 Jul 19, 2021
b9837e8
unit test
xiaowu0162 Jul 19, 2021
53ab047
ut debug
xiaowu0162 Jul 19, 2021
416f680
replace torch.linalg.norm with torch.norm
xiaowu0162 Jul 19, 2021
4d91a1e
update ut
xiaowu0162 Jul 20, 2021
c0efcd5
improve docs
xiaowu0162 Jul 26, 2021
bb8f437
debug
xiaowu0162 Jul 26, 2021
61679f1
example fix
xiaowu0162 Jul 26, 2021
c0b93ed
handle empty groups caused by config
xiaowu0162 Jul 26, 2021
856ee2a
sanity check
xiaowu0162 Jul 26, 2021
32bf76a
head indexing
xiaowu0162 Jul 27, 2021
ed11f42
default args
xiaowu0162 Jul 27, 2021
50d9cb9
improve docs
xiaowu0162 Jul 27, 2021
a6727d5
debug
xiaowu0162 Jul 27, 2021
9411998
debug
xiaowu0162 Jul 27, 2021
ae31c18
epoch parameter to trainer
xiaowu0162 Jul 27, 2021
266ec1c
update example
xiaowu0162 Jul 27, 2021
83cfd4a
Merge branch 'microsoft:master' into bertpruner
xiaowu0162 Jul 27, 2021
a512ef3
forward_runner API v1
xiaowu0162 Jul 27, 2021
5e5b962
Merge branch 'bertpruner' of https://github.com/xiaowu0162/nni into b…
xiaowu0162 Jul 27, 2021
00ad0ef
delete some usages
xiaowu0162 Jul 27, 2021
a62f9de
update docs
xiaowu0162 Jul 27, 2021
80383af
update ut
xiaowu0162 Jul 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/en_US/Compression/CompressionReference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ Pruners
.. autoclass:: nni.algorithms.compression.pytorch.pruning.lottery_ticket.LotteryTicketPruner
:members:

.. autoclass:: nni.algorithms.compression.pytorch.pruning.transformer_pruner.TransformerHeadPruner
:members:

Quantizers
^^^^^^^^^^
Expand Down
4 changes: 3 additions & 1 deletion docs/en_US/Compression/Overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ The algorithms include pruning algorithms and quantization algorithms.
Pruning Algorithms
^^^^^^^^^^^^^^^^^^

Pruning algorithms compress the original network by removing redundant weights or channels of layers, which can reduce model complexity and mitigate the over-fitting issue.
Pruning algorithms compress the original network by removing redundant weights or channels of layers, which can reduce model complexity and mitigate the over-fitting issue.

.. list-table::
:header-rows: 1
Expand Down Expand Up @@ -73,6 +73,8 @@ Pruning algorithms compress the original network by removing redundant weights o
- Automatic pruning by iteratively call SimulatedAnnealing Pruner and ADMM Pruner `Reference Paper <https://arxiv.org/abs/1907.03141>`__
* - `AMC Pruner <../Compression/Pruner.rst#amc-pruner>`__
- AMC: AutoML for Model Compression and Acceleration on Mobile Devices `Reference Paper <https://arxiv.org/pdf/1802.03494.pdf>`__
* - `Transformer Head Pruner <../Compression/Pruner.rst#transformer-head-pruner>`__
- Pruning attention heads from transformer models either in one shot or iteratively.


You can refer to this `benchmark <../CommunitySharings/ModelCompressionComparison.rst>`__ for the performance of these pruners on some benchmark problems.
Expand Down
126 changes: 126 additions & 0 deletions docs/en_US/Compression/Pruner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a
**Others**

* `Lottery Ticket Hypothesis <#lottery-ticket-hypothesis>`__
* `Transformer Head Pruner <#transformer-head-pruner>`__

Level Pruner
------------
Expand Down Expand Up @@ -722,3 +723,128 @@ User configuration for Sensitivity Pruner
**PyTorch**

.. autoclass:: nni.algorithms.compression.pytorch.pruning.SensitivityPruner

Transformer Head Pruner
-----------------------

Transformer Head Pruner is a tool designed for pruning attention heads from the models belonging to the `Transformer family <https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>`__. The following image from `Efficient Transformers: A Survey <https://arxiv.org/pdf/2009.06732.pdf>`__ gives a good overview the general structure of the Transformer.

.. image:: ../../img/transformer_structure.png
:target: ../../img/transformer_structure.png
:alt:

Typically, each attention layer in the Transformer models consists of four weights: three projection matrices for query, key, value, and an output projection matrix. The outputs of the former three matrices contains the projected results for all heads. Normally, the results are then reshaped so that each head performs that attention computation independently. The final results are concatenated back before fed into the output projection. Therefore, when an attention head is pruned, the same weights corresponding to that heads in the three projection matrices are pruned. Also, the weights in the output projection corresponding to the head's output are pruned. In our implementation, we calculate and apply masks to the four matrices together.

The pruner implements the following algorithm:

.. code-block:: bash

Repeat for each pruning iteration (1 for one-shot pruning):
1. Calculate importance scores for each head in each specified layer using a specific criterion
2. Sort heads locally or globally, and prune out some heads with lowest scores. The number of pruned heads is determined according to the sparsity specified in the config.
3. If the specified pruning iteration is larger than 1 (iterative pruning), finetune the model for a while before the next pruning iteration.

Currently, the following head sorting criteria are supported:

* "l1_weight": rank heads by the L1-norm of weights of the query, key, and value projection matrices.
* "l2_weight": rank heads by the L2-norm of weights of the query, key, and value projection matrices.
* "l1_activation": rank heads by the L1-norm of their attention computation output.
* "l2_activation": rank heads by the L2-norm of their attention computation output.
* "taylorfo": rank heads by l1 norm of the output of attention computation * gradient for this output. Check more details in `this paper <https://arxiv.org/abs/1905.10650>`__ and `this one <https://arxiv.org/abs/1611.06440>`__.

We support local sorting (i.e., sorting heads within a layer) and global sorting (sorting all heads together), and you can control by setting the ``global_sort`` parameter. Note that if ``global_sort=True`` is passed, all weights must have the same sparsity in the config list. However, this does not mean that each layer will be prune to the same sparsity as specified. This sparsity value will be interpreted as a global sparsity, and each layer is likely to have different sparsity after pruning by global sort.

In our implementation, we support two ways to group the four weights in the same layer together. You can either pass a nested list containing the names of these modules (usage 1 below) to the pruner, or simply pass a dummy input and the pruner will run ``torch.jit.trace`` to group the weights (usage 2 below).

However, if you would like to assign different sparsity to each layer, currently you could only use the first option, i.e., passing names of the weights to the pruner (usage 3 below). Also note that weights belong to the same layer must have the same sparsity.

In addition to the following usage guide, we provide a more detailed example of pruning BERT for tasks from the GLUE benchmark. Please find it in this :githublink:`page <examples/model_compress/pruning/transformers>`.

Usage
^^^^^

Usage 1: one-shot pruning, same sparsity for all the layers (PyTorch code)

.. code-block:: python

from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner
kwargs = {'ranking_criterion': "l1_weight",
'global_sort': False,
'num_iterations': 1,
'epochs_per_iteration': 1, # this is ignored when num_iterations = 1
'head_hidden_dim': 64,
'dummy_input': dummy_input,
'trainer': trainer,
'optimizer': optimizer
}
config_list = [{
'sparsity': 0.5,
'op_types': ["Linear"]
xiaowu0162 marked this conversation as resolved.
Show resolved Hide resolved
}]
pruner = TransformerHeadPruner(model, config_list, **kwargs)
pruner.compress()

Usage 2: same effect as usage 1, the only change is passing names to the pruner instead of dummy input (PyTorch code)

.. code-block:: python

from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner
attention_name_groups = list(zip(['encoder.layer.{}.attention.self.query'.format(i) for i in range(12)],
['encoder.layer.{}.attention.self.key'.format(i) for i in range(12)],
['encoder.layer.{}.attention.self.value'.format(i) for i in range(12)],
['encoder.layer.{}.attention.output.dense'.format(i) for i in range(12)]))
kwargs = {'ranking_criterion': "l1_weight",
'global_sort': False,
'num_iterations': 1,
'epochs_per_iteration': 1, # this is ignored when num_iterations = 1
'head_hidden_dim': 64,
'attention_name_groups': attention_name_groups,
'trainer': trainer,
'optimizer': optimizer
}
config_list = [{
'sparsity': 0.5,
'op_types': ["Linear"]
}]
pruner = TransformerHeadPruner(model, config_list, **kwargs)
pruner.compress()

Usage 3: one-shot pruning, setting different sparsity for different layers (PyTorch code)

.. code-block:: python

from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner
attention_name_groups = list(zip(['encoder.layer.{}.attention.self.query'.format(i) for i in range(12)],
['encoder.layer.{}.attention.self.key'.format(i) for i in range(12)],
['encoder.layer.{}.attention.self.value'.format(i) for i in range(12)],
['encoder.layer.{}.attention.output.dense'.format(i) for i in range(12)]))
kwargs = {'ranking_criterion': "l1_weight",
'global_sort': False,
'num_iterations': 1,
'epochs_per_iteration': 1, # this is ignored when num_iterations = 1
'head_hidden_dim': 64,
'attention_name_groups': attention_name_groups, # can change to dummy_input here
'trainer': trainer,
'optimizer': optimizer
}
config_list = [{
'sparsity': 0.5,
'op_types': ["Linear"],
'op_names': [x for layer in attention_name_groups[:6] for x in layer] # first six layers
},
{
'sparsity': 0.25,
'op_types': ["Linear"],
'op_names': [x for layer in attention_name_groups[:6] for x in layer] # last six layers
xiaowu0162 marked this conversation as resolved.
Show resolved Hide resolved
}
]
pruner = TransformerHeadPruner(model, config_list, **kwargs)
pruner.compress()


User configuration for Transformer Head Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

**PyTorch**

.. autoclass:: nni.algorithms.compression.pytorch.pruning.TransformerHeadPruner
Binary file added docs/img/transformer_structure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
45 changes: 45 additions & 0 deletions examples/model_compress/pruning/transformers/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/bin/bash

# Usage: ./run.sh gpu_id glue_task

export CUDA_VISIBLE_DEVICES=$1
TASK_NAME=$2 # "cola", "sst2", "mrpc", "stsb", "qqp", "mnli", "qnli", "rte", "wnli"
PRETRAINED_MODEL="bert-base-uncased" # "distilbert-base-uncased", "roberta-base", "bert-base-cased", ...

# parameters for pruning
# change USAGE to different numbers (1, 2, 3) to run examples with different configs
USAGE=2
SPARSITY=0.5
RANKING_CRITERION=l1_weight # "l1_weight", "l2_weight", "l1_activation", "l2_activation", "taylorfo"
NUM_ITERATIONS=1 # 1 for one-shot pruning
EPOCHS_PER_ITERATION=1

# other training parameters, no need to change
MAX_LENGTH=128
BATCH_SIZE=32
LR=2e-5
N_EPOCHS=3

time=$(date "+%Y%m%d%H%M%S")
OUTDIR="models_${PRETRAINED_MODEL}_${TASK_NAME}_$time/"

TASK_LIST=("cola" "sst2" "mrpc" "stsb" "qqp" "mnli" "qnli" "rte" "wnli")
if [[ ${TASK_LIST[*]} =~ (^|[[:space:]])$TASK_NAME($|[[:space:]]) ]]; then
mkdir $OUTDIR
python transformer_pruning.py \
--sparsity $SPARSITY \
--ranking_criterion $RANKING_CRITERION \
--num_iterations $NUM_ITERATIONS \
--epochs_per_iteration $EPOCHS_PER_ITERATION \
--speed_up \
--model_name $PRETRAINED_MODEL \
--task_name $TASK_NAME \
--max_length $MAX_LENGTH \
--batch_size $BATCH_SIZE \
--learning_rate $LR \
--num_train_epochs $N_EPOCHS \
--output_dir $OUTDIR \
2>&1 | tee "$OUTDIR/output.log"
else
echo "Unsupported task $TASK_NAME."
fi
Loading