diff --git a/server/lorax_server/utils/convert.py b/server/lorax_server/utils/convert.py index f911e5b5..f02703be 100644 --- a/server/lorax_server/utils/convert.py +++ b/server/lorax_server/utils/convert.py @@ -8,6 +8,8 @@ from typing import List, Dict from collections import defaultdict +from lorax_server.utils.errors import InfWeightsError, NanWeightsError + def _remove_duplicate_names( state_dict: Dict[str, torch.Tensor], @@ -90,6 +92,10 @@ def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]): pt_tensor = loaded[k] sf_tensor = reloaded[k] if not torch.equal(pt_tensor, sf_tensor): + if torch.any(torch.isnan(pt_tensor)): + raise NanWeightsError(f"Weights unusuable as param {k} in file {pt_file} contains NaN values") + if torch.any(torch.isinf(pt_tensor)): + raise InfWeightsError(f"Weights unusuable as param {k} in file {pt_file} contains inf values") raise RuntimeError(f"The output tensors do not match for key {k}") diff --git a/server/lorax_server/utils/errors.py b/server/lorax_server/utils/errors.py new file mode 100644 index 00000000..a02a1957 --- /dev/null +++ b/server/lorax_server/utils/errors.py @@ -0,0 +1,6 @@ +class NanWeightsError(RuntimeError): + pass + + +class InfWeightsError(RuntimeError): + pass diff --git a/server/tests/utils/test_convert.py b/server/tests/utils/test_convert.py index 62c3cb01..3c1ef049 100644 --- a/server/tests/utils/test_convert.py +++ b/server/tests/utils/test_convert.py @@ -1,3 +1,6 @@ +from pathlib import Path +import pytest +import torch from lorax_server.utils.sources.hub import ( download_weights, weight_hub_files, @@ -5,6 +8,7 @@ ) from lorax_server.utils.convert import convert_files +from lorax_server.utils.errors import NanWeightsError def test_convert_files(): @@ -19,3 +23,29 @@ def test_convert_files(): found_st_files = weight_files(model_id) assert all([p in found_st_files for p in local_st_files]) + + +def test_convert_files_nan_error(tmpdir): + model_id = "bigscience/bloom-560m" + pt_filenames = weight_hub_files(model_id, extension=".bin") + local_pt_files = download_weights(pt_filenames, model_id) + local_st_files = [ + p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files + ] + + # Introduce NaN to the first tensor in the first file + pt_file = local_pt_files[0] + with open(pt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu") + state_dict[list(state_dict.keys())[0]].fill_(float("nan")) + + # Write the corrupted state to a new temporary file + pt_file = Path(tmpdir) / pt_file.name + with open(pt_file, "wb") as f: + torch.save(state_dict, f) + + # Replace the first file with the corrupted file + local_pt_files[0] = pt_file + + with pytest.raises(NanWeightsError): + convert_files(local_pt_files, local_st_files, discard_names=[])