-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #22 from pytorch-labs/jcaip/sparsity
[sparse] add sparsity, add wanda sparsifier to ao
- Loading branch information
Showing
4 changed files
with
285 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import logging | ||
import unittest | ||
|
||
import torch | ||
from torch import nn | ||
from torchao.sparsity import WandaSparsifier | ||
from torch.ao.pruning import FakeSparsity | ||
from torch.nn.utils.parametrize import is_parametrized | ||
from torch.testing._internal.common_pruning import SimpleLinear | ||
from torch.testing._internal.common_utils import TestCase | ||
|
||
logging.basicConfig( | ||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO | ||
) | ||
|
||
|
||
class TestWandaSparsifier(TestCase): | ||
""" | ||
Test Wanda Sparsifier | ||
""" | ||
|
||
def test_prepare(self): | ||
model = SimpleLinear() | ||
sparsifier = WandaSparsifier() | ||
sparsifier.prepare(model, config=None) | ||
for g in sparsifier.groups: | ||
module = g["module"] | ||
# Check mask exists | ||
assert hasattr(module.parametrizations["weight"][0], "mask") | ||
# Check parametrization exists and is correct | ||
assert is_parametrized(module, "weight") | ||
assert type(module.parametrizations.weight[0]) == FakeSparsity | ||
# check activation observer is present | ||
assert hasattr(module, "activation_post_process") | ||
|
||
def test_squash_mask(self): | ||
# check observers and parameterizations removed | ||
model = SimpleLinear() | ||
sparsifier = WandaSparsifier() | ||
sparsifier.prepare(model, config=None) | ||
sparsifier.squash_mask() | ||
for g in sparsifier.groups: | ||
module = g["module"] | ||
assert not is_parametrized(module, "weight") | ||
assert not hasattr(module, "mask") | ||
assert not hasattr(module, "activation_post_process") | ||
|
||
def test_one_layer_mlp_2x4(self): | ||
model = nn.Sequential(nn.Linear(8, 1)) | ||
weights = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) | ||
model[0].weight.data.copy_(weights.data) | ||
X = torch.ones(1, 8) | ||
|
||
sparsifier = WandaSparsifier(semi_structured_block_size=4) | ||
sparsifier.prepare(model, config=None) | ||
|
||
model(X) | ||
|
||
sparsifier.step() | ||
sparsifier.squash_mask() | ||
|
||
sparsity = (model[0].weight == 0).float().mean() | ||
assert sparsity == 0.5 | ||
|
||
expected_fc = torch.tensor([[0, 0, 3, 4, 0, 0, 7, 8]], dtype=torch.float32) | ||
assert torch.allclose(model[0].weight.data, expected_fc, rtol=1e-05, atol=1e-07) | ||
|
||
def test_one_layer_mlp_unstructured(self): | ||
model = nn.Sequential(nn.Linear(4, 1)) | ||
weights = torch.tensor([[1, 2, 3, 4]], dtype=torch.float32) | ||
model[0].weight.data.copy_(weights.data) | ||
X = torch.tensor([[100, 10, 1, 0.1]], dtype=torch.float32) | ||
|
||
sparsifier = WandaSparsifier(sparsity_level=0.5) | ||
sparsifier.prepare(model, config=None) | ||
|
||
model(X) | ||
|
||
sparsifier.step() | ||
sparsifier.squash_mask() | ||
|
||
sparsity = (model[0].weight == 0).float().mean() | ||
assert sparsity == 0.5 | ||
|
||
expected_fc = torch.tensor([[1, 2, 0, 0]], dtype=torch.float32) | ||
assert torch.allclose(model[0].weight.data, expected_fc, rtol=1e-05, atol=1e-07) | ||
|
||
def test_two_layer_mlp_unstructured(self): | ||
model = nn.Sequential( | ||
nn.Linear(128, 200), nn.ReLU(), nn.Linear(200, 10) | ||
) # C_in by C_out | ||
X1 = torch.randn(100, 128) # B1 by C_in | ||
X2 = torch.randn(50, 128) # B2 by C_in | ||
|
||
sparsifier = WandaSparsifier(sparsity_level=0.5) | ||
sparsifier.prepare(model, config=None) | ||
|
||
model(X1) | ||
model(X2) | ||
sparsifier.step() | ||
|
||
cnt = 0 | ||
for m in model.modules(): | ||
if isinstance(m, nn.Linear): | ||
cnt += 1 | ||
sparsity_level = (m.weight == 0).float().mean() | ||
assert ( | ||
sparsity_level == 0.5 | ||
), f"sparsity for linear layer {cnt} should be 0.5" | ||
|
||
sparsifier.squash_mask() | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from .wanda import WandaSparsifier # noqa: F403 | ||
from .utils import PerChannelNormObserver # noqa: F403 | ||
|
||
__all__ = [ | ||
"WandaSparsifier", | ||
"PerChannelNormObserver" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch | ||
from torch.ao.quantization.observer import UniformQuantizationObserverBase | ||
|
||
__all__ = ["PerChannelNormObserver"] | ||
|
||
# Observers | ||
class PerChannelNormObserver(UniformQuantizationObserverBase): | ||
""" | ||
A custom observer that computes the L2 norm of each channel and stores it in a buffer. | ||
""" | ||
|
||
def __init__(self, **kwargs) -> None: | ||
# init with fixed qparams for quantization flow | ||
super().__init__( | ||
dtype=torch.quint8, | ||
qscheme=torch.per_channel_affine, | ||
reduce_range=False, | ||
quant_min=None, | ||
quant_max=None, | ||
eps=torch.finfo(torch.float32).eps, | ||
**kwargs | ||
) | ||
# set averaging constant so quantization flow knows observer is memoryless. | ||
self.averaging_constant = 1.0 | ||
self.register_buffer("norm", torch.tensor([])) | ||
|
||
def forward(self, x_orig): | ||
if x_orig.numel() == 0: | ||
return x_orig | ||
x = x_orig.detach() # avoid keeping autograd tape | ||
|
||
# channel_ax is always the last dimension | ||
new_axis_list = [i for i in range(x.dim())] # noqa: C416 | ||
new_axis_list[0], new_axis_list[-1] = new_axis_list[-1], new_axis_list[0] | ||
y = x.permute(new_axis_list) | ||
y = torch.flatten(y, start_dim=1) | ||
norm = torch.norm(y, dim=1) ** 2 | ||
|
||
if self.norm.numel() == 0: | ||
self.norm.resize_(norm.shape) | ||
self.norm.copy_(norm) | ||
else: | ||
self.norm += norm | ||
|
||
return x_orig | ||
|
||
def calculate_qparams(self): | ||
raise NotImplementedError("PerChannelNormObserver is designed to store activations only. ") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
|
||
import warnings | ||
|
||
from typing import Dict, List, Optional, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
from torch.ao.pruning import BaseSparsifier | ||
from torch.ao.quantization import default_placeholder_observer, QConfig | ||
from torch.ao.quantization.quantize import _remove_qconfig | ||
from .utils import PerChannelNormObserver | ||
|
||
__all__ = ["WandaSparsifier"] | ||
|
||
|
||
class WandaSparsifier(BaseSparsifier): | ||
r"""Wanda sparsifier | ||
Wanda (Pruning by Weights and activations), proposed in https://arxiv.org/abs/2306.11695 | ||
is an activation aware pruning method. The sparsifier removes weights based on the product | ||
of the input activation norm and the weight magnitude. | ||
This sparsifier is controlled by three variables: | ||
1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out; | ||
Args: | ||
sparsity_level: The target level of sparsity; | ||
model: The model to be sparsified; | ||
""" | ||
|
||
def __init__( | ||
self, | ||
sparsity_level: float = 0.5, | ||
semi_structured_block_size: Optional[int] = None, | ||
): | ||
defaults = { | ||
"sparsity_level": sparsity_level, | ||
"semi_structured_block_size": semi_structured_block_size, | ||
} | ||
if semi_structured_block_size is not None: | ||
m = semi_structured_block_size | ||
warnings.warn( | ||
f"WandaSparsifier got semi_structured_bock_size={m}, sparsity_level fixed to 50% ({m // 2}:{m}) sparsity" | ||
) | ||
super().__init__(defaults=defaults) | ||
|
||
def prepare(self, model: nn.Module, config: List[Dict]) -> None: | ||
# activation: use PerChannelNormObserver | ||
# use no-op placeholder weight observer | ||
model.qconfig = QConfig( | ||
activation=PerChannelNormObserver, weight=default_placeholder_observer | ||
) # type: ignore[assignment] | ||
torch.ao.quantization.prepare(model, inplace=True) | ||
|
||
# call superclass prepare | ||
super().prepare(model, config) | ||
|
||
def update_mask( # type: ignore[override] | ||
self, module: nn.Module, tensor_name: str, sparsity_level: float, **kwargs | ||
) -> None: | ||
r"""Pruning function for WandaSparsifier | ||
The activation statistics is retrieved first in the `act_per_input` variable. | ||
Then the Wanda pruning metric is computed. The weight matrix is then pruned | ||
by comparing this metric across the whole current layer. | ||
""" | ||
|
||
# Step 1: get the tensor and the mask from the parametrizations | ||
mask = getattr(module.parametrizations, tensor_name)[0].mask | ||
tensor = getattr(module.parametrizations, tensor_name).original | ||
activation_norm_per_channel = module.activation_post_process.norm | ||
|
||
# Step 2: Calculate Wx | ||
pruning_metric = torch.abs(tensor) * activation_norm_per_channel | ||
|
||
# defaults for unstructured sparsity | ||
block_size = pruning_metric.numel() | ||
num_specified = int(block_size * sparsity_level) | ||
# if set to use semi-structured, ignore sparsity_level | ||
if kwargs.get("semi_structured_block_size", None) is not None: | ||
block_size = kwargs["semi_structured_block_size"] | ||
num_specified = block_size // 2 | ||
|
||
# get indicies to prune | ||
pruning_inds = pruning_metric.view(-1, block_size).argsort(dim=1)[ | ||
:, :num_specified | ||
] | ||
# update mask | ||
mask.data.view(-1, block_size).scatter_( | ||
1, pruning_inds, torch.zeros_like(pruning_inds, dtype=mask.dtype) | ||
) | ||
|
||
def squash_mask( | ||
self, | ||
params_to_keep: Optional[Tuple[str, ...]] = None, | ||
params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None, | ||
*args, | ||
**kwargs, | ||
): | ||
# remove quantization config | ||
for config in self.groups: | ||
module = config["module"] | ||
tensor_name = config["tensor_name"] | ||
_remove_qconfig(module) | ||
|
||
# remove parameterizations | ||
super().squash_mask( | ||
params_to_keep=params_to_keep, | ||
params_to_keep_per_layer=params_to_keep_per_layer, | ||
) |