Skip to content

Commit

Permalink
Merge pull request #10 from CrazyBoyM/support_subfolder
Browse files Browse the repository at this point in the history
fix resolve_weight_file_from_hf_hub
  • Loading branch information
JunnYu committed Dec 27, 2023
2 parents 046f20d + 84dec4e commit a1ccabf
Showing 1 changed file with 103 additions and 111 deletions.
214 changes: 103 additions & 111 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from paddle.utils.download import is_url as is_remote_url
from tqdm.auto import tqdm

from paddlenlp.utils.downloader import get_path_from_url_with_filelock, hf_file_exists
from paddlenlp.utils.downloader import get_path_from_url_with_filelock
from paddlenlp.utils.env import (
CONFIG_NAME,
LEGACY_CONFIG_NAME,
Expand Down Expand Up @@ -366,50 +366,28 @@ def resolve_weight_file_from_hf_hub(
subfolder (str, optional) An optional value corresponding to a folder inside the repo.
"""
is_sharded = False

if use_safetensors:
# SAFE WEIGHTS
if hf_file_exists(repo_id, SAFE_WEIGHTS_INDEX_NAME, subfolder=subfolder):
file_name = SAFE_WEIGHTS_INDEX_NAME
is_sharded = True
elif hf_file_exists(repo_id, SAFE_WEIGHTS_NAME, subfolder=subfolder):
file_name = SAFE_WEIGHTS_NAME
else:
raise EntryNotFoundError(
message=f"can not find the safetensors weight file from: https://huggingface.co/{repo_id}",
response=None,
)
file_name_list = [
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
]
else:
if convert_from_torch:
# TORCH WEIGHTS
if hf_file_exists(repo_id, PYTORCH_WEIGHTS_INDEX_NAME, subfolder=subfolder):
file_name = PYTORCH_WEIGHTS_INDEX_NAME
is_sharded = True
elif hf_file_exists(repo_id, PYTORCH_WEIGHTS_NAME, subfolder=subfolder):
file_name = PYTORCH_WEIGHTS_NAME
else:
raise EntryNotFoundError(
message=f"can not find the pytorch weight file from: https://huggingface.co/{repo_id}",
response=None,
)
else:
if hf_file_exists(repo_id, PADDLE_WEIGHTS_INDEX_NAME, subfolder=subfolder):
file_name = PADDLE_WEIGHTS_INDEX_NAME
is_sharded = True
elif hf_file_exists(repo_id, PADDLE_WEIGHTS_NAME, subfolder=subfolder):
file_name = PADDLE_WEIGHTS_NAME
else:
raise EntryNotFoundError(
message=f"can not find the paddle weight file from: https://huggingface.co/{repo_id}",
response=None,
)

file_name_list = [file_name]
file_name_list = [
PYTORCH_WEIGHTS_INDEX_NAME,
PADDLE_WEIGHTS_INDEX_NAME,
PYTORCH_WEIGHTS_NAME,
PADDLE_WEIGHTS_NAME,
SAFE_WEIGHTS_NAME, # (NOTE,lxl): 兼容极端情况
]
resolved_file = None
for fn in file_name_list:
resolved_file = cached_file_for_hf_hub(
repo_id, fn, cache_dir, subfolder, _raise_exceptions_for_missing_entries=False
)
if resolved_file is not None:
if resolved_file.endswith(".json"):
is_sharded = True
break

if resolved_file is None:
Expand Down Expand Up @@ -1458,6 +1436,30 @@ def _resolve_model_file_path(
is_sharded = False
sharded_metadata = None

# -1. when it's from HF
if from_hf_hub or convert_from_torch:
resolved_archive_file, is_sharded = resolve_weight_file_from_hf_hub(
pretrained_model_name_or_path,
cache_dir=cache_dir,
convert_from_torch=convert_from_torch,
subfolder=subfolder,
use_safetensors=use_safetensors,
)
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
resolved_sharded_files = None
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_sharded_files, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
from_aistudio=from_aistudio,
from_hf_hub=from_hf_hub,
cache_dir=cache_dir,
subfolder=subfolder,
)

return resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded

if pretrained_model_name_or_path is not None:
# the following code use a lot of os.path.join, hence setting subfolder to empty str if None
if subfolder is None:
Expand Down Expand Up @@ -1561,95 +1563,85 @@ def get_file_path(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME, v
filename = pretrained_model_name_or_path
resolved_archive_file = get_path_from_url_with_filelock(pretrained_model_name_or_path)
else:
# -1. when it's from HF
if from_hf_hub:
resolved_archive_file, is_sharded = resolve_weight_file_from_hf_hub(
pretrained_model_name_or_path,

# set correct filename
if use_safetensors is not False:
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
else:
filename = _add_variant(PADDLE_WEIGHTS_NAME, variant)

try:
# Load from URL or cache if already cached
cached_file_kwargs = dict(
cache_dir=cache_dir,
convert_from_torch=convert_from_torch,
subfolder=subfolder,
use_safetensors=use_safetensors,
from_aistudio=from_aistudio,
_raise_exceptions_for_missing_entries=False,
)
else:
resolved_archive_file = None
if pretrained_model_name_or_path in cls.pretrained_init_configuration:
# fetch the weight url from the `pretrained_resource_files_map`
resource_file_url = cls.pretrained_resource_files_map["model_state"][
pretrained_model_name_or_path
]
resolved_archive_file = cached_file(
resource_file_url, _add_variant(PADDLE_WEIGHTS_NAME, variant), **cached_file_kwargs
)

if resolved_archive_file is None:
resolved_archive_file = cached_file(
pretrained_model_name_or_path, filename, **cached_file_kwargs
)

# set correct filename
if use_safetensors is not False:
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
else:
# xxx.pdparams in pretrained_resource_files_map renamed model_state.pdparams
filename = _add_variant(PADDLE_WEIGHTS_NAME, variant)

try:
# Load from URL or cache if already cached
cached_file_kwargs = dict(
cache_dir=cache_dir,
subfolder=subfolder,
from_aistudio=from_aistudio,
_raise_exceptions_for_missing_entries=False,
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path,
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)
resolved_archive_file = None
if pretrained_model_name_or_path in cls.pretrained_init_configuration:
# fetch the weight url from the `pretrained_resource_files_map`
resource_file_url = cls.pretrained_resource_files_map["model_state"][
pretrained_model_name_or_path
]
resolved_archive_file = cached_file(
resource_file_url, _add_variant(PADDLE_WEIGHTS_NAME, variant), **cached_file_kwargs
)

if resolved_archive_file is None:
resolved_archive_file = cached_file(
pretrained_model_name_or_path, filename, **cached_file_kwargs
if resolved_archive_file is not None:
is_sharded = True
elif use_safetensors:
raise EnvironmentError(
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
)

else:
# xxx.pdparams in pretrained_resource_files_map renamed model_state.pdparams
# This repo has no safetensors file of any kind, we switch to PyTorch.
filename = _add_variant(PADDLE_WEIGHTS_NAME, variant)

# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path,
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)
if resolved_archive_file is not None:
is_sharded = True
elif use_safetensors:
raise EnvironmentError(
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
)
else:
# This repo has no safetensors file of any kind, we switch to PyTorch.
filename = _add_variant(PADDLE_WEIGHTS_NAME, variant)
resolved_archive_file = cached_file(
pretrained_model_name_or_path, filename, **cached_file_kwargs
)
if resolved_archive_file is None and filename == _add_variant(PADDLE_WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path,
_add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)
# raise ValueError(resolved_archive_file)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {_add_variant(PADDLE_WEIGHTS_NAME, variant)}."
pretrained_model_name_or_path, filename, **cached_file_kwargs
)
except Exception as e:
logger.info(e)
# For any other exception, we throw a generic error.
if resolved_archive_file is None and filename == _add_variant(PADDLE_WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path,
_add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)
# raise ValueError(resolved_archive_file)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://paddlenlp.bj.bcebos.com'"
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {_add_variant(PADDLE_WEIGHTS_NAME, variant)}."
)
except Exception as e:
logger.info(e)
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://paddlenlp.bj.bcebos.com'"
)

if is_local:
logger.info(f"Loading weights file {archive_file}")
Expand Down

0 comments on commit a1ccabf

Please sign in to comment.