diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py index 602682004a2..387114ad96a 100644 --- a/src/sparseml/core/model/base.py +++ b/src/sparseml/core/model/base.py @@ -116,3 +116,11 @@ def set_param(self, target: str, param: PT): :param param: the param instance to set """ raise NotImplementedError() + + def qat_active(self) -> bool: + """ + Checks if quantization aware training is set up in the model + + :return: True if QAT is active in any layer, False otherwise + """ + raise NotImplementedError() diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py index 670d164900c..258675115ba 100644 --- a/src/sparseml/core/model/pytorch.py +++ b/src/sparseml/core/model/pytorch.py @@ -24,6 +24,7 @@ get_layers_params, get_param, get_params, + qat_active, set_layer, set_param, ) @@ -94,3 +95,11 @@ def set_param(self, target: str, param: Parameter): :param param: the parameter to set """ return set_param(target, param, self.model) + + def qat_active(self) -> bool: + """ + Checks if quantization aware training is set up in the model + + :return: True if QAT is active in any layer, False otherwise + """ + return qat_active(self.model) diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index a739806ee43..bbbf3ee3281 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from typing import List, Optional, Union +import logging +from typing import Any, Dict, List, Optional, Union from sparseml.core import Modifier +from sparseml.core.factory import ModifierFactory from sparseml.core.state import State from sparseml.utils import ALL_TOKEN __all__ = ["SparseGPTModifier"] +_LOGGER = logging.getLogger(__name__) + class SparseGPTModifier(Modifier): """ @@ -34,7 +37,9 @@ class SparseGPTModifier(Modifier): :param sparsity: Sparsity to compress model to :param block_size: Used to determine number of columns to compress in one pass - :param quantize: Whether or not model is quantized (affects layer names) + :param quantize: Whether or not to quantize weights during SparseGPT. Set to True + to quantize using an existing quantization modifier, or pass in the configuration + for a quantization modifier if one does not already exist in the recipe :param dampening_frac: Amount of dampening to apply to H, as a fraction of the diagonal norm :param sequential_update: Whether or not to update weights sequentially by layer, @@ -50,7 +55,7 @@ class SparseGPTModifier(Modifier): sparsity: Union[float, List[float]] block_size: int - quantize: bool + quantize: Union[bool, Dict] dampening_frac: Optional[float] = 0.01 sequential_update: Optional[bool] = True prunen: Optional[int] = 0 @@ -58,9 +63,77 @@ class SparseGPTModifier(Modifier): targets: Union[str, List[str], None] = ALL_TOKEN target_ids: Optional[List[str]] = None layer_prefix: Optional[str] = None + compressible_layers_: List = None + quantization_modifier_: Any = None + + def compressible_layers(self) -> List: + """ + Retrieves the modules corresponding to a list of compressible layer names + + :return: list of Pytorch modules to compress + """ + compressible_dict = self.model.get_layers(self.targets) + return [v for _, v in compressible_dict.items()] + + def on_intialize_structure(self, state: State, **kwargs): + quantization_already_active = state.model.qat_active() + if isinstance(self.quantize, bool): + if not self.quantize and quantization_already_active: + _LOGGER.warning( + "SparseGPT quantization is set to False, but a " + "quantization modifier is already active on the model " + "resetting quantize to True" + ) + self.quantize = True + elif self.quantize and not quantization_already_active: + _LOGGER.warning( + "SparseGPT quantization is set to True without an " + "active quantization modifier. Creating a default " + "8-bit quantization modifier" + ) + default_quant_config = {"QuantizationModifier": {}} + self._build_quant_modifier_from_dict( + default_quant_config, state.framework + ) + return # use existing quantization modifier if there is one + else: + if not isinstance(self.quantize, Dict): + raise ValueError( + "SparseGPTModifier.quantize accepts only a single " + "quantization modifier or a boolean. Found " + f"type {type(self.quantize)}" + ) + if len(self.quantize) != 1: + raise ValueError( + "SparseGPTModifier.quantize accepts only a single " + "quantization modifier or a boolean. Found " + f"{len(self.quantize)} modifiers" + ) + if quantization_already_active: + _LOGGER.warning( + "Attempting to initialize quantization for SparseGPT " + "but a quantization modifier has already been applied. " + "The quantization configuration defined under the " + "SparseGPT modifier will be ignored." + ) + self.quantize = True + return + self._build_quant_modifier_from_dict(self.quantize, state.framework) + self.quantize = True + + if self.quantization_modifier_: + self.quantization_modifier_.on_intialize_structure(state, **kwargs) - def on_initialize_structure(self, state: "State", **kwargs): - pass # nothing needed for this modifier + def _build_quant_modifier_from_dict(self, quant_config, framework): + modifier_type = list(quant_config.keys())[0] + modifier_args = quant_config[modifier_type] + self.quantization_modifier_ = ModifierFactory.create( + modifier_type, + framework=framework, + allow_registered=True, + allow_experimental=True, + **modifier_args, + ) def _validate_layerwise_sparisity(self): if isinstance(self.sparsity, float): diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index 998550c7876..1e4e99a3b3f 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -17,7 +17,6 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple import torch -from torch.nn import Module from sparseml.core.model import ModifiableModel from sparseml.core.state import State @@ -46,31 +45,9 @@ class SparseGPTModifierPyTorch(SparseGPTModifier): """ model: Any = None - compressible_layers_: List = None device_: str = "cuda:0" finalization_kwargs_: Dict = None - def compressible_layers(self) -> List[Module]: - """ - Retrieves the modules corresponding to a list of compressible layer names - - :return: list of Pytorch modules to compress - """ - compressible_dict = self.model.get_layers(self.targets) - - # Compare length to sparsities again in case one of the provided layers was - # invalid or not compressible - if isinstance(self.sparsity, List) and len(self.sparsity) != len( - compressible_dict - ): - raise ValueError( - "Number of compressible layers must match the number of " - f"sparsities. Got {len(compressible_dict)} layers and " - f"{len(self.sparsity)} sparsities" - ) - - return [v for _, v in compressible_dict.items()] - def on_initialize(self, state: "State", **kwargs) -> bool: """ Initialize and run the OBCQ algorithm on the current state @@ -79,6 +56,10 @@ def on_initialize(self, state: "State", **kwargs) -> bool: """ self._validate_layerwise_sparisity() + if not self.initialized_structure_: + self.on_intialize_structure(state, **kwargs) + if self.quantization_modifier_: + self.quantization_modifier_.initialize(state, **kwargs) self.finalization_kwargs_ = {} modifiable_model = state.model calibration_dataloader = state.data.calib @@ -172,9 +153,11 @@ def on_finalize(self, state: "State", **kwargs) -> bool: :param state: un-used, for matching spec of Modifier base class """ use_cache = self.finalization_kwargs_.get("use_cache", False) - self.model.apply(torch.quantization.disable_observer) self.model.config.use_cache = use_cache + if self.quantization_modifier_: + self.quantization_modifier_.finalize(state, **kwargs) + return True def compress_bottom( diff --git a/src/sparseml/modifiers/obcq/utils/layer_compressor.py b/src/sparseml/modifiers/obcq/utils/layer_compressor.py index 7a0197d07d2..892e9defb67 100644 --- a/src/sparseml/modifiers/obcq/utils/layer_compressor.py +++ b/src/sparseml/modifiers/obcq/utils/layer_compressor.py @@ -17,11 +17,11 @@ from typing import Dict, List import torch -import torch.nn as nn from torch.nn import Module from sparseml.modifiers.obcq.utils.sparsegpt import SparseGPT from sparseml.pytorch.utils.helpers import get_dependency_order +from sparseml.utils.pytorch.module import get_prunable_layers _LOGGER = logging.getLogger(__name__) @@ -60,14 +60,8 @@ def compressible_modules(self) -> Dict: :return: dictionary of compressible modules """ - quantize = self.args.get("quantize", False) - if quantize: - # The layer names are changed due to quantization modifiers, therefore - # we need a slightly different func to retrieve layers - modules = _find_quant_layers(self.layer) - else: - modules = _find_layers(self.layer) - return modules + compressible_layers = get_prunable_layers(self.layer) + return compressible_layers def pre_compress_parallel(self, **kwargs) -> Dict: """ @@ -217,30 +211,3 @@ def tmp(_, inp, out): blocksize=self.args["blocksize"], ) gpts.free() - - -def _find_quant_layers( - module, layers=[torch.nn.qat.Conv2d, torch.nn.qat.Linear], name="" -): - res = {} - # search for QAT versions of layers - for name1, child in module.named_children(): - res.update( - _find_layers( - child, layers=layers, name=name + "." + name1 if name != "" else name1 - ) - ) - return res - - -def _find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""): - if type(module) in layers: - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update( - _find_layers( - child, layers=layers, name=name + "." + name1 if name != "" else name1 - ) - ) - return res diff --git a/src/sparseml/transformers/sparsification/obcq/example.yaml b/src/sparseml/transformers/sparsification/obcq/example.yaml index 619b76d7122..479e9c15b7d 100644 --- a/src/sparseml/transformers/sparsification/obcq/example.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example.yaml @@ -1,20 +1,20 @@ test_stage: obcq_modifiers: - QuantizationModifier: - ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"] - post_oneshot_calibration: True - scheme_overrides: - ReLU: - input_activations: null - output_activations: null - LayerNorm: - input_activations: null - output_activations: null SparseGPTModifier: sparsity: 0.5 block_size: 128 sequential_update: False - quantize: True + quantize: + QuantizationModifier: + ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"] + post_oneshot_calibration: True + scheme_overrides: + ReLU: + input_activations: null + output_activations: null + LayerNorm: + input_activations: null + output_activations: null percdamp: 0.01 prunen: 0 prunem: 0 diff --git a/src/sparseml/utils/pytorch/module.py b/src/sparseml/utils/pytorch/module.py index 05fae4a174a..fb3956016b0 100644 --- a/src/sparseml/utils/pytorch/module.py +++ b/src/sparseml/utils/pytorch/module.py @@ -65,6 +65,7 @@ "get_terminal_layers", "get_prunable_layers", "get_quantizable_layers", + "qat_active", "get_layers_params", ] @@ -241,6 +242,21 @@ def get_quantizable_layers(module: Module) -> Dict[str, Module]: return quantizable +def qat_active(module: Module) -> bool: + """ + Determines if any layers in the model have quantization enabled by checking for + weight_fake_quant attributes + + :param module: PyTorch model to check for quantization + :return: True if quantization is active anywhere in the model, False otherwise + """ + for _, layer in module.named_modules(): + if isinstance(layer, torch.quantization.FakeQuantize): + return True + + return False + + def get_layers_params( targets: Union[str, List[str]], module: Module ) -> Dict[str, ModelParameterizedLayer[Parameter, Module]]: diff --git a/tests/sparseml/pytorch/modifiers/obcq/__init__.py b/tests/sparseml/pytorch/modifiers/obcq/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/pytorch/modifiers/obcq/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py b/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py index 09048904546..1da948db45a 100644 --- a/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py @@ -17,6 +17,8 @@ from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch +from sparseml.modifiers.quantization import QuantizationModifier +from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory from tests.sparseml.pytorch.helpers import LinearNet @@ -61,3 +63,91 @@ def test_successful_layerwise_recipe(): # ensure layers names successfully match up with model assert len(found_compressible_layers) == len(targets) + + +def test_create_default_quant_modifier(): + setup_modifier_factory() + kwargs = dict(sparsity=0.5, block_size=128, quantize=True) + + modifier = SparseGPTModifierPyTorch(**kwargs) + assert modifier.quantization_modifier_ is None + + testing_harness = LifecyleTestingHarness(model=LinearNet()) + modifier.pre_initialize_structure(testing_harness.get_state()) + assert modifier.quantize + assert isinstance(modifier.quantization_modifier_, QuantizationModifier) + + should_be_default_quant_scheme = modifier.quantization_modifier_.scheme + assert should_be_default_quant_scheme.input_activations.num_bits == 8 + assert not should_be_default_quant_scheme.input_activations.symmetric + assert should_be_default_quant_scheme.weights.num_bits == 8 + assert should_be_default_quant_scheme.weights.symmetric + + +def test_set_quant_if_modifer_already_exists(): + setup_modifier_factory() + + model = LinearNet() + kwargs = dict( + scheme=dict( + input_activations=dict(num_bits=8, symmetric=True), + weights=dict(num_bits=4, symmetric=False), + ), + ) + + modifier = QuantizationModifierPyTorch(**kwargs) + testing_harness = LifecyleTestingHarness(model=model) + + assert not testing_harness.get_state().model.qat_active() + modifier.initialize(testing_harness.get_state()) + assert testing_harness.get_state().model.qat_active() + + kwargs = dict(sparsity=0.5, block_size=128, quantize=False) + modifier = SparseGPTModifierPyTorch(**kwargs) + assert not modifier.quantize + modifier.pre_initialize_structure(testing_harness.get_state()) + + # quantization modifier not owned by SparseGPT + assert modifier.quantization_modifier_ is None + + # since quantization modifier is already applied, quantization must be set in OBCQ + assert modifier.quantize + + +def test_set_quant_in_sparsegpt(): + setup_modifier_factory() + + quant_kwargs = { + "scheme": { + "input_activations": { + "num_bits": 8, + "symmetric": False, + "strategy": "tensor", + "kwargs": {}, + }, + "weights": { + "num_bits": 4, + "symmetric": True, + "strategy": "channel", + "kwargs": {}, + }, + } + } + quant_config = {"QuantizationModifier": quant_kwargs} + + kwargs = dict(sparsity=0.5, block_size=128, quantize=quant_config) + + modifier = SparseGPTModifierPyTorch(**kwargs) + assert modifier.quantization_modifier_ is None + + testing_harness = LifecyleTestingHarness(model=LinearNet()) + modifier.pre_initialize_structure(testing_harness.get_state()) + assert modifier.quantize + assert isinstance(modifier.quantization_modifier_, QuantizationModifier) + + dict_scheme = dict(modifier.quantization_modifier_.scheme) + assert dict(dict_scheme["weights"]) == quant_kwargs["scheme"]["weights"] + assert ( + dict(dict_scheme["input_activations"]) + == quant_kwargs["scheme"]["input_activations"] + )