Skip to content

Commit

Permalink
[Model] Support Nemotron models (Nemotron-3, Nemotron-4, Minitron) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Jul 26, 2024
1 parent 85ad7e2 commit 07278c3
Show file tree
Hide file tree
Showing 9 changed files with 776 additions and 1 deletion.
11 changes: 11 additions & 0 deletions .buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nvidia/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1
model_name: "nvidia/Minitron-4B-Base"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.252
- name: "exact_match,flexible-extract"
value: 0.252
limit: 1000
num_fewshot: 5
1 change: 1 addition & 0 deletions .buildkite/lm-eval-harness/configs/models-small.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Minitron-4B-Base.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
16 changes: 16 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,21 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


class ReLUSquaredActivation(CustomOp):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
relu_applied = nn.functional.relu(x)
squared = torch.square(relu_applied)
return squared

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)


class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
Expand Down Expand Up @@ -207,6 +222,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
"relu2": ReLUSquaredActivation(),
"quick_gelu": QuickGELU(),
}

Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ def get_rope(
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
rotary_percent: float = 1.0,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
Expand All @@ -786,6 +787,8 @@ def get_rope(
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if rotary_percent < 1.0:
rotary_dim = int(rotary_dim * rotary_percent)
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args, dtype)
if key in _ROPE_DICT:
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
Expand Down
Loading

0 comments on commit 07278c3

Please sign in to comment.