Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add Gemma 2 #5908

Merged
merged 6 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ Alongside each architecture, we include some popular models that use it.
- Gemma
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
- ✅︎
* - :code:`Gemma2ForCausalLM`
- Gemma2
- :code:`google/gemma-2-9b`, :code:`google/gemma-2-27b`, etc.
- ✅︎
* - :code:`GPT2LMHeadModel`
- GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc.
Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ numpy < 2.0.0
requests
tqdm
py-cpuinfo
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
transformers >= 4.42.0 # Required for Gemma 2.
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
aiohttp
Expand Down
29 changes: 22 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_tpu, is_xpu,
is_hip, is_neuron, is_tpu, is_xpu, print_warning_once,
update_environment_variables)

if TYPE_CHECKING:
Expand Down Expand Up @@ -257,8 +257,17 @@ def verify_with_parallel_config(
"BitAndBytes quantization with TP or PP is not supported yet.")

def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
"""Get the sliding window size, or None if disabled."""

if (self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None):
print_warning_once(
"While Gemma 2 uses sliding window attention for every odd "
"layer, vLLM currently ignores it and uses global attention "
"for all layers. This might affect the model's behavior when "
"the context length is larger than the sliding window size "
f"({self.hf_text_config.sliding_window}).")
return None

# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
Expand Down Expand Up @@ -1252,10 +1261,16 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32
# models.
logger.info("Casting torch.float32 to torch.float16.")
torch_dtype = torch.float16
if config.model_type == "gemma2":
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
else:
torch_dtype = config_dtype
else:
Expand Down
46 changes: 46 additions & 0 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,49 @@ def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}"
return s


class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma.
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved

Two differences from the above RMSNorm:
1. x * (1 + w) instead of x * w.
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
"""

def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps

def forward_native(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should try decorating this with @torch.compile, similar to what we do in Command R

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was thinking about it or writing a CUDA kernel. Let's discuss this in another PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good -- agree it should be in a different PR :)

self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
if residual is not None:
x = x + residual
residual = x

x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x = x * (1.0 + self.weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual)

def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
return self.forward_native(x, residual)
10 changes: 9 additions & 1 deletion vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: float = 1.0,
logits_as_input: bool = False) -> None:
logits_as_input: bool = False,
soft_cap: Optional[float] = None) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
Expand All @@ -34,6 +35,8 @@ def __init__(self,
self.logits_as_input = logits_as_input
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap

def forward(
self,
Expand All @@ -52,6 +55,11 @@ def forward(
logits = self._get_logits(hidden_states, embedding, embedding_bias)

if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
logits = torch.tanh(logits)
logits = logits * self.soft_cap

if self.scale != 1.0:
logits *= self.scale

Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,16 @@ def forward(
return query.flatten(-2), key.flatten(-2)


class GemmaRotaryEmbedding(RotaryEmbedding):
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq = 1.0 / (base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
self.rotary_dim))
return inv_freq


_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}


Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
Expand Down
Loading
Loading