Skip to content

Commit

Permalink
[Update] layer prefix to be set at model level (#1778)
Browse files Browse the repository at this point in the history
* - Update `src/sparseml/modifiers/obcq/pytorch.py`
to use layer prefix for from model
- Remove `layer_prefix` from `SparseGPTModifier` base
- Update ModelMetaData to include layer_prefix
- Added a convenience function to update missing
values in RecipeMetaData instance from another RecipeMetaData instance
- Update simplify recipe to also include metadata
- Update simplify_combine_recipes to include metadata
- Add layer_prefix property to `ModifiableModel`
- propagate `layer_prefix` to superclass
- update session.py to set_layer_prefix on the model
before initializing modifiers
- Update example recipe to include layer_prefix in metadata

* Add missing docstring

* - address review comment
- update docstring
- add test for `update_missing_metadata`

* Add test

* Style

* Fix tests

* Style

* [modifier refactor] Add constant pruning tests  (#1752)

* Initial commit

* Add end to end tests

* Add e2e tests for constant pruning modifier

* Move imports inside the test fuctions so
that torch isn't imported unless running the tests

* Update setup.py to not run modifier tests unless pytorch is specified

* [Bugfix] .dict() method on Recipe (#1753)

* Bugfix .dict() method on Recipe

* Remove extraneous local test, [faulty commit]

* [modifier refactor] Add serialization tests (#1755)

* Add serialization tests

* Clean up

* Keep original stage and group names
Clean up _get_yaml_dict

* fix comment

* Typo

* [Unit Tests][Modifier Refactor] (#1756)

* Move valid recipes to a helper file
Add tests for session.py

* Increase test coverage of src/sparseml/core/session.py
to 100%
Run Style
Add logs to .gitignore

* Increase coverage of tests/sparseml/core/test_state.py
to 100%

* add tests for lifecycle/event.py

* Increase code coverage of lifecycle/event to
100%

* increase lifecycle/session.py code coverage to 93%

* Address review comments from @Satrat

* Address review comments on 1752 (#1772)

Update makefile to only ignore *pytorch.py files in modifier dir
Fix order in test
Add regex to makefile
Add helper function to determine if torch tests should be run
Check masks
Make transformers import optional in sparsegpt.py

* Fix merge conflict

* Add more tests to check valid modifiers are created (#1774)

* [Bug][ConstantPruningModifier] Fix mask de register bug (#1773)

* Fix mask de-register logic

* forgot to remove commented out line

* Move tests inside pytorch directory as requested

* Fix session reset (#1790)

* fix datasets version to be compatible with fsspec (#1797)

* Add kvcache config for Mistral (#1766)

* Add kvcache config for Mistral

* Update configs.py

* Update configs.py

* Fix reset logic

* Style after resolving merge conflicts

---------

Co-authored-by: Sara Adkins <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
  • Loading branch information
3 people authored and Benjamin committed Nov 16, 2023
1 parent 312f2f0 commit aa74932
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 12 deletions.
14 changes: 13 additions & 1 deletion src/sparseml/core/lifecycle/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SparsificationLifecycle:

def reset(self):
for mod in self.modifiers:
if not mod.initialized_ or mod.finalized:
if not mod.initialized or mod.finalized:
continue

try:
Expand Down Expand Up @@ -87,6 +87,7 @@ def initialize(self, framework: Framework = None, **kwargs) -> List[Any]:
extras = self.recipe_container.update(**extras)

self._check_compile_recipe()
self._set_model_layer_prefix()
mod_data = []
for mod in self.modifiers:
data = mod.initialize(state=self.state, **extras)
Expand Down Expand Up @@ -208,3 +209,14 @@ def _check_setup_event_lifecycle(self, event_type: EventType):
)
else:
raise ValueError(f"invalid event type {event_type}")

def _set_model_layer_prefix(self):
if (
(compiled_recipe := self.recipe_container.compiled_recipe) is None
or (metadata := compiled_recipe.metadata) is None
or (model_metadata := metadata.target_model) is None
):
return False

self.state.model.layer_prefix = model_metadata.layer_prefix
return True
26 changes: 25 additions & 1 deletion src/sparseml/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,21 @@ class ModifiableModel(Generic[MT, LT, PT], MultiFrameworkObject):
to be searchable by the MultiFrameworkObject factory method.
:param framework: the framework the model is in
:param layer_prefix: name of model attribute that contains the list of layers, i.e.
model.decoder for OPT or just model for Llama
:param model: the model object
"""

model: MT = None

def __init__(self, framework: Optional[Framework] = None, model=None):
def __init__(
self,
framework: Optional[Framework] = None,
model=None,
layer_prefix: Optional[str] = None,
):
self.model = model
self._layer_prefix = layer_prefix

def get_layers_params(
self, targets: Union[str, List[str]]
Expand Down Expand Up @@ -117,6 +125,22 @@ def set_param(self, target: str, param: PT):
"""
raise NotImplementedError()

@property
def layer_prefix(self) -> Optional[str]:
"""
:return: the name of model attribute that contains the list of layers, i.e.
model.decoder for OPT or just model for Llama
"""
return self._layer_prefix

@layer_prefix.setter
def layer_prefix(self, value: Optional[str]):
"""
:param value: the name of model attribute that contains the list of layers, i.e.
model.decoder for OPT or just model for Llama
"""
self._layer_prefix = value

def get_matching_layer(
self, target: str, name_to_match: str, model: LT
) -> Optional[Tuple[str, LT]]:
Expand Down
9 changes: 7 additions & 2 deletions src/sparseml/core/model/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,17 @@ class ModifiableModelPyTorch(ModifiableModel[Module, Module, Parameter]):
:param framework: the framework the model is in
:param model: the model object
:param layer_prefix: name of model attribute that contains the list of layers, i.e.
model.decoder for OPT or just model for Llama
"""

def __init__(
self, framework: Optional[Framework] = None, model: Optional[Module] = None
self,
framework: Optional[Framework] = None,
model: Optional[Module] = None,
layer_prefix: Optional[str] = None,
):
super().__init__(framework=framework, model=model)
super().__init__(framework=framework, model=model, layer_prefix=layer_prefix)

def get_layers_params(
self, targets: Union[str, List[str]]
Expand Down
18 changes: 17 additions & 1 deletion src/sparseml/core/recipe/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -69,6 +69,7 @@ class ModelMetaData(BaseModel):
input_shapes: List[List[int]] = None
output_shapes: List[List[int]] = None
layers: List[LayerMetaData] = Field(default_factory=list)
layer_prefix: Optional[str] = None


class RecipeMetaData(BaseModel):
Expand All @@ -79,3 +80,18 @@ class RecipeMetaData(BaseModel):
tags: List[str] = None
target_dataset: DatasetMetaData = None
target_model: ModelMetaData = None

def update_missing_metadata(self, other: "RecipeMetaData"):
"""
Update recipe metadata with missing values from another
recipe metadata instance
:param other: the recipe metadata to update with
"""
self.domain = self.domain or other.domain
self.task = self.task or other.task
self.versions = self.versions or other.versions
self.requirements = self.requirements or other.requirements
self.tags = self.tags or other.tags
self.target_dataset = self.target_dataset or other.target_dataset
self.target_model = self.target_model or other.target_model
20 changes: 20 additions & 0 deletions src/sparseml/core/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def simplify_recipe(
simplified.args = RecipeArgs(args)
simplified.stages = stages
simplified.evaluate(args=args, shift=shift)
simplified.metadata = (
recipe.metadata if isinstance(recipe, Recipe) else recipe.recipe.metadata
)

return simplified

Expand Down Expand Up @@ -185,6 +188,7 @@ def simplify_combine_recipes(
combined.version = simplified.version
combined.stages.extend(simplified.stages)
combined.args.update(simplified.args)
combined.combine_metadata(simplified.metadata)

return combined

Expand Down Expand Up @@ -388,6 +392,22 @@ def extract_dict_stages(values: Dict[str, Any]) -> List[Dict[str, Any]]:

return stages

def combine_metadata(self, metadata: Optional[RecipeMetaData]):
"""
Combines the metadata of the recipe with the supplied metadata
If the recipe already has metadata, the supplied metadata will
be used to update missing metadata
:param metadata: The metadata to combine with the recipe
"""
if metadata is None:
return

if self.metadata is None:
self.metadata = metadata
else:
self.metadata.update_missing_metadata(metadata)

def dict(self, *args, **kwargs) -> Dict[str, Any]:
"""
>>> recipe_str = '''
Expand Down
2 changes: 0 additions & 2 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ class SparseGPTModifier(Modifier):
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model
:param target_ids: list of keys in model output to cache
:param layer_prefix: name of model attribute that contains the list of layers, i.e.
model.decoder for OPT or just model for Llama
"""

sparsity: Union[float, List[float]]
Expand Down
13 changes: 9 additions & 4 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class SparseGPTModifierPyTorch(SparseGPTModifier):
model: Any = None
device_: str = "cuda:0"
finalization_kwargs_: Dict = None
layer_prefix_: Optional[str] = None

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Expand Down Expand Up @@ -85,6 +86,7 @@ def initialize_obcq(
"""
self.model = model
self.compressible_layers_ = self.compressible_layers()
self.layer_prefix_ = model.layer_prefix
self.model = self.model.model
self._set_device(device)

Expand All @@ -106,7 +108,7 @@ def apply_obcq(
extras = self.compress_bottom(
dev=self.device_,
target_ids=self.target_ids,
layer_prefix=self.layer_prefix,
layer_prefix=self.layer_prefix_,
**accum_kwargs,
)
accum_kwargs.update(extras)
Expand Down Expand Up @@ -166,17 +168,20 @@ def compress_bottom(
nsamples: int = None,
dev: str = "cuda:0",
target_ids: List[str] = None,
layer_prefix: str = None,
layer_prefix: Optional[str] = None,
) -> Dict:
"""
Runs calibration data through the bottom part of the network (everything up
to the first decoder layer) and return the captured outputs
:param dataloader: calibration data to pass through the model
:nsamples: number of samples to use for calibration, or None to use it all
:dev: device to use
:param nsamples: number of samples to use for calibration, or None to use it all
:param dev: device to use
:param layer_prefix: name of model attribute that contains the list of layers,
i.e. model.decoder for OPT or just model for Llama
:return: outputs from bottom part of network, attention mask, and kv-cache state
"""
layer_prefix = layer_prefix or self.layer_prefix_
cached_inputs = cache_attention_inputs(
self.model, dataloader, dev, nsamples, target_ids, layer_prefix
)
Expand Down
6 changes: 5 additions & 1 deletion src/sparseml/transformers/sparsification/obcq/example.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
metadata:
target_model:
layer_prefix: "decoder"
architecture: "opt"

test_stage:
obcq_modifiers:
SmoothQuantModifier:
Expand Down Expand Up @@ -52,4 +57,3 @@ test_stage:
"model.decoder.layers.23"
]
target_ids: ["attention_mask"]
layer_prefix: "decoder"
53 changes: 53 additions & 0 deletions tests/sparseml/core/lifecycle/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytest

import sparseml.core.session as sml
from sparseml.core import Framework
from sparseml.core.event import Event, EventType
from sparseml.core.lifecycle.event import CallbacksEventLifecycle
Expand All @@ -25,6 +26,58 @@
from sparseml.core.state import State


def recipe_with_layer_prefix():
layer_prefix = "decoder"
recipe = f"""
metadata:
target_model:
layer_prefix: {layer_prefix}
architecture: "opt"
test_stage:
pruning_modifiers:
ConstantPruningModifier:
targets: __ALL_PRUNABLE__
start: 0
end: 5
"""
return recipe, layer_prefix


def recipe_without_layer_prefix():
recipe = """
test_stage:
pruning_modifiers:
ConstantPruningModifier:
targets: __ALL_PRUNABLE__
start: 0
end: 5
"""
return recipe, None


@pytest.fixture
def model():
# identity model
return lambda x: x


@pytest.mark.parametrize(
"recipe, expected_layer_prefix",
[
recipe_without_layer_prefix(),
recipe_with_layer_prefix(),
],
)
def test_session_initialize_propagates_layer_prefix_to_model(
recipe, expected_layer_prefix, model
):
session = sml.active_session()
session.initialize(framework=Framework.general, model=model, recipe=recipe)
print(f"{session.state.model.layer_prefix=}, {expected_layer_prefix=}")
assert session.state.model.layer_prefix == expected_layer_prefix


class ModifierMock(ModifierInterface):
initialized_ = False

Expand Down
55 changes: 55 additions & 0 deletions tests/sparseml/core/recipe/test_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest

from sparseml.core.recipe.metadata import ModelMetaData, RecipeMetaData


class TestRecipeMetaData:
@pytest.mark.parametrize(
"self_metadata",
[
dict(domain="cv", task="classification"),
dict(),
],
)
@pytest.mark.parametrize(
"other_metadata",
[
dict(domain="domain", task="segmentation", requirements=["torch>=1.6.0"]),
dict(
domain="cv",
task="task",
target_model=ModelMetaData(layer_prefix="something"),
),
],
)
def test_update_missing_metadata(self, self_metadata, other_metadata):

metadata_a = RecipeMetaData(**self_metadata)
metadata_b = RecipeMetaData(**other_metadata)

metadata_a.update_missing_metadata(metadata_b)

all_keys = set(self_metadata.keys()).union(other_metadata.keys())

# keys should not be overwritten
# if they already exist
for key in all_keys:
if key in self_metadata:
assert getattr(metadata_a, key) == self_metadata[key]
elif key in other_metadata:
assert getattr(metadata_a, key) == other_metadata[key]

0 comments on commit aa74932

Please sign in to comment.