Skip to content

Commit

Permalink
Add hf_hub_subdir to loading fns
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Feb 10, 2024
1 parent 8517309 commit 145b000
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
18 changes: 14 additions & 4 deletions src/xlora/xlora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
from typing import Dict, List, Optional, Union

import peft
Expand Down Expand Up @@ -230,6 +231,7 @@ def from_pretrained(
device: str,
verbose: bool = False,
from_safetensors: bool = True,
hf_hub_subdir: Optional[str] = None,
) -> xLoRAModel:
"""
Loads a pretrained classifier and potentially adapters from the specified folder while initializing the model. This is the counterpart to `save_pretrained`.
Expand All @@ -252,12 +254,15 @@ def from_pretrained(
Device of the model, used to load the classifier.
from_safetensors (`bool`, *optional*, defaults to True):
Whether to load the classifier weights from a .pt or .safetensors file.
hf_hub_subdir (`str`, *optional*, defaults to None):
If `xlora_path` is a HF model repo ID, specify a subdirectory where the weights may be found.
Returns:
model (`xLoRAModel`):
The new model.
"""

with open(xlora_utils._get_file_path(load_directory, "xlora_config.json"), "r") as f:
with open(xlora_utils._get_file_path(load_directory, "xlora_config.json", hf_hub_subdir), "r") as f:
conf = json.load(f)
conf["device"] = torch.device(device)

Expand All @@ -267,7 +272,12 @@ def from_pretrained(

if use_trainable_adapters:
adapters_dict: Dict[str, str] = {
name: xlora_utils._get_file_path_dir(load_directory, name, "adapters") for name in adapters
name: xlora_utils._get_file_path(
load_directory,
name,
os.path.join(hf_hub_subdir, "adapters") if hf_hub_subdir is not None else "adapters",
)
for name in adapters
}
else:
assert isinstance(adapters, dict)
Expand All @@ -278,11 +288,11 @@ def from_pretrained(
if from_safetensors:
state_dict = load_model(
classifier,
xlora_utils._get_file_path(load_directory, "xlora_classifier.safetensors"),
xlora_utils._get_file_path(load_directory, "xlora_classifier.safetensors", hf_hub_subdir),
)
classifier.to(device)
else:
state_dict = torch.load(xlora_utils._get_file_path(load_directory, "xlora_classifier.pt"))
state_dict = torch.load(xlora_utils._get_file_path(load_directory, "xlora_classifier.pt", hf_hub_subdir))
classifier.load_state_dict(state_dict) # type: ignore

return model_peft
13 changes: 11 additions & 2 deletions src/xlora/xlora_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .xlora import from_pretrained, xLoRAModel # type: ignore


def _get_file_path(
def _get_file_path_single(
load_directory: str,
name: str,
) -> str:
Expand All @@ -29,6 +29,12 @@ def _get_file_path_dir(load_directory: str, name: str, dir: str) -> str:
return huggingface_hub.hf_hub_download(load_directory, filename=name, subfolder=dir)


def _get_file_path(load_directory: str, name: str, dir: Optional[str]) -> str:
if dir is not None:
return _get_file_path_dir(load_directory, name, dir)
return _get_file_path_single(load_directory, name)


def load_model(
model_name: str,
xlora_path: Optional[str],
Expand All @@ -39,6 +45,7 @@ def load_model(
load_xlora: bool = True,
verbose: bool = False,
from_safetensors: bool = True,
hf_hub_subdir: Optional[str] = None,
) -> Tuple[Union[AutoModelForCausalLM, xLoRAModel], Union[PreTrainedTokenizer, PreTrainedTokenizerFast]]:
"""
Convenience function to load a model, converting it to xLoRA if specified.
Expand All @@ -47,7 +54,7 @@ def load_model(
model_name (`str`):
AutoModelForCausalLM pretrained model name or path
xlora_path (`str`, *optional*):
Directory to load the xLoRAClassifier from.
Directory or HF model repo ID to load the xLoRAClassifier from.
device (`str`):
Device to load the base model and the xLoRA model to.
dtype (`torch.dtype`):
Expand All @@ -63,6 +70,8 @@ def load_model(
Enable verbose loading.
from_safetensors (`bool`, *optional*, defaults to True):
Whether to load the classifier weights from a .pt or .safetensors file.
hf_hub_subdir (`str`, *optional*, defaults to None):
If `xlora_path` is a HF model repo ID, specify a subdirectory where the weights may be found.
Returns:
Tuple whose elements are respectively:
Expand Down

0 comments on commit 145b000

Please sign in to comment.