Skip to content

Commit

Permalink
Revert "fixes to properly shard FSDP across cpu and meta for cpu_effc…
Browse files Browse the repository at this point in the history
…ient_loading for prequantized 4bit (#32276)" (#32477)

* Revert "fixes to properly shard FSDP across cpu and meta for cpu_efficient_loading for prequantized 4bit (#32276)"

This reverts commit 62c60a3.

We uncovered an issue with this change that caused our training runs to hang.

* `is_torchdynamo_compiling` -- cast a wide exception net (#32476)

* cast a wide net

* make fix-copies with a few manual changes

* add copied from

---------

Co-authored-by: Joao Gante <[email protected]>
  • Loading branch information
matthewdouglas and gante authored Aug 6, 2024
1 parent 4fdc702 commit ac2707e
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 12 deletions.
7 changes: 1 addition & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,8 +933,6 @@ def _load_state_dict_into_meta_model(
)
)
):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
# For backward compatibility with older versions of `accelerate` and for non-quantized params
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else:
Expand All @@ -945,10 +943,7 @@ def _load_state_dict_into_meta_model(
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, tensor_name = get_module_from_name(model, param_name)
value = getattr(module, tensor_name)
param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
value = type(value)(value.data.to(param_to), **value.__dict__)
value = type(value)(value.data.to("cpu"), **value.__dict__)
setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return

Expand Down
6 changes: 0 additions & 6 deletions src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from packaging import version
Expand Down Expand Up @@ -208,16 +207,11 @@ def create_quantized_param(
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)

param_kwargs = {}
sig = inspect.signature(bnb.nn.Params4bit.from_prequantized)
if "module" in sig.parameters:
param_kwargs["module"] = module
new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=target_device,
**param_kwargs,
)
else:
new_value = param_value.to("cpu")
Expand Down

0 comments on commit ac2707e

Please sign in to comment.