diff --git a/src/xlora/xlora.py b/src/xlora/xlora.py index c9cc63d..97ffc71 100644 --- a/src/xlora/xlora.py +++ b/src/xlora/xlora.py @@ -1,4 +1,5 @@ import json +import os from typing import Dict, List, Optional, Union import peft @@ -230,6 +231,7 @@ def from_pretrained( device: str, verbose: bool = False, from_safetensors: bool = True, + hf_hub_subdir: Optional[str] = None, ) -> xLoRAModel: """ Loads a pretrained classifier and potentially adapters from the specified folder while initializing the model. This is the counterpart to `save_pretrained`. @@ -252,12 +254,15 @@ def from_pretrained( Device of the model, used to load the classifier. from_safetensors (`bool`, *optional*, defaults to True): Whether to load the classifier weights from a .pt or .safetensors file. + hf_hub_subdir (`str`, *optional*, defaults to None): + If `xlora_path` is a HF model repo ID, specify a subdirectory where the weights may be found. + Returns: model (`xLoRAModel`): The new model. """ - with open(xlora_utils._get_file_path(load_directory, "xlora_config.json"), "r") as f: + with open(xlora_utils._get_file_path(load_directory, "xlora_config.json", hf_hub_subdir), "r") as f: conf = json.load(f) conf["device"] = torch.device(device) @@ -267,7 +272,12 @@ def from_pretrained( if use_trainable_adapters: adapters_dict: Dict[str, str] = { - name: xlora_utils._get_file_path_dir(load_directory, name, "adapters") for name in adapters + name: xlora_utils._get_file_path( + load_directory, + name, + os.path.join(hf_hub_subdir, "adapters") if hf_hub_subdir is not None else "adapters", + ) + for name in adapters } else: assert isinstance(adapters, dict) @@ -278,11 +288,11 @@ def from_pretrained( if from_safetensors: state_dict = load_model( classifier, - xlora_utils._get_file_path(load_directory, "xlora_classifier.safetensors"), + xlora_utils._get_file_path(load_directory, "xlora_classifier.safetensors", hf_hub_subdir), ) classifier.to(device) else: - state_dict = torch.load(xlora_utils._get_file_path(load_directory, "xlora_classifier.pt")) + state_dict = torch.load(xlora_utils._get_file_path(load_directory, "xlora_classifier.pt", hf_hub_subdir)) classifier.load_state_dict(state_dict) # type: ignore return model_peft diff --git a/src/xlora/xlora_utils.py b/src/xlora/xlora_utils.py index 77700bc..8c2c3b4 100644 --- a/src/xlora/xlora_utils.py +++ b/src/xlora/xlora_utils.py @@ -14,7 +14,7 @@ from .xlora import from_pretrained, xLoRAModel # type: ignore -def _get_file_path( +def _get_file_path_single( load_directory: str, name: str, ) -> str: @@ -29,6 +29,12 @@ def _get_file_path_dir(load_directory: str, name: str, dir: str) -> str: return huggingface_hub.hf_hub_download(load_directory, filename=name, subfolder=dir) +def _get_file_path(load_directory: str, name: str, dir: Optional[str]) -> str: + if dir is not None: + return _get_file_path_dir(load_directory, name, dir) + return _get_file_path_single(load_directory, name) + + def load_model( model_name: str, xlora_path: Optional[str], @@ -39,6 +45,7 @@ def load_model( load_xlora: bool = True, verbose: bool = False, from_safetensors: bool = True, + hf_hub_subdir: Optional[str] = None, ) -> Tuple[Union[AutoModelForCausalLM, xLoRAModel], Union[PreTrainedTokenizer, PreTrainedTokenizerFast]]: """ Convenience function to load a model, converting it to xLoRA if specified. @@ -47,7 +54,7 @@ def load_model( model_name (`str`): AutoModelForCausalLM pretrained model name or path xlora_path (`str`, *optional*): - Directory to load the xLoRAClassifier from. + Directory or HF model repo ID to load the xLoRAClassifier from. device (`str`): Device to load the base model and the xLoRA model to. dtype (`torch.dtype`): @@ -63,6 +70,8 @@ def load_model( Enable verbose loading. from_safetensors (`bool`, *optional*, defaults to True): Whether to load the classifier weights from a .pt or .safetensors file. + hf_hub_subdir (`str`, *optional*, defaults to None): + If `xlora_path` is a HF model repo ID, specify a subdirectory where the weights may be found. Returns: Tuple whose elements are respectively: