diff --git a/src/sparseml/transformers/sparsification/obcq/obcq.py b/src/sparseml/transformers/sparsification/obcq/obcq.py index 7a60b14a5b8..8a973cdb803 100644 --- a/src/sparseml/transformers/sparsification/obcq/obcq.py +++ b/src/sparseml/transformers/sparsification/obcq/obcq.py @@ -19,6 +19,7 @@ from typing import Optional from torch.nn import Module +from transformers import AutoConfig import sparseml.core.session as session_manager from sparseml.core.framework import Framework @@ -36,7 +37,7 @@ _LOGGER = logging.getLogger(__name__) SUPPORTED_DATASETS = ["wikitext2", "ptb", "c4", "open_platypus"] -SUPPORTED_MODELS = ["opt", "llama"] +SUPPORTED_MODELS = ["opt", "llama", "mistral"] def one_shot( @@ -70,14 +71,21 @@ def one_shot( if deploy_dir.exists(): raise RuntimeError(f"deploy_dir={deploy_dir} already exists") + # Load the configuration from the model path + config = AutoConfig.from_pretrained(model_path) + model_type = config.model_type.lower() + model_loader_fn = None forward_fn = None - if "opt" in model_path.lower(): + if "opt" in model_type: model_loader_fn = SparseCasualLM.opt_model_from_pretrained forward_fn = opt_forward - elif "llama" in model_path.lower(): + elif "llama" in model_type: model_loader_fn = SparseCasualLM.llama_model_from_pretrained forward_fn = llama_forward + elif "mistral" in model_type: + model_loader_fn = SparseCasualLM.auto_model_from_pretrained + forward_fn = llama_forward else: raise ValueError(f"model_path={model_path} should be one of {SUPPORTED_MODELS}") model = model_loader_fn(model_path) diff --git a/src/sparseml/transformers/utils/model.py b/src/sparseml/transformers/utils/model.py index 3f89f7b4127..4e23d91c1e7 100644 --- a/src/sparseml/transformers/utils/model.py +++ b/src/sparseml/transformers/utils/model.py @@ -462,6 +462,19 @@ def llama_model_from_pretrained(model_path: str) -> torch.nn.Module: model.seqlen = model.config.max_position_embeddings return model + @staticmethod + def auto_model_from_pretrained(model_path: str) -> torch.nn.Module: + """ + Load a pretrained model using auto from the specified hugging face path + + :param model_path: hugging face path to model + :return: loaded pretrained model + """ + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto") + model.eval() + model.seqlen = model.config.max_position_embeddings + return model + def get_shared_tokenizer_src(student: Module, teacher: Optional[Module]) -> str: """