Skip to content

Commit

Permalink
[draft] proposed fix for incorrect mask application in FSDP (#1807)
Browse files Browse the repository at this point in the history
* [draft] proposed fix for incorrect mask application in FSDP

* fix for multi-gpu

* fix for hanging model save

* clean up

---------

Co-authored-by: Sara Adkins <[email protected]>
  • Loading branch information
bfineran and Sara Adkins authored Nov 14, 2023
1 parent ab69c11 commit f220740
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 23 deletions.
19 changes: 11 additions & 8 deletions src/sparseml/modifiers/pruning/constant/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

from typing import Dict

import torch

from sparseml.core import Event, EventType, ModelParameterizedLayer, State
from sparseml.modifiers.pruning.constant.base import ConstantPruningModifier
from sparseml.modifiers.pruning.utils.pytorch import LayerParamMasking
from sparseml.modifiers.pruning.utils.pytorch import LayerParamMasking, param_mask_name


class ConstantPruningModifierPyTorch(ConstantPruningModifier, LayerParamMasking):
Expand Down Expand Up @@ -59,18 +61,19 @@ def on_start(self, state: State, event: Event, **kwargs):

self.enable_masks()

@torch.no_grad()
def on_update(self, state: State, event: Event, **kwargs):
if self._use_hooks:
# hooks are used to update, so nothing to do here
return
if event.type_ == EventType.OPTIM_POST_STEP:

def apply_masks(module):
mask_name = param_mask_name()
if hasattr(module, mask_name):
module.weight *= getattr(module, mask_name)

if event.type_ == EventType.OPTIM_PRE_STEP:
for layer_param_name, _ in self.parameterized_layers_.items():
self.apply_mask_gradient(layer_param_name)
elif event.type_ == EventType.OPTIM_POST_STEP:
for layer_param_name, _ in self.parameterized_layers_.items():
self.apply_mask_weight(layer_param_name)
state.model.model.apply(apply_masks)

def on_end(self, state: State, event: Event, **kwargs):
# print(self._masked_layer_params['model.layers.5.self_attn.q_proj'].param.data[0][:5])
self.disable_masks()
16 changes: 6 additions & 10 deletions src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
from sparseml.core import ModelParameterizedLayer


__all__ = ["LayerParamMasking"]
__all__ = ["LayerParamMasking", "param_mask_name"]


def param_mask_name(param_name: str) -> str:
valid_name = param_name.replace(".", "_")
return f"{valid_name}_mask"
def param_mask_name() -> str:
return "mask"


def setup_mask_for_param(param: Parameter, mask: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -70,7 +69,7 @@ def add_mask(
if layer_param_name in self._masked_layer_params:
raise ValueError(f"Layer param {layer_param_name} already has a mask")

mask_name = param_mask_name(parameterized_layer.param_name)
mask_name = param_mask_name()

try:
parameterized_layer.layer.get_buffer(mask_name)
Expand Down Expand Up @@ -126,7 +125,7 @@ def update_mask(
mask: torch.Tensor,
):
parameterized_layer = self._masked_layer_params[layer_param_name]
mask_name = param_mask_name(parameterized_layer.param_name)
mask_name = param_mask_name()
mask_tensor = parameterized_layer.layer.get_buffer(mask_name)
mask_tensor[:] = mask

Expand All @@ -137,7 +136,7 @@ def remove_mask(self, layer_param_name: str):
if not mask_settings.persistent:
delattr(
parameterized_layer.layer,
param_mask_name(parameterized_layer.param_name),
param_mask_name(),
)

del self._masked_layer_params[layer_param_name]
Expand All @@ -155,9 +154,6 @@ def apply_mask_weight(self, layer_param_name: str):
return

parameterized_layer = self._masked_layer_params[layer_param_name]
if layer_param_name == "model.layers.5.mlp.down_proj.weight":
print(parameterized_layer.param.data)

mask_name = param_mask_name(parameterized_layer.param_name)
mask = parameterized_layer.layer.get_buffer(mask_name)
parameterized_layer.param.data = parameterized_layer.param.data * mask
Expand Down
18 changes: 18 additions & 0 deletions src/sparseml/transformers/finetune/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@


class PostOptimCallback(TrainerCallback):
def __init__(self, trainer, *args, **kwargs):
super().__init__(*args, **kwargs)
self.trainer = trainer

def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""
Event called at the beginning of training.
"""
super().on_train_begin(args, state, control, **kwargs)
session = sml.active_session()
session.state.model.model = self.trainer.model

def on_step_end(
self,
args: TrainingArguments,
Expand Down
7 changes: 4 additions & 3 deletions src/sparseml/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
sml.create_session()

super().__init__(model=model, **kwargs)
self.optim_callbacks = PostOptimCallback()
self.optim_callbacks = PostOptimCallback(self)
self.callback_handler.add_callback(self.optim_callbacks)
self.callback_disable_fp16 = DisableHalfPrecisionCallback(self)
self.callback_handler.add_callback(self.callback_disable_fp16)
Expand Down Expand Up @@ -255,8 +255,9 @@ def train(self, *args, **kwargs):
self.finalize_session()

self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
print("logging sparsification")
from torch.distributed.fsdp import FullyShardedDataParallel

with FullyShardedDataParallel.summon_full_params(self.model):
self.log_model_sparsification()

return output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_constant_pruning_modifier_e2e(model, optimizer):
# check mask is added and has correct sparsity

for _, parameterized_layer in modifier.parameterized_layers_.items():
mask_name = param_mask_name(parameterized_layer.param_name)
mask_name = param_mask_name()
mask_tensor = parameterized_layer.layer.get_buffer(mask_name)
data_tensor = parameterized_layer.param.data
# check mask and data tensors have 0 in the same places
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_constant_pruning_modifier_e2e(model, optimizer):

# check mask is removed
for layer_param_name, parameterized_layer in modifier.parameterized_layers_.items():
mask_name = param_mask_name(parameterized_layer.param_name)
mask_name = param_mask_name()

if not old_mask_settings[layer_param_name].persistent:
assert not hasattr(parameterized_layer.layer, mask_name)
Expand Down

0 comments on commit f220740

Please sign in to comment.