Skip to content

Commit

Permalink
fix: Handle NaN values during weight conversion (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jan 9, 2024
1 parent d88ffed commit 112aeec
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
6 changes: 6 additions & 0 deletions server/lorax_server/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import List, Dict
from collections import defaultdict

from lorax_server.utils.errors import InfWeightsError, NanWeightsError


def _remove_duplicate_names(
state_dict: Dict[str, torch.Tensor],
Expand Down Expand Up @@ -90,6 +92,10 @@ def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
if torch.any(torch.isnan(pt_tensor)):
raise NanWeightsError(f"Weights unusuable as param {k} in file {pt_file} contains NaN values")
if torch.any(torch.isinf(pt_tensor)):
raise InfWeightsError(f"Weights unusuable as param {k} in file {pt_file} contains inf values")
raise RuntimeError(f"The output tensors do not match for key {k}")


Expand Down
6 changes: 6 additions & 0 deletions server/lorax_server/utils/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class NanWeightsError(RuntimeError):
pass


class InfWeightsError(RuntimeError):
pass
30 changes: 30 additions & 0 deletions server/tests/utils/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from pathlib import Path
import pytest
import torch
from lorax_server.utils.sources.hub import (
download_weights,
weight_hub_files,
weight_files,
)

from lorax_server.utils.convert import convert_files
from lorax_server.utils.errors import NanWeightsError


def test_convert_files():
Expand All @@ -19,3 +23,29 @@ def test_convert_files():
found_st_files = weight_files(model_id)

assert all([p in found_st_files for p in local_st_files])


def test_convert_files_nan_error(tmpdir):
model_id = "bigscience/bloom-560m"
pt_filenames = weight_hub_files(model_id, extension=".bin")
local_pt_files = download_weights(pt_filenames, model_id)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
]

# Introduce NaN to the first tensor in the first file
pt_file = local_pt_files[0]
with open(pt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu")
state_dict[list(state_dict.keys())[0]].fill_(float("nan"))

# Write the corrupted state to a new temporary file
pt_file = Path(tmpdir) / pt_file.name
with open(pt_file, "wb") as f:
torch.save(state_dict, f)

# Replace the first file with the corrupted file
local_pt_files[0] = pt_file

with pytest.raises(NanWeightsError):
convert_files(local_pt_files, local_st_files, discard_names=[])

0 comments on commit 112aeec

Please sign in to comment.