From f6ab0632e8bf084de84619fa9722de072d2c739b Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 26 Oct 2023 18:31:12 -0400 Subject: [PATCH] Layerwise Sparsity Support for SparseGPT (#1777) --- src/sparseml/modifiers/obcq/base.py | 31 +++++++++-- src/sparseml/modifiers/obcq/pytorch.py | 15 ++++-- .../modifiers/obcq/utils/layer_compressor.py | 39 ++------------ tests/sparseml/modifiers/conf.py | 8 ++- .../pytorch/modifiers/obcq/test_pytorch.py | 52 +++++++++++++++++-- 5 files changed, 97 insertions(+), 48 deletions(-) diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index 59aa076768b..f9e7a1c2955 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -53,7 +53,7 @@ class SparseGPTModifier(Modifier): model.decoder for OPT or just model for Llama """ - sparsity: float + sparsity: Union[float, List[float]] block_size: int quantize: Union[bool, Dict] dampening_frac: Optional[float] = 0.01 @@ -63,7 +63,6 @@ class SparseGPTModifier(Modifier): targets: Union[str, List[str], None] = ALL_TOKEN target_ids: Optional[List[str]] = None layer_prefix: Optional[str] = None - compressible_layers_: List = None quantization_modifier_: Any = None @@ -76,7 +75,7 @@ def compressible_layers(self) -> List: compressible_dict = self.model.get_layers(self.targets) return [v for _, v in compressible_dict.items()] - def pre_initialize_structure(self, state: State, **kwargs): + def on_initialize_structure(self, state: State, **kwargs): quantization_already_active = state.model.qat_active() if isinstance(self.quantize, bool): if not self.quantize and quantization_already_active: @@ -123,7 +122,7 @@ def pre_initialize_structure(self, state: State, **kwargs): self.quantize = True if self.quantization_modifier_: - self.quantization_modifier_.pre_initialize_structure(state, **kwargs) + self.quantization_modifier_.on_initialize_structure(state, **kwargs) def _build_quant_modifier_from_dict(self, quant_config, framework): modifier_type = list(quant_config.keys())[0] @@ -135,3 +134,27 @@ def _build_quant_modifier_from_dict(self, quant_config, framework): allow_experimental=True, **modifier_args, ) + + def _validate_layerwise_sparisity(self): + if isinstance(self.sparsity, float): + return # single sparsity will be applied to all layers + + if not isinstance(self.targets, List): + raise ValueError( + "Layer targets must be a list when specifying layer-wise" + f" sparsity. Got {self.targets}" + ) + + if len(self.targets) != len(self.sparsity): + raise ValueError( + "Number of layer targets must match the number of " + f"sparsities. Got {len(self.targets)} layers and " + f"{len(self.sparsity)} sparsities" + ) + + for layer_name in self.targets: + if layer_name.startswith("re:"): + raise ValueError( + "Using regular expressions for layer-wise sparsity " + f"profiles is not permitted. Found {layer_name}" + ) diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index afdf530985f..70530f29c8d 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -54,8 +54,10 @@ def on_initialize(self, state: "State", **kwargs) -> bool: :param state: session state storing input model and calibration data """ + self._validate_layerwise_sparisity() + if not self.initialized_structure_: - self.pre_initialize_structure(state, **kwargs) + self.on_initialize_structure(state, **kwargs) if self.quantization_modifier_: self.quantization_modifier_.initialize(state, **kwargs) self.finalization_kwargs_ = {} @@ -118,10 +120,17 @@ def apply_obcq( "The 'outputs' key is expected but not found from the " "return of the bottom compressor" ) + inputs = accum_kwargs["outputs"] - _LOGGER.info(f"\n===== Compressing layer {idx}/{num_layers-1} =====") + layer_sparsity = ( + self.sparsity[idx] if isinstance(self.sparsity, List) else self.sparsity + ) + _LOGGER.info( + f"\n===== Compressing layer {idx+1}/{num_layers} " + f"to sparsity {layer_sparsity} =====" + ) args = { - "sparsity": self.sparsity, + "sparsity": layer_sparsity, "prunen": self.prunen, "prunem": self.prunem, "blocksize": self.block_size, diff --git a/src/sparseml/modifiers/obcq/utils/layer_compressor.py b/src/sparseml/modifiers/obcq/utils/layer_compressor.py index 7a0197d07d2..892e9defb67 100644 --- a/src/sparseml/modifiers/obcq/utils/layer_compressor.py +++ b/src/sparseml/modifiers/obcq/utils/layer_compressor.py @@ -17,11 +17,11 @@ from typing import Dict, List import torch -import torch.nn as nn from torch.nn import Module from sparseml.modifiers.obcq.utils.sparsegpt import SparseGPT from sparseml.pytorch.utils.helpers import get_dependency_order +from sparseml.utils.pytorch.module import get_prunable_layers _LOGGER = logging.getLogger(__name__) @@ -60,14 +60,8 @@ def compressible_modules(self) -> Dict: :return: dictionary of compressible modules """ - quantize = self.args.get("quantize", False) - if quantize: - # The layer names are changed due to quantization modifiers, therefore - # we need a slightly different func to retrieve layers - modules = _find_quant_layers(self.layer) - else: - modules = _find_layers(self.layer) - return modules + compressible_layers = get_prunable_layers(self.layer) + return compressible_layers def pre_compress_parallel(self, **kwargs) -> Dict: """ @@ -217,30 +211,3 @@ def tmp(_, inp, out): blocksize=self.args["blocksize"], ) gpts.free() - - -def _find_quant_layers( - module, layers=[torch.nn.qat.Conv2d, torch.nn.qat.Linear], name="" -): - res = {} - # search for QAT versions of layers - for name1, child in module.named_children(): - res.update( - _find_layers( - child, layers=layers, name=name + "." + name1 if name != "" else name1 - ) - ) - return res - - -def _find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""): - if type(module) in layers: - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update( - _find_layers( - child, layers=layers, name=name + "." + name1 if name != "" else name1 - ) - ) - return res diff --git a/tests/sparseml/modifiers/conf.py b/tests/sparseml/modifiers/conf.py index 90ab0c34be6..653f3d5c53b 100644 --- a/tests/sparseml/modifiers/conf.py +++ b/tests/sparseml/modifiers/conf.py @@ -25,9 +25,13 @@ def setup_modifier_factory(): class LifecyleTestingHarness: - def __init__(self, model=None, optimizer=None, framework=Framework.pytorch): + def __init__( + self, model=None, optimizer=None, framework=Framework.pytorch, device="cpu" + ): self.state = State(framework=framework) - self.state.update(model=model, optimizer=optimizer, start=0, steps_per_epoch=1) + self.state.update( + model=model, device=device, optimizer=optimizer, start=0, steps_per_epoch=1 + ) self.event_lifecycle = CallbacksEventLifecycle( type_first=EventType.BATCH_START, start=self.state.start_event diff --git a/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py b/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py index 709a932e3cb..01b5afb886f 100644 --- a/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + +from sparseml.core.framework import Framework +from sparseml.core.model import ModifiableModel from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch from sparseml.modifiers.quantization import QuantizationModifier from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch @@ -19,6 +23,48 @@ from tests.sparseml.pytorch.helpers import LinearNet +@pytest.mark.parametrize( + "sparsity,targets", + [ + ([0.5, 0.2], "__ALL__"), # type mismatch + ([0.2, 0.1, 0.3], ["seq.fc1", "seq.fc2"]), # length mismatch + ([0.3, 0.4], ["re:.*fc1", "re:.*fc2"]), # regex not supported + ], +) +def test_invalid_layerwise_recipes_raise_exceptions(sparsity, targets): + setup_modifier_factory() + model = LinearNet() + + kwargs = dict( + sparsity=sparsity, + block_size=128, + quantize=False, + targets=targets, + ) + modifier = SparseGPTModifierPyTorch(**kwargs) + testing_harness = LifecyleTestingHarness(model=model) + + # confirm invalid layerwise recipes fail at initialization + with pytest.raises(ValueError): + modifier.initialize(testing_harness.get_state()) + + +def test_successful_layerwise_recipe(): + setup_modifier_factory() + model = LinearNet() + + sparsities = [0.5, 0.2] + targets = ["seq.fc1", "seq.fc2"] + kwargs = dict(sparsity=sparsities, block_size=128, quantize=False, targets=targets) + modifier = SparseGPTModifierPyTorch(**kwargs) + modifier._validate_layerwise_sparisity() + modifier.model = ModifiableModel(framework=Framework.pytorch, model=model) + found_compressible_layers = modifier.compressible_layers() + + # ensure layers names successfully match up with model + assert len(found_compressible_layers) == len(targets) + + def test_create_default_quant_modifier(): setup_modifier_factory() kwargs = dict(sparsity=0.5, block_size=128, quantize=True) @@ -27,7 +73,7 @@ def test_create_default_quant_modifier(): assert modifier.quantization_modifier_ is None testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.pre_initialize_structure(testing_harness.get_state()) + modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize assert isinstance(modifier.quantization_modifier_, QuantizationModifier) @@ -59,7 +105,7 @@ def test_set_quant_if_modifer_already_exists(): kwargs = dict(sparsity=0.5, block_size=128, quantize=False) modifier = SparseGPTModifierPyTorch(**kwargs) assert not modifier.quantize - modifier.pre_initialize_structure(testing_harness.get_state()) + modifier.on_initialize_structure(testing_harness.get_state()) # quantization modifier not owned by SparseGPT assert modifier.quantization_modifier_ is None @@ -95,7 +141,7 @@ def test_set_quant_in_sparsegpt(): assert modifier.quantization_modifier_ is None testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.pre_initialize_structure(testing_harness.get_state()) + modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize assert isinstance(modifier.quantization_modifier_, QuantizationModifier)