diff --git a/docs/en_US/Compression/CompressionReference.rst b/docs/en_US/Compression/CompressionReference.rst index 5903500115..b005ce725f 100644 --- a/docs/en_US/Compression/CompressionReference.rst +++ b/docs/en_US/Compression/CompressionReference.rst @@ -91,6 +91,8 @@ Pruners .. autoclass:: nni.algorithms.compression.pytorch.pruning.lottery_ticket.LotteryTicketPruner :members: +.. autoclass:: nni.algorithms.compression.pytorch.pruning.transformer_pruner.TransformerHeadPruner + :members: Quantizers ^^^^^^^^^^ diff --git a/docs/en_US/Compression/Overview.rst b/docs/en_US/Compression/Overview.rst index 788aa0ac84..723649af58 100644 --- a/docs/en_US/Compression/Overview.rst +++ b/docs/en_US/Compression/Overview.rst @@ -35,7 +35,7 @@ The algorithms include pruning algorithms and quantization algorithms. Pruning Algorithms ^^^^^^^^^^^^^^^^^^ -Pruning algorithms compress the original network by removing redundant weights or channels of layers, which can reduce model complexity and mitigate the over-ļ¬tting issue. +Pruning algorithms compress the original network by removing redundant weights or channels of layers, which can reduce model complexity and mitigate the over-fitting issue. .. list-table:: :header-rows: 1 @@ -73,6 +73,8 @@ Pruning algorithms compress the original network by removing redundant weights o - Automatic pruning by iteratively call SimulatedAnnealing Pruner and ADMM Pruner `Reference Paper `__ * - `AMC Pruner <../Compression/Pruner.rst#amc-pruner>`__ - AMC: AutoML for Model Compression and Acceleration on Mobile Devices `Reference Paper `__ + * - `Transformer Head Pruner <../Compression/Pruner.rst#transformer-head-pruner>`__ + - Pruning attention heads from transformer models either in one shot or iteratively. You can refer to this `benchmark <../CommunitySharings/ModelCompressionComparison.rst>`__ for the performance of these pruners on some benchmark problems. diff --git a/docs/en_US/Compression/Pruner.rst b/docs/en_US/Compression/Pruner.rst index d833873a1f..4fb0aa8674 100644 --- a/docs/en_US/Compression/Pruner.rst +++ b/docs/en_US/Compression/Pruner.rst @@ -28,6 +28,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a **Others** * `Lottery Ticket Hypothesis <#lottery-ticket-hypothesis>`__ +* `Transformer Head Pruner <#transformer-head-pruner>`__ Level Pruner ------------ @@ -724,3 +725,95 @@ User configuration for Sensitivity Pruner **PyTorch** .. autoclass:: nni.algorithms.compression.pytorch.pruning.SensitivityPruner + +Transformer Head Pruner +----------------------- + +Transformer Head Pruner is a tool designed for pruning attention heads from the models belonging to the `Transformer family `__. The following image from `Efficient Transformers: A Survey `__ gives a good overview the general structure of the Transformer. + +.. image:: ../../img/transformer_structure.png + :target: ../../img/transformer_structure.png + :alt: + +Typically, each attention layer in the Transformer models consists of four weights: three projection matrices for query, key, value, and an output projection matrix. The outputs of the former three matrices contains the projected results for all heads. Normally, the results are then reshaped so that each head performs that attention computation independently. The final results are concatenated back before fed into the output projection. Therefore, when an attention head is pruned, the same weights corresponding to that heads in the three projection matrices are pruned. Also, the weights in the output projection corresponding to the head's output are pruned. In our implementation, we calculate and apply masks to the four matrices together. + +Note: currently, the pruner can only handle models with projection weights written as separate ``Linear`` modules, i.e., it expects four ``Linear`` modules corresponding to query, key, value, and an output projections. Therefore, in the ``config_list``, you should either write ``['Linear']`` for the ``op_types`` field, or write names corresponding to ``Linear`` modules for the ``op_names`` field. + +The pruner implements the following algorithm: + +.. code-block:: bash + + Repeat for each pruning iteration (1 for one-shot pruning): + 1. Calculate importance scores for each head in each specified layer using a specific criterion. + 2. Sort heads locally or globally, and prune out some heads with lowest scores. The number of pruned heads is determined according to the sparsity specified in the config. + 3. If the specified pruning iteration is larger than 1 (iterative pruning), finetune the model for a while before the next pruning iteration. + +Currently, the following head sorting criteria are supported: + + * "l1_weight": rank heads by the L1-norm of weights of the query, key, and value projection matrices. + * "l2_weight": rank heads by the L2-norm of weights of the query, key, and value projection matrices. + * "l1_activation": rank heads by the L1-norm of their attention computation output. + * "l2_activation": rank heads by the L2-norm of their attention computation output. + * "taylorfo": rank heads by l1 norm of the output of attention computation * gradient for this output. Check more details in `this paper `__ and `this one `__. + +We support local sorting (i.e., sorting heads within a layer) and global sorting (sorting all heads together), and you can control by setting the ``global_sort`` parameter. Note that if ``global_sort=True`` is passed, all weights must have the same sparsity in the config list. However, this does not mean that each layer will be prune to the same sparsity as specified. This sparsity value will be interpreted as a global sparsity, and each layer is likely to have different sparsity after pruning by global sort. + +In our implementation, we support two ways to group the four weights in the same layer together. You can either pass a nested list containing the names of these modules as the pruner's initialization parameters (usage below), or simply pass a dummy input and the pruner will run ``torch.jit.trace`` to group the weights (experimental feature). However, if you would like to assign different sparsity to each layer, you can only use the first option, i.e., passing names of the weights to the pruner (see usage below). Also note that weights belonging to the same layer must have the same sparsity. + +In addition to the following usage guide, we provide a more detailed example of pruning BERT for tasks from the GLUE benchmark. Please find it in this :githublink:`page `. + +Usage +^^^^^ + +Suppose we want to prune a BERT with Huggingface implementation, which has the following architecture (obtained by calling ``print(model)``). Note that we only show the first layer of the repeated layers in the encoder's ``ModuleList layer``. + +.. image:: ../../img/huggingface_bert_architecture.png + :target: ../../img/huggingface_bert_architecture.png + :alt: + +**Usage Example: one-shot pruning, assigning sparsity 0.5 to the first six layers and sparsity 0.25 to the last six layers (PyTorch code)**. Note that + +* Here we specify ``op_names`` in the config list to assign different sparsity to different layers. +* Meanwhile, we pass ``attention_name_groups`` to the pruner so that the pruner may group together the weights belonging to the same attention layer. +* Since in this example we want to do one-shot pruning, the ``num_iterations`` parameter is set to 1, and the parameter ``epochs_per_iteration`` is ignored. If you would like to do iterative pruning instead, you can set the ``num_iterations`` parameter to the number of pruning iterations, and the ``epochs_per_iteration`` parameter to the number of finetuning epochs between two iterations. +* The arguments ``trainer`` and ``optimizer`` are only used when we want to do iterative pruning, or the ranking criterion is ``taylorfo``. Here these two parameters are ignored by the pruner. +* The argument ``forward_runner`` is only used when the ranking criterion is ``l1_activation`` or ``l2_activation``. Here this parameter is ignored by the pruner. + +.. code-block:: python + + from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner + attention_name_groups = list(zip(["encoder.layer.{}.attention.self.query".format(i) for i in range(12)], + ["encoder.layer.{}.attention.self.key".format(i) for i in range(12)], + ["encoder.layer.{}.attention.self.value".format(i) for i in range(12)], + ["encoder.layer.{}.attention.output.dense".format(i) for i in range(12)])) + kwargs = {"ranking_criterion": "l1_weight", + "global_sort": False, + "num_iterations": 1, + "epochs_per_iteration": 1, # this is ignored when num_iterations = 1 + "head_hidden_dim": 64, + "attention_name_groups": attention_name_groups, + "trainer": trainer, + "optimizer": optimizer, + "forward_runner": forward_runner + } + config_list = [{ + "sparsity": 0.5, + "op_types": ["Linear"], + "op_names": [x for layer in attention_name_groups[:6] for x in layer] # first six layers + }, + { + "sparsity": 0.25, + "op_types": ["Linear"], + "op_names": [x for layer in attention_name_groups[6:] for x in layer] # last six layers + } + ] + pruner = TransformerHeadPruner(model, config_list, **kwargs) + pruner.compress() + + +User configuration for Transformer Head Pruner +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +**PyTorch** + +.. autoclass:: nni.algorithms.compression.pytorch.pruning.TransformerHeadPruner diff --git a/docs/img/huggingface_bert_architecture.png b/docs/img/huggingface_bert_architecture.png new file mode 100644 index 0000000000..9187c79c2e Binary files /dev/null and b/docs/img/huggingface_bert_architecture.png differ diff --git a/docs/img/transformer_structure.png b/docs/img/transformer_structure.png new file mode 100644 index 0000000000..bd3fcc78b3 Binary files /dev/null and b/docs/img/transformer_structure.png differ diff --git a/examples/model_compress/pruning/transformers/run.sh b/examples/model_compress/pruning/transformers/run.sh new file mode 100755 index 0000000000..af00f1cf8e --- /dev/null +++ b/examples/model_compress/pruning/transformers/run.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# Usage: ./run.sh gpu_id glue_task + +export CUDA_VISIBLE_DEVICES=$1 +TASK_NAME=$2 # "cola", "sst2", "mrpc", "stsb", "qqp", "mnli", "qnli", "rte", "wnli" +PRETRAINED_MODEL="bert-base-uncased" # "distilbert-base-uncased", "roberta-base", "bert-base-cased", ... + +# parameters for pruning +SPARSITY=0.5 +RANKING_CRITERION=l1_weight # "l1_weight", "l2_weight", "l1_activation", "l2_activation", "taylorfo" +NUM_ITERATIONS=1 # 1 for one-shot pruning +EPOCHS_PER_ITERATION=1 + +# other training parameters, no need to change +MAX_LENGTH=128 +BATCH_SIZE=32 +LR=2e-5 +N_EPOCHS=3 + +time=$(date "+%Y%m%d%H%M%S") +OUTDIR="models_${PRETRAINED_MODEL}_${TASK_NAME}_$time/" + +TASK_LIST=("cola" "sst2" "mrpc" "stsb" "qqp" "mnli" "qnli" "rte" "wnli") +if [[ ${TASK_LIST[*]} =~ (^|[[:space:]])$TASK_NAME($|[[:space:]]) ]]; then + mkdir $OUTDIR + python transformer_pruning.py \ + --sparsity $SPARSITY \ + --ranking_criterion $RANKING_CRITERION \ + --num_iterations $NUM_ITERATIONS \ + --epochs_per_iteration $EPOCHS_PER_ITERATION \ + --speed_up \ + --model_name $PRETRAINED_MODEL \ + --task_name $TASK_NAME \ + --max_length $MAX_LENGTH \ + --batch_size $BATCH_SIZE \ + --learning_rate $LR \ + --num_train_epochs $N_EPOCHS \ + --output_dir $OUTDIR \ + 2>&1 | tee "$OUTDIR/output.log" +else + echo "Unsupported task $TASK_NAME." +fi diff --git a/examples/model_compress/pruning/transformers/transformer_pruning.py b/examples/model_compress/pruning/transformers/transformer_pruning.py new file mode 100644 index 0000000000..c98e0ec744 --- /dev/null +++ b/examples/model_compress/pruning/transformers/transformer_pruning.py @@ -0,0 +1,384 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import argparse +import logging +import math +import os +import random + +import torch +from torch.utils.data.dataloader import DataLoader +from tqdm.auto import tqdm + +import nni +from nni.compression.pytorch import ModelSpeedup +from nni.compression.pytorch.utils.counter import count_flops_params +from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner + + +import datasets +from datasets import load_dataset, load_metric +import transformers +from transformers import ( + AdamW, + AutoConfig, + AutoModel, + AutoModelForPreTraining, + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorWithPadding, + PretrainedConfig, + default_data_collator, + get_scheduler, +) + + +logger = logging.getLogger("bert_pruning_example") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Example: prune a Huggingface transformer and finetune on GLUE tasks.") + + parser.add_argument("--model_name", type=str, required=True, + help="Pretrained model architecture.") + parser.add_argument("--task_name", type=str, default=None, + help="The name of the GLUE task.", + choices=["cola", "mnli", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]) + parser.add_argument("--output_dir", type=str, default=None, + help="Where to store the model and mask.") + parser.add_argument("--sparsity", type=float, required=True, + help="Sparsity: proportion of heads to prune (should be between 0 and 1)") + parser.add_argument("--global_sort", action="store_true", default=False, + help="Rank the heads globally and prune the heads with lowest scores. If set to False, the " + "heads are only ranked within one layer") + parser.add_argument("--ranking_criterion", type=str, default="l1_weight", + choices=["l1_weight", "l2_weight", "l1_activation", "l2_activation", "taylorfo"], + help="Criterion by which the attention heads are ranked.") + parser.add_argument("--num_iterations", type=int, default=1, + help="Number of pruning iterations (1 for one-shot pruning).") + parser.add_argument("--epochs_per_iteration", type=int, default=1, + help="Epochs to finetune before the next pruning iteration " + "(only effective if num_iterations > 1).") + parser.add_argument("--speed_up", action="store_true", default=False, + help="Whether to speed-up the pruned model") + + # parameters for model training; no need to change them for running examples + parser.add_argument("--max_length", type=int, default=128, + help=("The maximum total input sequence length after tokenization. Sequences longer than this " + "will be truncated, sequences shorter will be padded if `--pad_to_max_lengh` is passed.")) + parser.add_argument("--batch_size", type=int, default=8, + help="Batch size.") + parser.add_argument("--learning_rate", type=float, default=5e-5, + help="Initial learning rate.") + parser.add_argument("--num_train_epochs", type=int, default=3, + help="Total number of training epochs to perform.") + parser.add_argument("--lr_scheduler_type", default="linear", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", + "constant_with_warmup"]) + parser.add_argument("--num_warmup_steps", type=int, default=0, + help="Number of steps for the warmup in the lr scheduler.") + + args = parser.parse_args() + + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + return args + + +def get_raw_dataset(task_name): + """ + Get a GLUE dataset using huggingface datasets. + """ + raw_dataset = load_dataset("glue", task_name) + is_regression = task_name == "stsb" + num_labels = 1 if is_regression else len(raw_dataset["train"].features["label"].names) + + return raw_dataset, is_regression, num_labels + + +def preprocess(args, tokenizer, raw_dataset): + """ + Tokenization and column renaming. + """ + assert args.task_name is not None + + task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), + } + sentence1_key, sentence2_key = task_to_keys[args.task_name] + + def tokenize(data): + texts = ( + (data[sentence1_key],) if sentence2_key is None else (data[sentence1_key], data[sentence2_key]) + ) + result = tokenizer(*texts, padding=False, max_length=args.max_length, truncation=True) + + if "label" in data: + result["labels"] = data["label"] + return result + + processed_datasets = raw_dataset.map(tokenize, batched=True, remove_columns=raw_dataset["train"].column_names) + return processed_datasets + + +def get_dataloader_and_optimizer(args, tokenizer, model, train_dataset, eval_dataset): + data_collator = DataCollatorWithPadding(tokenizer) + train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, + batch_size=args.batch_size) + eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, + batch_size=args.batch_size) + + optimizer = AdamW(model.parameters(), lr=args.learning_rate) + + return optimizer, train_dataloader, eval_dataloader, data_collator + + +def train_model(args, model, is_regression, train_dataloader, eval_dataloader, optimizer, lr_scheduler, metric, device): + """ + Train the model using train_dataloader and evaluate after every epoch using eval_dataloader. + This function is called before and after pruning for "pretraining" on the GLUE task and further "finetuning". + """ + train_steps = args.num_train_epochs * len(train_dataloader) + progress_bar = tqdm(range(train_steps), position=0, leave=True) + + for epoch in range(args.num_train_epochs): + model.train() + for step, batch in enumerate(train_dataloader): + for field in batch.keys(): + batch[field] = batch[field].to(device) + outputs = model(**batch) + outputs.loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + + model.eval() + for step, batch in enumerate(eval_dataloader): + for field in batch.keys(): + batch[field] = batch[field].to(device) + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() + metric.add_batch(predictions=predictions, references=batch["labels"]) + + eval_metric = metric.compute() + logger.info(f"epoch {epoch}: {eval_metric}") + + +def trainer_helper(model, train_dataloader, optimizer, device): + """ + This function is used for to create a "trainer" that is passed to the pruner. + Finetune the model for 1 epoch. This function is called by the pruner during pruning iterations (or called to + calculate scores for pruning when ranking criterion is "taylorfo"). + """ + logger.info("Training for 1 epoch...") + progress_bar = tqdm(range(len(train_dataloader)), position=0, leave=True) + + train_epoch = 1 + for epoch in range(train_epoch): + for step, batch in enumerate(train_dataloader): + for field in batch.keys(): + batch[field] = batch[field].to(device) + outputs = model(**batch) + outputs.loss.backward() + optimizer.step() + optimizer.zero_grad() + progress_bar.update(1) + + +def forward_runner_helper(model, train_dataloader, device): + """ + This function is used for to create a "forward_runner" that is passed to the pruner. + The function just runs forward on the train set without updating the parameters. + This allows the pruner to collect data for activation-based pruning methods. + """ + logger.info("Running forward on the entire train set without updating parameters...") + progress_bar = tqdm(range(len(train_dataloader)), position=0, leave=True) + + forward_epoch = 1 + for epoch in range(forward_epoch): + for step, batch in enumerate(train_dataloader): + for field in batch.keys(): + batch[field] = batch[field].to(device) + _ = model(**batch) + # note: no loss.backward or optimizer.step() is performed here + progress_bar.update(1) + + +def final_eval_for_mnli(args, model, processed_datasets, metric, data_collator): + """ + If the task is MNLI, perform a final evaluation on mismatched validation set + """ + eval_dataset = processed_datasets["validation_mismatched"] + eval_dataloader = DataLoader( + eval_dataset, collate_fn=data_collator, batch_size=args.batch_size + ) + + model.eval() + for step, batch in enumerate(eval_dataloader): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + metric.add_batch( + predictions=predictions, + references=batch["labels"], + ) + + eval_metric = metric.compute() + logger.info(f"mnli-mm: {eval_metric}") + + +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + args = parse_args() + + ######################################################################### + # Prepare model, tokenizer, dataset, optimizer, and the scheduler + logger.setLevel(logging.INFO) + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + + # Load dataset and tokenizer, and then preprocess the dataset + raw_dataset, is_regression, num_labels = get_raw_dataset(args.task_name) + tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True) + processed_datasets = preprocess(args, tokenizer, raw_dataset) + train_dataset = processed_datasets["train"] + eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] + + # Load pretrained model + config = AutoConfig.from_pretrained(args.model_name, num_labels=num_labels, finetuning_task=args.task_name) + model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config) + model.to(device) + + ######################################################################### + # Finetune on the target GLUE task before pruning + optimizer, train_dataloader, eval_dataloader, data_collator = get_dataloader_and_optimizer(args, tokenizer, + model, + train_dataset, + eval_dataset) + train_steps = args.num_train_epochs * len(train_dataloader) + lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps, + num_training_steps=train_steps) + metric = load_metric("glue", args.task_name) + + logger.info("================= Finetuning before pruning =================") + train_model(args, model, is_regression, train_dataloader, eval_dataloader, optimizer, lr_scheduler, metric, device) + + if args.output_dir is not None: + torch.save(model.state_dict(), args.output_dir + "/model_before_pruning.pt") + + if args.task_name == "mnli": + final_eval_for_mnli(args, model, processed_datasets, metric, data_collator) + + ######################################################################### + # Pruning + optimizer, train_dataloader, eval_dataloader, data_collator = get_dataloader_and_optimizer(args, tokenizer, + model, + train_dataset, + eval_dataset) + dummy_input = next(iter(train_dataloader))["input_ids"].to(device) + flops, params, results = count_flops_params(model, dummy_input) + print(f"Initial model FLOPs {flops / 1e6:.2f} M, #Params: {params / 1e6:.2f}M") + + # Here criterion is embedded in the model. Upper levels can just pass None to trainer. + def trainer(model, optimizer, criterion, epoch): + return trainer_helper(model, train_dataloader, optimizer, device) + + def forward_runner(model): + return forward_runner_helper(model, train_dataloader, device) + + # example: prune different layers with different sparsity + attention_name_groups = list(zip(["bert.encoder.layer.{}.attention.self.query".format(i) for i in range(12)], + ["bert.encoder.layer.{}.attention.self.key".format(i) for i in range(12)], + ["bert.encoder.layer.{}.attention.self.value".format(i) for i in range(12)], + ["bert.encoder.layer.{}.attention.output.dense".format(i) for i in range(12)])) + + kwargs = {"ranking_criterion": args.ranking_criterion, + "global_sort": args.global_sort, + "num_iterations": args.num_iterations, + "epochs_per_iteration": args.epochs_per_iteration, + "attention_name_groups": attention_name_groups, + "head_hidden_dim": 64, + "trainer": trainer, + "optimizer": optimizer, + "forward_runner": forward_runner} + + config_list = [{ + "sparsity": args.sparsity, + "op_types": ["Linear"], + "op_names": [x for layer in attention_name_groups[:6] for x in layer] + }, + { + "sparsity": args.sparsity / 2, + "op_types": ["Linear"], + "op_names": [x for layer in attention_name_groups[6:] for x in layer] + } + ] + + pruner = TransformerHeadPruner(model, config_list, **kwargs) + pruner.compress() + + ######################################################################### + # uncomment the following part to export the pruned model masks + # model_path = os.path.join(args.output_dir, "pruned_{}_{}.pth".format(args.model_name, args.task_name)) + # mask_path = os.path.join(args.output_dir, "mask_{}_{}.pth".format(args.model_name, args.task_name)) + # pruner.export_model(model_path=model_path, mask_path=mask_path) + + ######################################################################### + # Speedup + # Currently, speeding up Transformers through NNI ModelSpeedup is not supported because of shape inference issues. + # However, if you are using the transformers library, you can use the following workaround: + # The following code gets the head pruning decisions from the pruner and calls the _prune_heads() function + # implemented in models from the transformers library to speed up the model. + if args.speed_up: + speedup_rules = {} + for group_idx, group in enumerate(pruner.attention_name_groups): + # get the layer index + layer_idx = None + for part in group[0].split("."): + try: + layer_idx = int(part) + break + except: + continue + if layer_idx is not None: + speedup_rules[layer_idx] = pruner.pruned_heads[group_idx] + pruner._unwrap_model() + model.bert._prune_heads(speedup_rules) + print(model) + + ######################################################################### + # After pruning, finetune again on the target task + # Get the metric function + metric = load_metric("glue", args.task_name) + + # re-initialize the optimizer and the scheduler + optimizer, _, _, data_collator = get_dataloader_and_optimizer(args, tokenizer, model, train_dataset, + eval_dataset) + lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps, + num_training_steps=train_steps) + + logger.info("================= Finetuning after Pruning =================") + train_model(args, model, is_regression, train_dataloader, eval_dataloader, optimizer, lr_scheduler, metric, device) + + if args.output_dir is not None: + torch.save(model.state_dict(), args.output_dir + "/model_after_pruning.pt") + + if args.task_name == "mnli": + final_eval_for_mnli(args, model, processed_datasets, metric, data_collator) + + flops, params, results = count_flops_params(model, dummy_input) + print(f"Final model FLOPs {flops / 1e6:.2f} M, #Params: {params / 1e6:.2f}M") + + +if __name__ == "__main__": + main() diff --git a/nni/algorithms/compression/pytorch/pruning/__init__.py b/nni/algorithms/compression/pytorch/pruning/__init__.py index f49cf0cb65..2d92454859 100644 --- a/nni/algorithms/compression/pytorch/pruning/__init__.py +++ b/nni/algorithms/compression/pytorch/pruning/__init__.py @@ -3,6 +3,7 @@ from .finegrained_pruning_masker import * from .structured_pruning_masker import * +from .transformer_pruning_head_masker import * from .one_shot_pruner import * from .iterative_pruner import * from .lottery_ticket import LotteryTicketPruner @@ -11,3 +12,4 @@ from .auto_compress_pruner import AutoCompressPruner from .sensitivity_pruner import SensitivityPruner from .amc import AMCPruner +from .transformer_pruner import TransformerHeadPruner diff --git a/nni/algorithms/compression/pytorch/pruning/transformer_pruner.py b/nni/algorithms/compression/pytorch/pruning/transformer_pruner.py new file mode 100644 index 0000000000..5f5d9b5dfd --- /dev/null +++ b/nni/algorithms/compression/pytorch/pruning/transformer_pruner.py @@ -0,0 +1,337 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +from schema import And, Optional + +from nni.common.graph_utils import TorchModuleGraph +from nni.compression.pytorch.utils.shape_dependency import AttentionWeightDependency +from nni.compression.pytorch.utils.config_validation import CompressorSchema +from nni.compression.pytorch.compressor import Pruner +from . import L1WeightHeadMasker, L2WeightHeadMasker, L1ActivationHeadMasker, L2ActivationHeadMasker, TaylorFOHeadMasker + +__all__ = ['TransformerHeadPruner'] + +MASKER_DICT = { + 'l1_weight': L1WeightHeadMasker, + 'l2_weight': L2WeightHeadMasker, + 'l1_activation': L1ActivationHeadMasker, + 'l2_activation': L2ActivationHeadMasker, + 'taylorfo': TaylorFOHeadMasker +} + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class TransformerHeadPruner(Pruner): + """ + A pruner specialized for pruning attention heads in models belong to the transformer family. + + Parameters + ---------- + model : torch.nn.Module + Model to be pruned. Expect a model from transformers library (e.g., BertModel). + This pruner can work with other customized transformer models, but some ranking modes might fail. + config_list : list + Supported keys: + - sparsity : This is to specify the sparsity operations to be compressed to. + - op_types : Optional. Operation types to prune. (Should be 'Linear' for this pruner.) + - op_names : Optional. Operation names to prune. + head_hidden_dim : int + Dimension of the hidden dimension of each attention head. (e.g., 64 for BERT) + We assume that this head_hidden_dim is constant across the entire model. + attention_name_groups : list (Optional) + List of groups of names for weights of each attention layer. Each element should be a four-element list, with + the first three corresponding to Q_proj, K_proj, V_proj (in any order) and the last one being output_proj. + dummy_input : torch.Tensor (Optional) + Input to model's forward method, used to infer module grouping if attention_name_groups is not specified. + This tensor is used by the underlying torch.jit.trace to infer the module graph. + ranking_criterion : str + The criterion for ranking attention heads. Currently we support: + - l1_weight: l1 norm of Q_proj, K_proj, and V_proj + - l2_weight: l2 norm of Q_proj, K_proj, and V_proj + - l1_activation: l1 norm of the output of attention computation + - l2_activation: l2 norm of the output of attention computation + - taylorfo: l1 norm of the output of attention computation * gradient for this output + (check more details in the masker documentation) + global_sort : bool + Whether rank the heads globally or locally before deciding heads to prune. + num_iterations : int + Number of pruning iterations. Defaults to 1 (ont-shot pruning). If num_iterations > 1, the pruner will split + the sparsity specified in config_list uniformly and assign a fraction to each pruning iteration. + epochs_per_iteration : int + Number of finetuning epochs before the next pruning iteration. + Only used when num_iterations > 1. + If num_iterations is 1, then no finetuning is performed by the pruner after pruning. + optimizer: torch.optim.Optimizer + Optimizer used to train model + trainer: function + Function used to finetune the model between pruning iterations. + Only used when num_iterations > 1 or ranking_criterion is 'taylorfo'. + Users should write this function as a normal function to train the PyTorch model and include + `model, optimizer, criterion, epoch` as function arguments. Note that the trainer is also used for collecting + gradients for pruning if ranking_criterion is 'taylorfo'. In that case, ``epoch=None`` will be passed. + criterion: function + Function used to calculate the loss between the target and the output. + Only used when num_iterations > 1 or ranking_criterion is 'taylorfo'. + For example, you can use ``torch.nn.CrossEntropyLoss()`` as input. + forward_runner: function + Function used to perform a "dry run" on the model on the entire train/validation dataset in order to collect + data for pruning required by the criteria 'l1_activation' or 'l2_activation'. + Only used when ranking_criterion is 'l1_activation' or 'l2_activation'. + Users should write this function as a normal function that accepts a PyTorch model and runs forward on the model + using the entire train/validation dataset. This function is not expected to perform any backpropagation or + parameter updates. + """ + def __init__(self, model, config_list,head_hidden_dim, attention_name_groups=None, dummy_input=None, + ranking_criterion='l1_weight', global_sort=False, num_iterations=1, epochs_per_iteration=1, + optimizer=None, trainer=None, criterion=None, forward_runner=None, + **algo_kwargs): + super().__init__(model, config_list) + + self.head_hidden_dim = int(head_hidden_dim) + self.attention_name_groups = attention_name_groups + self.dummy_input = dummy_input + self.ranking_criterion = ranking_criterion + assert self.ranking_criterion in ['l1_weight', 'l2_weight', 'l1_activation', 'l2_activation', 'taylorfo'], \ + "Unsupported ranking criteria." + self.global_sort = global_sort + self.num_iterations = int(num_iterations) + assert self.num_iterations >= 1, "num_iterations must be greater than or equal to 1" + self.epochs_per_iteration = int(epochs_per_iteration) + self._optimizer = optimizer + self._trainer = trainer + self._criterion = criterion + self._forward_runner = forward_runner + if self.ranking_criterion in ['taylorfo'] or num_iterations > 1: + assert self._trainer is not None + assert self._optimizer is not None + if self.ranking_criterion in ['l1_activation', 'l2_activation']: + assert self._forward_runner is not None + + # Group generation: one group per attention layer, four weights per group + self.masking_groups = [] + if self.attention_name_groups is not None: + logger.info("Note: weights for the same attention layer are grouped using the given attention_name_groups.") + self.group_weights_by_name() + else: + assert self.dummy_input is not None + logger.info("Note: weights for the same attention layer are grouped using model graph.") + self._unwrap_model() + self.group_weight_names_by_graph() + self._wrap_model() + + # Group sanity check + self.validate_weight_groups() + + # Remove any mistakenly captured ungrouped modules + self._unwrap_model() + self.remove_ungrouped_modules() + self._wrap_model() + + self.masker = MASKER_DICT[ranking_criterion](model, self, self.head_hidden_dim, **algo_kwargs) + self.pruned_heads = {i: set() for i in range(len(self.masking_groups))} + + def group_weights_by_name(self): + """ + Populate self.masking_groups using the groups specified by user in attention_name_groups. + """ + assert len(self.masking_groups) == 0 + # build up masking groups + name2group = {} + for layer_idx, layer in enumerate(self.attention_name_groups): + errmsg = 'Each name group must contain 4 weights, with the first three corresponding to Q_proj, K_proj, ' \ + 'V_proj (in any order) and the last one being output_proj.' + assert len(layer) == 4, errmsg + self.masking_groups.append([]) + for weight in layer: + name2group[weight] = layer_idx + + # group wrappers + for wrapper in self.get_modules_wrapper(): + if wrapper.name in name2group: + wrapper.group_idx = name2group[wrapper.name] + self.masking_groups[name2group[wrapper.name]].append(wrapper) + + logger.info('Grouping updated:') + logger.info([[x.name for x in group] for group in self.masking_groups]) + + def group_weight_names_by_graph(self): + """ + Populate self.attention_name_groups by running inference on the module graph. + Currently, the group inferred AttentionWeightDependency is limited to a set of four weights, with the first + three corresponding to Q_proj, K_proj, V_proj (in any order) and the last one being output_proj. + """ + try: + module_graph = TorchModuleGraph(self.bound_model, self.dummy_input) + dependency_tracer = AttentionWeightDependency(traced_model=module_graph.trace) + self.attention_name_groups = dependency_tracer.dependency_sets + self.group_weights_by_name() + + except Exception as e: + raise RuntimeError('Graph trace failed: please check dummy_input, or specify attention_name_groups.\n' + 'Exception message: ' + str(e)) + + def validate_weight_groups(self): + """ + Sanity checks: + - Q, K, V projection weights in each groups must have the same shape + - output projection weight shape must match total hidden dimension (inferred from Q, K, V projection) + - Four weights in a group must have the same sparsity in their config + - If global_sort is specified, all weights must have the same sparsity + - head_hidden_dim must be a divisor of the output dimension of the projection weights (i.e., the resulting + head number must be an integer) + """ + errmsg = 'Attention weight group sanity check not passed' + sparsity = None + for group in self.masking_groups: + # allow empty groups - may be caused by config list filtering + if len(group) == 0: + continue + assert len(group) == 4, errmsg + ': each group must have four weights' + assert group[0].module.weight.size() == group[1].module.weight.size() and \ + group[1].module.weight.size() == group[2].module.weight.size(), \ + errmsg + ': the dimensions of Q, K, V projection matrices must be the same ' + assert group[0].module.weight.size()[0] == group[3].module.weight.size()[1], \ + errmsg + ': the dimension of attention results must match with input for output projection' + assert group[0].config['sparsity'] == group[1].config['sparsity'] == \ + group[2].config['sparsity'] == group[3].config['sparsity'], \ + errmsg + ': the sparsity of matrices in the same layer must be the same' + if sparsity is None: + sparsity = group[0].config['sparsity'] + if self.global_sort: + assert sparsity == group[0].config['sparsity'], \ + errmsg + ': for global_sort=True, the sparsity for all modules must be the same' + assert group[0].module.weight.size(0) % self.head_hidden_dim == 0, \ + errmsg + ': head_hidden_dim must be a divisor of the output dimension of the projection weights' + + def remove_ungrouped_modules(self): + """ + Remove non-attention weights that might be mistakenly captured by a simplified config_list. + Also update the corresponding list of layer information (self.modules_to_compress) + """ + care_of_modules = set([x for layer in self.masking_groups for x in layer]) + + modules_wrapper_new, modules_to_compress_new = [], [] + for wrapper, layer_info in zip(self.modules_wrapper, self.modules_to_compress): + if wrapper in care_of_modules: + modules_wrapper_new.append(wrapper) + modules_to_compress_new.append(layer_info) + + self.modules_wrapper = modules_wrapper_new + self.modules_to_compress = modules_to_compress_new + + def validate_config(self, model, config_list): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned + config_list : list + List on pruning configs + """ + schema = CompressorSchema([{ + 'sparsity': And(float, lambda n: 0 < n < 1), + Optional('op_types'): [str], + Optional('op_names'): [str] + }], model, logger) + + schema.validate(config_list) + + def compress(self): + for pruning_iter in range(self.num_iterations): + if self.ranking_criterion in ['l1_activation', 'l2_activation']: + training = self.bound_model.training + self.bound_model.eval() + self._forward_runner(self.bound_model) # dry run, forward only + self.update_mask() + self.bound_model.train(training) + elif self.ranking_criterion in ['taylorfo']: + self._trainer(self.bound_model, optimizer=self._optimizer, criterion=self._criterion, epoch=None) + self.update_mask() + else: + self.update_mask() + + # for iterative pruning, if not the last iteration, finetune before next iteration + # Then, reset the maskers (may create additional hooks) + if self.num_iterations > 1 and pruning_iter != self.num_iterations - 1: + for e in range(self.epochs_per_iteration): + self._trainer(self.bound_model, optimizer=self._optimizer, criterion=self._criterion, epoch=e+1) + self.masker.reset() + + logger.info('Pruned heads after iteration %i', pruning_iter) + logger.info(self.pruned_heads) + + def update_mask(self): + """ + Calculate and update masks for each masking group. If global_sort is set, the masks for all groups are + calculated altogether, and then the groups are updated individually. + """ + masks_for_all_groups = None + if self.global_sort: + masks_for_all_groups = self._calc_mask_global() + assert len(masks_for_all_groups) == len(self.masking_groups) + for group_idx, layer_weight_group in enumerate(self.masking_groups): + if self.global_sort: + masks = masks_for_all_groups[group_idx] + else: + masks = self._calc_mask(layer_weight_group) + if masks is not None: + for i, mask in enumerate(masks): + for mask_type in mask: + assert hasattr(layer_weight_group[i], mask_type), \ + "there is no attribute '%s' in wrapper on %s" % (mask_type, layer_weight_group[i]) + setattr(layer_weight_group[i], mask_type, mask[mask_type]) + logger.debug(f'mask updated: {layer_weight_group[i].name} {mask_type}') + + def _calc_mask(self, weight_group): + """ + Calculate mask for each group using only layer-local information. + When global_sort is set for the pruner, _calc_mask_global should be called instead of this function. + + Parameters + ---------- + weight_group : list + A list of four wrappers generated by self.group_weights_by_name(). + + Returns + ------- + masks : list + A four element list corresponding to the masks for each element in the four-element weight group. + Each element in masks is a dict with keys "weight_mask" and "bias_mask" (optional). + masks can be None if the underlying masker returns None. This means that the mask calculation fails. + The calling function can try recalculate the mask at a later time. Note that the calling function might need + to call masker.reset() before attempting to recalculate the mask. + """ + iter_sparsity = weight_group[0].config['sparsity'] / self.num_iterations + masks = self.masker.calc_mask(sparsity=iter_sparsity, weight_group=weight_group) + + return masks + + def _calc_mask_global(self): + """ + Calculate mask for all groups using global information. + + Returns + ------- + masks_list : list + A list corresponding to the masks for each weight group in self.masking_groups. Each element in the + returned mask_list is a four-element list corresponding to the masks for each element in a four-element + weight group. + """ + if len(self.get_modules_wrapper()) == 0: + return [] + + overall_sparsity = self.get_modules_wrapper()[0].config['sparsity'] / self.num_iterations + n_heads_total = 0 + for group in self.masking_groups: + if len(group) != 0: + q_proj, _, _, _ = group + n_heads_total += int(q_proj.module.weight.size()[0] / self.head_hidden_dim) + n_heads_to_prune = int(n_heads_total * overall_sparsity) + + return self.masker.calc_mask_global(n_heads_to_prune) + + def calc_mask(self, wrapper, **kwargs): + raise RuntimeError("Applications should directly call TransformerHeadPruner's update_mask() method.") diff --git a/nni/algorithms/compression/pytorch/pruning/transformer_pruning_head_masker.py b/nni/algorithms/compression/pytorch/pruning/transformer_pruning_head_masker.py new file mode 100644 index 0000000000..fc7f6f5808 --- /dev/null +++ b/nni/algorithms/compression/pytorch/pruning/transformer_pruning_head_masker.py @@ -0,0 +1,444 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import torch +from .weight_masker import WeightMasker + +__all__ = ['L1WeightHeadMasker', 'L2WeightHeadMasker', 'L1ActivationHeadMasker', 'L2ActivationHeadMasker', + 'TaylorFOHeadMasker'] + +logger = logging.getLogger('transformer head pruner') + + +class AttentionHeadMasker(WeightMasker): + """ + A structured pruning masker base class that prunes attention heads in attention layers. + + Parameters + ---------- + model: nn.Module + model to be pruned + pruner: Pruner + A Pruner instance used to prune the model + head_hidden_dim: int + Hidden dimension for each attention head (e.g., 64 for BERT base) + """ + def __init__(self, model, pruner, head_hidden_dim=None): + super().__init__(model, pruner) + self.head_hidden_dim = head_hidden_dim + assert self.head_hidden_dim is not None, "head_hidden_dim must be specified." + + def reset(self): + """ + Derived classes can override this method to do preparations necessary for calculating importance scores. + This method is called during iterative pruning, before each iteration starts (except the first one). + """ + pass + + def calc_mask(self, sparsity, wrapper=None, wrapper_idx=None, weight_group=None, **kwargs): + """ + Calculate all the masks for a group of wrappers (specified in weight_group). + This function only utilizes local information for mask calculation. If global_sort is specified for the pruner, + the pruner should call calc_mask_global instead of this function. + + Parameters + ---------- + sparsity: float + The target (amount of increase of) sparsity of the wrapper list. + weight_group: list + A four-element list of module wrappers + wrapper: PrunerModuleWrapper/list of PrunerModuleWrappers + Should be None. Not used in this masker, just for consistency with the parent API. + wrapper_idx: int/list of int + Should be None. Not used in this masker, just for consistency with the parent API. + Returns + ------- + masks : list + masks for each element in the group. + Each element in the list masks is a dictionary for storing masks, keys of the dict: + 'weight_mask': weight mask tensor + 'bias_mask': bias mask tensor (optional) + """ + assert weight_group is not None + if len(weight_group) == 0: + return None + else: + num_total = weight_group[0].module.weight.data.size(0) // self.head_hidden_dim + if num_total < 2: + return None + num_prune = max(int(num_total * sparsity), 1) + return self.get_mask(num_prune, weight_group, **kwargs) + + def calc_mask_global(self, n_heads_to_prune): + """ + Calculate all the masks for all groups in the pruner. + + Parameters + ---------- + n_heads_to_prune : int + Total number of attention heads to prune. + Returns + ------- + all_masks : list + A list of masks for all groups, where each element is a list of masks for each module in the group. + """ + # calculate scores as normal (this step does not require global information) + head_importance_scores = [] + for group_idx, group in enumerate(self.pruner.masking_groups): + if len(group) != 0: + scores = self.get_head_importance_scores(group) + n_heads = group[0].module.weight.size(0) // self.head_hidden_dim + for head_idx in range(n_heads): + head_importance_scores.append([group_idx, head_idx, scores[head_idx]]) + + # determine which head to prune for each layer + n_selected = 0 + for group_idx, head_idx, _ in sorted(head_importance_scores, key=(lambda x: x[-1])): + n_heads_original = self.pruner.masking_groups[group_idx][0].module.weight.size(0) // self.head_hidden_dim + n_heads_remaining = n_heads_original - len(self.pruner.pruned_heads[group_idx]) + if n_heads_remaining > 1 and head_idx not in self.pruner.pruned_heads[group_idx]: + self.pruner.pruned_heads[group_idx].add(head_idx) + n_selected += 1 + if n_selected >= n_heads_to_prune: + break + + # generate masks + all_masks = [] + for group_idx, group in enumerate(self.pruner.masking_groups): + if len(group) == 0: + masks = None + else: + n_heads = group[0].module.weight.size(0) // self.head_hidden_dim + device = group[0].module.weight.device + head_level_mask = torch.tensor([i not in self.pruner.pruned_heads[group_idx] for i in range(n_heads)], device=device) # pylint: disable=not-callable + masks = self._get_layer_masks_from_head_mask(group, head_level_mask) + all_masks.append(masks) + + return all_masks + + def get_mask(self, num_prune, weight_group, **kwargs): + """ + Calculate the mask of given layer (weight_group). + + Parameters + ---------- + num_prune: int + Num of heads to prune + weight_group: list + A four-element list of module wrappers + Returns + ------- + masks : list + masks for each element in the group. + Each element in the list masks is a dictionary for storing masks, keys of the dict: + 'weight_mask': weight mask tensor + 'bias_mask': bias mask tensor (optional) + """ + raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) + + def _get_layer_masks_from_head_mask(self, weight_group, head_mask_bool, device=None): + q_proj, _, _, output_proj = weight_group + if device is None: + device = q_proj.module.weight.device + + n_heads = q_proj.module.weight.size()[0] // self.head_hidden_dim + weight_mask_shape = q_proj.module.weight.data.view([n_heads, -1]).size() + bias_mask_shape = q_proj.module.bias.data.view([n_heads, -1]).size() + + mask_weight = head_mask_bool.unsqueeze(-1).expand(weight_mask_shape).type_as(q_proj.module.weight) + mask_bias = head_mask_bool.unsqueeze(-1).expand(bias_mask_shape).type_as(q_proj.module.weight) + + mask_weight_proj = mask_weight.contiguous().view(q_proj.module.weight.size()).detach().to(device) + mask_bias_proj = mask_bias.contiguous().view(-1).detach().to(device) + masks_for_proj = {'weight_mask': mask_weight_proj.detach()} + if hasattr(q_proj.module, 'bias') and q_proj.module.bias is not None: + masks_for_proj['bias_mask'] = mask_bias_proj + + mask_weight_dense = mask_bias_proj.expand_as(output_proj.module.weight.data).detach().to(device) + mask_bias_dense = torch.ones_like(output_proj.module.bias.data).to(device) + masks_for_dense = {'weight_mask': mask_weight_dense.detach()} + if hasattr(output_proj.module, 'bias') and output_proj.module.bias is not None: + masks_for_dense['bias_mask'] = mask_bias_dense + + masks = [masks_for_proj, masks_for_proj, masks_for_proj, masks_for_dense] + + return masks + + def get_mask_by_importance_ranking(self, num_prune, weight_group): + """ + Calculate the mask of given layer by pruning out heads with lowest importance scores. + + Parameters + ---------- + num_prune: int + Num of heads to prune + weight_group: list + list of a group of weights for an attention layer + Returns + ------- + masks : list + masks for each element in the group. + Each element in the list masks is a dictionary for storing masks, keys of the dict: + 'weight_mask': weight mask tensor + 'bias_mask': bias mask tensor (optional) + """ + importance_scores = self.get_head_importance_scores(weight_group) + if importance_scores is None: + return None + + importance_scores = [[i, importance_scores[i]] for i in range(len(importance_scores))] + head_mask_bool = torch.ones(len(importance_scores)) + n_selected = 0 + for head_idx, _ in sorted(importance_scores, key=(lambda x: x[-1])): + head_mask_bool[head_idx] = 0 + if head_idx not in self.pruner.pruned_heads[weight_group[0].group_idx]: + n_selected += 1 + # update pruned_heads in pruner (mainly for iterative pruning) + self.pruner.pruned_heads[weight_group[0].group_idx].add(head_idx) + if n_selected == num_prune: + break + + return self._get_layer_masks_from_head_mask(weight_group, head_mask_bool) + + def get_head_importance_scores(self, weight_group): + """ + Calculate the importance score for each head. + Parameters + ---------- + weight_group: list + list of a group of weights for an attention layer + + Returns + ------- + importance_scores: tensor + Tensor that indicates the importance of each head + """ + raise NotImplementedError('{} get_channel_sum is not implemented'.format(self.__class__.__name__)) + + +class L1WeightHeadMasker(AttentionHeadMasker): + """ + A structured pruning algorithm that prunes the heads weight smallest weight magnitude for the query, head, + and key projection matrices. L1 norm is used for magnitude calculation. Note that in this implementation, weight + norms of q_proj, k_proj, v_proj from each head are summed as the final importance score for the head. + """ + def get_head_importance_scores(self, weight_group): + q_proj, k_proj, v_proj, _ = weight_group + + n_heads = q_proj.module.weight.size()[0] // self.head_hidden_dim + query_proj_weights = q_proj.module.weight.data.view([n_heads, -1]) + key_proj_weights = k_proj.module.weight.data.view([n_heads, -1]) + value_proj_weights = v_proj.module.weight.data.view([n_heads, -1]) + + query_norm_avg = torch.norm(query_proj_weights, 1, -1) + key_norm_avg = torch.norm(key_proj_weights, 1, -1) + value_norm_avg = torch.norm(value_proj_weights, 1, -1) + + return ((query_norm_avg + key_norm_avg + value_norm_avg) / 3).detach() + + def get_mask(self, num_prune, weight_group, **kwargs): + return self.get_mask_by_importance_ranking(num_prune, weight_group) + + +class L2WeightHeadMasker(AttentionHeadMasker): + """ + A structured pruning algorithm that prunes the heads weight smallest weight magnitude for the query, head, + and key projection matrices. L2 norm is used for magnitude calculation. Note that in this implementation, weight + norms of q_proj, k_proj, v_proj from each head are summed as the final importance score for the head. + """ + def get_head_importance_scores(self, weight_group): + q_proj, k_proj, v_proj, _ = weight_group + + n_heads = q_proj.module.weight.size()[0] // self.head_hidden_dim + query_proj_weights = q_proj.module.weight.data.view([n_heads, -1]) + key_proj_weights = k_proj.module.weight.data.view([n_heads, -1]) + value_proj_weights = v_proj.module.weight.data.view([n_heads, -1]) + + query_norm_avg = torch.norm(query_proj_weights, 2, -1) + key_norm_avg = torch.norm(key_proj_weights, 2, -1) + value_norm_avg = torch.norm(value_proj_weights, 2, -1) + + return ((query_norm_avg + key_norm_avg + value_norm_avg) / 3).detach() + + def get_mask(self, num_prune, weight_group, **kwargs): + return self.get_mask_by_importance_ranking(num_prune, weight_group) + + +class L1ActivationHeadMasker(AttentionHeadMasker): + """ + A structured pruning algorithm that prunes the heads with smallest final output value. + Note that this masker only relies on the output of the output layer of each attention layer. + The masker collects the L1 norm of the output of the last weight (output projection) in each group on the entire + train set, and prunes the heads producing the smallest output. + """ + def __init__(self, model, pruner, head_hidden_dim=None): + super().__init__(model, pruner, head_hidden_dim) + self.reset() + + def reset(self): + self.pruner.hook_id = self._add_activation_collector(self.pruner) + + def get_head_importance_scores(self, weight_group): + _, _, _, output_proj = weight_group + activations = torch.stack(self.pruner.collected_activation[output_proj.group_idx], -1) + activations = torch.sum(activations, -1) + n_heads = activations.size()[0] // self.head_hidden_dim + scores = torch.sum(activations.view([n_heads, -1]), -1).detach().cpu() + + # clean up hooks + if self.pruner.hook_id in self.pruner._fwd_hook_handles: + self.pruner.remove_activation_collector(self.pruner.hook_id) + + return scores + + def _add_activation_collector(self, pruner): + def collector(collected_activation): + def hook(module_, input_, output): + if type(input_) is tuple: + input_ = input_[0] + raw_activation = torch.abs(input_.detach().cpu()) # L1-norm + raw_activation_reduced = torch.sum(raw_activation, [0, 1]) + collected_activation.append(raw_activation_reduced) + return hook + pruner.collected_activation = {} + pruner._fwd_hook_id += 1 + pruner._fwd_hook_handles[pruner._fwd_hook_id] = [] + + for _, _, _, output_proj in pruner.masking_groups: + pruner.collected_activation[output_proj.group_idx] = [] + handle = output_proj.register_forward_hook(collector(pruner.collected_activation[output_proj.group_idx])) + + pruner._fwd_hook_handles[pruner._fwd_hook_id].append(handle) + + return pruner._fwd_hook_id + + def get_mask(self, num_prune, weight_group, **kwargs): + return self.get_mask_by_importance_ranking(num_prune, weight_group) + + +class L2ActivationHeadMasker(AttentionHeadMasker): + """ + A structured pruning algorithm that prunes the heads with smallest final output value. + Note that this masker only relies on the output of the output layer of each attention layer. + The masker collects the L2 norm of the output of the last weight (output projection) in each group on the entire + train set, and prunes the heads producing the smallest output. + """ + def __init__(self, model, pruner, head_hidden_dim=None): + super().__init__(model, pruner, head_hidden_dim) + self.reset() + + def reset(self): + self.pruner.hook_id = self._add_activation_collector(self.pruner) + + def get_head_importance_scores(self, weight_group): + _, _, _, output_proj = weight_group + activations = torch.stack(self.pruner.collected_activation[output_proj.group_idx], -1) + scores = torch.sum(activations, -1).detach().cpu() + # n_heads = activations.size()[0] // self.head_hidden_dim + # scores = torch.sum(activations.view([n_heads, -1]), -1).detach().cpu() + + # clean up hooks + if self.pruner.hook_id in self.pruner._fwd_hook_handles: + self.pruner.remove_activation_collector(self.pruner.hook_id) + + return scores + + def _add_activation_collector(self, pruner): + def collector(collected_activation, head_hidden_dim): + def hook(module_, input_, output): + if type(input_) is tuple: + input_ = input_[0] + raw_activation = input_.detach().cpu() ** 2 + n_heads = raw_activation.size(-1) // head_hidden_dim + raw_activation = raw_activation.view(raw_activation.size(0), raw_activation.size(1), n_heads, -1) + raw_activation = torch.norm(raw_activation, 2, -1) # (B, S, n_heads) + raw_activation_reduced = torch.sum(raw_activation, [0, 1]) # (n_heads,) + collected_activation.append(raw_activation_reduced) + + return hook + + pruner.collected_activation = {} + pruner._fwd_hook_id += 1 + pruner._fwd_hook_handles[pruner._fwd_hook_id] = [] + + for _, _, _, output_proj in pruner.masking_groups: + pruner.collected_activation[output_proj.group_idx] = [] + handle = output_proj.register_forward_hook(collector(pruner.collected_activation[output_proj.group_idx], + head_hidden_dim=self.head_hidden_dim)) + + pruner._fwd_hook_handles[pruner._fwd_hook_id].append(handle) + + return pruner._fwd_hook_id + + def get_mask(self, num_prune, weight_group, **kwargs): + return self.get_mask_by_importance_ranking(num_prune, weight_group) + + +class TaylorFOHeadMasker(AttentionHeadMasker): + """ + A structured pruning algorithm that prunes the heads with smallest final output contribution. + Note that this masker only relies on the output of the output layer of each attention layer. + The masker collects the output the last weight (output projection) in each group and the corresponding gradient + on the entire train set, and prunes the heads producing the smallest contribution as used in the following papers: + "Are Sixteen Heads Really Better than One?" (Michel et.al, 2019) + "Pruning convolutional neural networks for resource efficient inference." (Molchanov et. al., 2017) + """ + def __init__(self, model, pruner, head_hidden_dim=None): + super().__init__(model, pruner, head_hidden_dim) + self.reset() + + def reset(self): + self.pruner.hook_id = self._add_activation_collector() # forward hooks for collecting activation + self.backward_hooks = {} # backward hooks for collecting gradient + self._add_gradient_collector() + + def get_head_importance_scores(self, weight_group): + _, _, _, output_proj = weight_group + result = output_proj.head_importance_scores + + # clean up hooks and cached data + if self.pruner.hook_id in self.pruner._fwd_hook_handles: + self.pruner.remove_activation_collector(self.pruner.hook_id) + self.backward_hooks[output_proj.group_idx].remove() + for attr in ['forward_output_cached', 'head_importance_scores']: + output_proj.__dict__.pop(attr, None) + + return result + + def _add_activation_collector(self): + def forward_hook(md, inp, out): + if type(inp) is tuple: + inp = inp[0] + n_heads_per_layer = inp.size(-1) // self.head_hidden_dim + heads_output = inp.view([inp.size(0), inp.size(1), n_heads_per_layer, -1]).detach() + md.forward_output_cached = heads_output + + self.pruner._fwd_hook_id += 1 + self.pruner._fwd_hook_handles[self.pruner._fwd_hook_id] = [] + + for _, _, _, output_proj in self.pruner.masking_groups: + handle = output_proj.register_forward_hook(forward_hook) + self.pruner._fwd_hook_handles[self.pruner._fwd_hook_id].append(handle) + + return self.pruner._fwd_hook_id + + def _add_gradient_collector(self): + def grad_hook(md, grad_in, grad_out): + if type(grad_in) is tuple: + grad_in = grad_in[0] + n_heads_per_layer = grad_in.size(-1) // self.head_hidden_dim + heads_grad = grad_in.view([grad_in.size(0), grad_in.size(1), n_heads_per_layer, -1]) + heads_scores = torch.abs(heads_grad * md.forward_output_cached) + heads_scores = torch.sum(heads_scores, [0, 1, 3]).detach().cpu() + if hasattr(md, 'head_importance_scores'): + md.head_importance_scores += heads_scores + else: + md.head_importance_scores = heads_scores + + for _, _, _, output_proj in self.pruner.masking_groups: + handle = output_proj.register_backward_hook(grad_hook) + self.backward_hooks[output_proj.group_idx] = handle + + def get_mask(self, num_prune, weight_group, **kwargs): + return self.get_mask_by_importance_ranking(num_prune, weight_group) diff --git a/nni/compression/pytorch/utils/shape_dependency.py b/nni/compression/pytorch/utils/shape_dependency.py index b8e6dc896f..883b731d57 100644 --- a/nni/compression/pytorch/utils/shape_dependency.py +++ b/nni/compression/pytorch/utils/shape_dependency.py @@ -6,7 +6,8 @@ import numpy as np -__all__ = ['ChannelDependency', 'GroupDependency', 'InputChannelDependency'] +__all__ = ['ChannelDependency', 'GroupDependency', 'InputChannelDependency', 'AttentionWeightDependency'] + CONV_TYPE = 'aten::_convolution' ADD_TYPES = ['aten::add', 'aten::add_'] @@ -88,7 +89,6 @@ def __init__(self, model=None, dummy_input=None, traced_model=None): """ This model analyze the channel dependencies between the conv layers in a model. - Parameters ---------- model : torch.nn.Module @@ -105,12 +105,10 @@ def __init__(self, model=None, dummy_input=None, traced_model=None): def _get_parent_layers(self, node): """ Find the nearest father conv layers for the target node. - Parameters --------- node : torch._C.Node target node. - Returns ------- parent_layers: list @@ -182,7 +180,6 @@ def export(self, filepath): means the output channel(filters) numbers of these three layers should be same with each other, otherwise the model may has shape conflict. - Output example: Dependency Set,Convolutional Layers Set 1,layer1.1.conv2,layer1.0.conv2,conv1 @@ -219,7 +216,6 @@ def dependency_sets(self): dependency_sets : list list of the dependency sets. For example, [set(['conv1', 'conv2']), set(['conv3', 'conv4'])] - """ d_sets = [] visited = set() @@ -256,7 +252,6 @@ def __init__(self, model, dummy_input=None, traced_model=None): """ This model analyze the input channel dependencies between the conv layers in a model. - Parameters ---------- model : torch.nn.Module @@ -319,7 +314,6 @@ def __init__(self, model=None, dummy_input=None, traced_model=None): """ This model analyze the group dependencis between the conv layers in a model. - Parameters ---------- model : torch.nn.Module @@ -336,12 +330,10 @@ def __init__(self, model=None, dummy_input=None, traced_model=None): def _get_parent_convs(self, node): """ Find the nearest father conv layers for the target node. - Parameters --------- node : torch._C.Node target node. - Returns ------- parent_layers : list @@ -369,12 +361,10 @@ def _get_parent_convs(self, node): def _get_conv_groups(self, node_group): """ Get the number of groups for a convolutional layer. - Parameters ---------- node_group : NodePyGroup target node. - Returns ------- group : int @@ -401,7 +391,7 @@ def build_dependency(self): conv2 takes the output features of conv1 as input. Then we have to the filters of conv1 can still be divided into 4 groups after filter pruning, because - the input channels of conv2 shoule be divided into + the input channels of conv2 should be divided into 4 groups. Returns @@ -448,7 +438,6 @@ def export(self, filepath): line is the group count of the filters in this layer. Note that, the group count may be larger than this layers original group number. - output example: Conv layer, Groups Conv1, 1 @@ -468,7 +457,6 @@ def dependency_sets(self): return self.dependency - class ReshapeDependency(Dependency): def __init__(self, model=None, dummy_input=None, traced_model=None): """ @@ -573,3 +561,142 @@ def dependency_sets(self): d_sets.extend(self.dependency[reshape_node]) d_sets = list(set(d_sets)) return d_sets + + +class AttentionWeightDependency(Dependency): + def __init__(self, model=None, dummy_input=None, traced_model=None): + """ + Groups the linear layers belonging to the same attention layer in a model. + Currently, we only capture weights in attention layers with forward computations written + as four Linear layers (projections for Q, K, V, and output) and two matmul operations. + The method implemented here can work for Huggingface transformers but may not correctly + capture transformers written in other fashions (e.g., torch.nn.Transformer). + + Parameters + ---------- + model : torch.nn.Module + The model to be analyzed. + dummy_input : torch.Tensor + The example input data to trace the network architecture. + traced_model : torch._C.Graph + if we already have the traced graph of the target model, we do not + need to trace the model again. + """ + super(AttentionWeightDependency, self).__init__( + model, dummy_input, traced_model) + + def _get_parent_layers(self, node): + """ + Find the nearest parent linear layers for the target node. + + Parameters + --------- + node : torch._C.Node + target node. + + Returns + ------- + parent_layers: list + nearest parent linear layers for the target worknode. + """ + parent_layers = [] + queue = [] + queue.append(node) + while queue: + curnode = queue.pop(0) + if curnode.op_type == 'Linear': + if curnode.name not in parent_layers: + parent_layers.append(curnode.name) + continue + if curnode.op_type == 'LayerNorm': + continue + parents = self.graph.find_predecessors(curnode.unique_name) + parents = [self.graph.name_to_node[name] for name in parents] + for parent in parents: + queue.append(parent) + return parent_layers + + def _get_children_layers(self, node): + """ + Find the nearest children linear layers for the target node. + + Parameters + --------- + node : torch._C.Node + target node. + + Returns + ------- + children_layers: list + nearest children linear layers for the target worknode. + """ + children_layers = [] + queue = [] + queue.append(node) + while queue: + curnode = queue.pop(0) + if curnode.op_type == 'Linear': + if curnode.name not in children_layers: + children_layers.append(curnode.name) + continue + if curnode.op_type == 'LayerNorm': + continue + children = self.graph.find_successors(curnode.unique_name) + children = [self.graph.name_to_node[name] for name in children] + for child in children: + queue.append(child) + return children_layers + + def build_dependency(self): + """ + For every matmul operation, find the immediate parent and children Linear operations. + If we get three parents and one children, add these four weights as a dependecy group. + """ + self.graph.unpack_manually() + for node in self.graph.nodes_py.nodes_op: + layers = [] + if node.op_type == 'aten::matmul': + parent_layers = self._get_parent_layers(node) + children_layers = self._get_children_layers(node) + if len(parent_layers) == 3 and len(children_layers) == 1: + layers.extend(parent_layers) + layers.extend(children_layers) + + self.dependency[node.name] = layers + + @property + def dependency_sets(self): + """ + Get the list of the dependency set. + + Returns + ------- + dependency_sets : list + list of the dependency sets. + Each dependency set is a 4-element list of module names, with the first three elements being the projection + matrices for Q, K, V (in any order), and the last element being the dense matrix. + """ + d_sets = [] + for node in self.graph.nodes_py.nodes_op: + if node.op_type != 'aten::matmul' or node.name not in self.dependency or len(self.dependency[node.name]) != 4: + continue + d_sets.append(self.dependency[node.name]) + + return d_sets + + def export(self, filepath): + """ + Export the group dependency to a csv file. Each line describes an attention layer. + + Output example: + Attention layer matmul op, Group + """ + header = ['Attention layer matmul op', 'Group'] + with open(filepath, 'w') as csvf: + csv_w = csv.writer(csvf, delimiter=',') + csv_w.writerow(header) + for name in self.dependency: + group = self.dependency[name] + if len(group) > 0: + csv_w.writerow([name, group]) + diff --git a/test/ut/sdk/models/pytorch_models/transformer.py b/test/ut/sdk/models/pytorch_models/transformer.py new file mode 100644 index 0000000000..608d4ed93b --- /dev/null +++ b/test/ut/sdk/models/pytorch_models/transformer.py @@ -0,0 +1,190 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import copy + + +class PosEncoding(nn.Module): + def __init__(self, hidden_dim, max_seq_len=80): + super().__init__() + self.hidden_dim = hidden_dim + + pe = torch.zeros(max_seq_len, hidden_dim) + for pos in range(max_seq_len): + for i in range(0, hidden_dim, 2): + pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / hidden_dim))) + pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / hidden_dim))) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x * math.sqrt(self.hidden_dim) + x = x + torch.autograd.Variable(self.pe[:, :x.size(1)], requires_grad=False) + return x + + +def attention(query, key, value, mask=None, dropout=None): + d_k = query.size(-1) + logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) + if mask is not None: + logits = logits.masked_fill(mask == 0, -1e9) + attention_map = F.softmax(logits, dim=-1) + if dropout is not None: + attention_map = dropout(attention_map) + return torch.matmul(attention_map, value) + + +class MultiHeadAttention(nn.Module): + def __init__(self, hidden_dim, n_heads, dropout=0.1): + super().__init__() + + self.hidden_dim = hidden_dim + self.head_dim = hidden_dim // n_heads + self.n_heads = n_heads + + self.q_proj = nn.Linear(hidden_dim, hidden_dim) + self.v_proj = nn.Linear(hidden_dim, hidden_dim) + self.k_proj = nn.Linear(hidden_dim, hidden_dim) + self.dropout = nn.Dropout(dropout) + self.output_proj = nn.Linear(hidden_dim, hidden_dim) + + def forward(self, query, key, value, mask=None): + batch_size = query.size(0) + + # project and reshaping + k_project = self.k_proj(key) + q_project = self.q_proj(query) + v_project = self.v_proj(value) + k_reshape = k_project.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) + q_reshape = q_project.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) + v_reshape = v_project.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) + + # merge heads and output + scores = attention(q_reshape, k_reshape, v_reshape, mask, self.dropout) + scores = scores.transpose(1, 2).contiguous() + scores = scores.view(batch_size, -1, self.hidden_dim) + + return self.output_proj(scores) + + +class FeedForwardLayer(nn.Module): + def __init__(self, hidden_dim, intermediate_dim=2048, dropout=0.1): + super().__init__() + self.dense1 = nn.Linear(hidden_dim, intermediate_dim) + self.dense2 = nn.Linear(intermediate_dim, hidden_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.dense2(self.dropout(F.relu(self.dense1(x)))) + + +class LayerNorm(nn.Module): + def __init__(self, hidden_dim, eps=1e-6): + super(LayerNorm, self).__init__() + + self.alpha = nn.Parameter(torch.ones(hidden_dim)) + self.beta = nn.Parameter(torch.zeros(hidden_dim)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.alpha * (x - mean) / (std + self.eps) + self.beta + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, n_heads, hidden_dim, dropout=0.1): + super().__init__() + + self.self_attn = MultiHeadAttention(hidden_dim, n_heads) + self.ff_layer = FeedForwardLayer(hidden_dim) + + self.norm1 = LayerNorm(hidden_dim) + self.dropout1 = nn.Dropout(dropout) + self.norm2 = LayerNorm(hidden_dim) + self.dropout2 = nn.Dropout(dropout) + + def forward(self, inp, mask): + x = self.norm1(inp) + x = inp + self.dropout1(self.self_attn(x, x, x, mask)) + x = x + self.dropout2(self.ff_layer(self.norm2(x))) + return x + + +class TransformerDecoderLayer(nn.Module): + def __init__(self, n_heads, hidden_dim, dropout=0.1): + super().__init__() + + self.self_attn = MultiHeadAttention(hidden_dim, n_heads) + self.cross_attn = MultiHeadAttention(hidden_dim, n_heads) + self.ff = FeedForwardLayer(hidden_dim) + + self.norm1 = LayerNorm(hidden_dim) + self.norm2 = LayerNorm(hidden_dim) + self.norm3 = LayerNorm(hidden_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + def forward(self, inp, mask, encoder_output, encoder_output_mask): + x = self.norm1(inp) + x = inp + self.dropout1(self.self_attn(x, x, x, mask)) + x = x + self.dropout2(self.cross_attn(self.norm2(x), encoder_output, encoder_output, encoder_output_mask)) + x = x + self.dropout3(self.ff(self.norm3(x))) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, vocab_size, n_layers, hidden_dim, n_heads): + super().__init__() + + self.n_layers = n_layers + self.embedding = nn.Embedding(vocab_size, hidden_dim) + self.posencoding = PosEncoding(hidden_dim) + self.layers = nn.ModuleList([copy.deepcopy(TransformerEncoderLayer(n_heads, hidden_dim)) for _ in range(n_layers)]) + self.layernorm = LayerNorm(hidden_dim) + + def forward(self, src, mask): + x = self.embedding(src) + x = self.posencoding(x) + for i in range(self.n_layers): + x = self.layers[i](x, mask) + return self.layernorm(x) + + +class TransformerDecoder(nn.Module): + def __init__(self, vocab_size, n_layers, hidden_dim, n_heads): + super().__init__() + + self.n_layers = n_layers + self.embedding = nn.Embedding(vocab_size, hidden_dim) + self.posencoding = PosEncoding(hidden_dim) + self.layers = nn.ModuleList([copy.deepcopy(TransformerDecoderLayer(n_heads, hidden_dim)) for _ in range(n_layers)]) + self.layernorm = LayerNorm(hidden_dim) + + def forward(self, inp, mask, encoder_output, encoder_output_mask): + x = self.embedding(inp) + x = self.posencoding(x) + for i in range(self.n_layers): + x = self.layers[i](x, mask, encoder_output, encoder_output_mask) + return self.layernorm(x) + + +class TransformerForSeq2Seq(nn.Module): + def __init__(self, src_vocab_size, tgt_vocab_size, n_layers, hidden_dim, n_heads): + super().__init__() + + self.encoder = TransformerEncoder(src_vocab_size, n_layers, hidden_dim, n_heads) + self.decoder = TransformerDecoder(tgt_vocab_size, n_layers, hidden_dim, n_heads) + self.output_dense = nn.Linear(hidden_dim, tgt_vocab_size) + + def forward(self, src, tgt, src_mask, tgt_mask): + encoder_outputs = self.encoder(src, src_mask) + decoder_outputs = self.decoder(tgt, tgt_mask, encoder_outputs, src_mask) + + return self.output_dense(decoder_outputs) diff --git a/test/ut/sdk/test_transformer_pruners.py b/test/ut/sdk/test_transformer_pruners.py new file mode 100644 index 0000000000..762a7bdf9b --- /dev/null +++ b/test/ut/sdk/test_transformer_pruners.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data +import math +import sys +import unittest +from unittest import TestCase, main + +from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner + +sys.path.append(os.path.dirname(__file__)) +from models.pytorch_models.transformer import TransformerEncoder + + +def validate_sparsity(wrapper, sparsity, bias=False): + masks = [wrapper.weight_mask] + if bias and wrapper.bias_mask is not None: + masks.append(wrapper.bias_mask) + for m in masks: + actual_sparsity = (m == 0).sum().item() / m.numel() + msg = 'actual sparsity: {:.2f}, target sparsity: {:.2f}'.format(actual_sparsity, sparsity) + assert math.isclose(actual_sparsity, sparsity, abs_tol=0.1), msg + + +class Model(nn.Module): + """ + A binary classifier using a transformer encoder for contextual embedding. + """ + def __init__(self, n_layer, hidden_dim, n_head): + super(Model, self).__init__() + self.embedding = TransformerEncoder(vocab_size=100, hidden_dim=hidden_dim, n_layers=n_layer, n_heads=n_head) + self.classifier = nn.Linear(hidden_dim, 1) + + def forward(self, x, mask): + raw_output = self.embedding(x, mask) + pooled_output = raw_output[0] + prediction = F.sigmoid(self.classifier(pooled_output)).squeeze() + return prediction + + +def train(model, dataloader, criterion, optimizer): + model.train() + device = next(model.parameters()).device + for _ in range(2): + y = torch.ones(10).to(device) + out = model(torch.randint(0, 100, (4, 10)).to(device), torch.ones(10).to(device)) + loss = criterion(out, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + +def dry_run(model): + device = next(model.parameters()).device + for _ in range(2): + y = torch.ones(10).to(device) + _ = model(torch.randint(0, 100, (4, 10)).to(device), torch.ones(10).to(device)) + + +def head_pruner_tests(criterion, global_sort, use_graph, iterative): + print("Testing criterion {} with global_sort={} and use_graph={}".format(criterion, global_sort, use_graph)) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Build config list and arguments + config_list = [{'sparsity': 0.5, 'op_types': ['Linear']}] + + kwargs = {'ranking_criterion': criterion, 'head_hidden_dim': 64} + if global_sort: + kwargs['global_sort'] = True + else: + kwargs['global_sort'] = False + + if use_graph: + attention_name_groups = list(zip(['embedding.layers.{}.self_attn.q_proj'.format(i) for i in range(6)], + ['embedding.layers.{}.self_attn.k_proj'.format(i) for i in range(6)], + ['embedding.layers.{}.self_attn.v_proj'.format(i) for i in range(6)], + ['embedding.layers.{}.self_attn.output_proj'.format(i) for i in range(6)])) + kwargs['attention_name_groups'] = attention_name_groups + else: + dummy_input = (torch.randint(0, 100, (10, 32)).to(device), torch.ones(32).to(device)) + kwargs['dummy_input'] = dummy_input + + if iterative: + kwargs['num_iterations'] = 2 + kwargs['epochs_per_iteration'] = 1 + + n_layers = 6 + n_heads = 8 + hidden_dim = 512 + model = Model(n_layers, hidden_dim, n_heads) + model.to(device) + kwargs['optimizer'] = torch.optim.SGD(model.parameters(), lr=0.001) + + def trainer(model, optimizer, criterion, epoch): + return train(model, None, criterion, optimizer) + kwargs['trainer'] = trainer + kwargs['criterion'] = nn.BCELoss() + + def forward_runner(model): + return dry_run(model) + kwargs['forward_runner'] = forward_runner + + # create pruner and call compress() + pruner = TransformerHeadPruner(model, config_list, **kwargs) + pruner.compress() + + # test model and mask export + pruner.export_model('./model_tmp.pth', './mask_tmp.pth', device=device) + dummy_input = (torch.randint(0, 100, (10, 32)).to(device), torch.ones(32).to(device)) + pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', + dummy_input=dummy_input, opset_version=10) + + # validate sparsity + if not global_sort: + for wrapper in pruner.modules_wrapper: + validate_sparsity(wrapper, wrapper.config['sparsity']) + + +class PrunerTestCase(TestCase): + def test_head_pruner(self): + for criterion in ["l1_weight", "l2_weight", "l1_activation", "l2_activation", "taylorfo"]: + for global_sort in [False, True]: + for use_graph in [False, True]: + for iterative in [False, True]: + head_pruner_tests(criterion, global_sort, use_graph, iterative) + + file_paths = ['./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', './search_history.csv', + './search_result.json'] + for f in file_paths: + if os.path.exists(f): + os.remove(f) + + +if __name__ == '__main__': + main()