Skip to content

Commit

Permalink
Add mistral to obcq (#1798)
Browse files Browse the repository at this point in the history
* Add mistral to obcq

* Update with model_type from config
  • Loading branch information
mgoin authored and bfineran committed Nov 16, 2023
1 parent f8b9372 commit 0eea408
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/sparseml/transformers/sparsification/obcq/obcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional

from torch.nn import Module
from transformers import AutoConfig

import sparseml.core.session as session_manager
from sparseml.core.framework import Framework
Expand All @@ -36,7 +37,7 @@

_LOGGER = logging.getLogger(__name__)
SUPPORTED_DATASETS = ["wikitext2", "ptb", "c4", "open_platypus"]
SUPPORTED_MODELS = ["opt", "llama"]
SUPPORTED_MODELS = ["opt", "llama", "mistral"]


def one_shot(
Expand Down Expand Up @@ -70,14 +71,21 @@ def one_shot(
if deploy_dir.exists():
raise RuntimeError(f"deploy_dir={deploy_dir} already exists")

# Load the configuration from the model path
config = AutoConfig.from_pretrained(model_path)
model_type = config.model_type.lower()

model_loader_fn = None
forward_fn = None
if "opt" in model_path.lower():
if "opt" in model_type:
model_loader_fn = SparseCasualLM.opt_model_from_pretrained
forward_fn = opt_forward
elif "llama" in model_path.lower():
elif "llama" in model_type:
model_loader_fn = SparseCasualLM.llama_model_from_pretrained
forward_fn = llama_forward
elif "mistral" in model_type:
model_loader_fn = SparseCasualLM.auto_model_from_pretrained
forward_fn = llama_forward
else:
raise ValueError(f"model_path={model_path} should be one of {SUPPORTED_MODELS}")
model = model_loader_fn(model_path)
Expand Down
13 changes: 13 additions & 0 deletions src/sparseml/transformers/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,19 @@ def llama_model_from_pretrained(model_path: str) -> torch.nn.Module:
model.seqlen = model.config.max_position_embeddings
return model

@staticmethod
def auto_model_from_pretrained(model_path: str) -> torch.nn.Module:
"""
Load a pretrained model using auto from the specified hugging face path
:param model_path: hugging face path to model
:return: loaded pretrained model
"""
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto")
model.eval()
model.seqlen = model.config.max_position_embeddings
return model


def get_shared_tokenizer_src(student: Module, teacher: Optional[Module]) -> str:
"""
Expand Down

0 comments on commit 0eea408

Please sign in to comment.