Skip to content

Commit

Permalink
[Model] Add Gemma 2 (vllm-project#5908)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and prashantgupta24 committed Jun 28, 2024
1 parent 087025d commit 2213a43
Show file tree
Hide file tree
Showing 9 changed files with 499 additions and 9 deletions.
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ Alongside each architecture, we include some popular models that use it.
- Gemma
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
- ✅︎
* - :code:`Gemma2ForCausalLM`
- Gemma2
- :code:`google/gemma-2-9b`, :code:`google/gemma-2-27b`, etc.
- ✅︎
* - :code:`GPT2LMHeadModel`
- GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc.
Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ numpy < 2.0.0
requests
tqdm
py-cpuinfo
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
transformers >= 4.42.0 # Required for Gemma 2.
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
aiohttp
Expand Down
30 changes: 23 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_tpu, is_xpu,
is_hip, is_neuron, is_tpu, is_xpu, print_warning_once,
update_environment_variables)

if TYPE_CHECKING:
Expand Down Expand Up @@ -141,6 +141,17 @@ def __init__(
code_revision, rope_scaling, rope_theta)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)

if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None):
print_warning_once(
"Gemma 2 uses sliding window attention for every odd layer, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f"({self.hf_text_config.sliding_window}).")
self.disable_sliding_window = True

self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
Expand Down Expand Up @@ -257,8 +268,7 @@ def verify_with_parallel_config(
"BitAndBytes quantization with TP or PP is not supported yet.")

def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
"""Get the sliding window size, or None if disabled."""

# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
Expand Down Expand Up @@ -1256,10 +1266,16 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32
# models.
logger.info("Casting torch.float32 to torch.float16.")
torch_dtype = torch.float16
if config.model_type == "gemma2":
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
Expand Down
4 changes: 4 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,10 @@ def vocab_size(self):
def scale(self):
return self.base_layer.scale

@property
def soft_cap(self):
return self.base_layer.soft_cap

@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
Expand Down
46 changes: 46 additions & 0 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,49 @@ def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}"
return s


class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma.
Two differences from the above RMSNorm:
1. x * (1 + w) instead of x * w.
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
"""

def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps

def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
if residual is not None:
x = x + residual
residual = x

x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x = x * (1.0 + self.weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual)

def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
return self.forward_native(x, residual)
10 changes: 9 additions & 1 deletion vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: float = 1.0,
logits_as_input: bool = False) -> None:
logits_as_input: bool = False,
soft_cap: Optional[float] = None) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
Expand All @@ -34,6 +35,8 @@ def __init__(self,
self.logits_as_input = logits_as_input
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap

def forward(
self,
Expand All @@ -52,6 +55,11 @@ def forward(
logits = self._get_logits(hidden_states, embedding, embedding_bias)

if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
logits = torch.tanh(logits)
logits = logits * self.soft_cap

if self.scale != 1.0:
logits *= self.scale

Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,16 @@ def forward(
return query.flatten(-2), key.flatten(-2)


class GemmaRotaryEmbedding(RotaryEmbedding):

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq = 1.0 / (base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
self.rotary_dim))
return inv_freq


_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}


Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
Expand Down
Loading

0 comments on commit 2213a43

Please sign in to comment.