Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update gemma for trt-llm 0.9 #8974

Merged
merged 8 commits into from
Apr 24, 2024
89 changes: 2 additions & 87 deletions nemo/export/trt_llm/decoder/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
from typing import Optional

from tensorrt_llm.functional import non_gated_version
from tensorrt_llm.layers import Attention, AttentionMaskType, GatedMLP, PositionEmbeddingType, RmsNorm
from tensorrt_llm.models.gemma.model import GemmaDecoderLayer, QuantConfig
from tensorrt_llm.models.modeling_utils import PretrainedConfig
from tensorrt_llm.module import Module
from tensorrt_llm.quantization import QuantMode
from typing_extensions import override

from nemo.export.trt_llm.decoder.decoder import DecoderLayerBuilder, DecoderLayerConfigBuilder
Expand All @@ -32,88 +30,6 @@
)


class GemmaDecoderLayer(Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.config = config

self.input_layernorm = RmsNorm(
normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype
)

self.attention = Attention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
attention_head_size=config.head_size,
max_position_embeddings=config.max_position_embeddings,
dtype=config.dtype,
attention_mask_type=AttentionMaskType.causal,
bias=config.attn_bias,
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
rotary_embedding_base=config.rotary_base,
rotary_embedding_scaling=config.rotary_scaling,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
quant_mode=config.quant_mode,
)

mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size

self.mlp = GatedMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=mlp_hidden_size,
hidden_act=config.hidden_act,
dtype=config.dtype,
bias=config.mlp_bias,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
quant_mode=config.quant_mode,
)
self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype)

def forward(
self,
hidden_states,
attention_mask=None,
medusa_packed_mask=None, # For Medusa support
medusa_position_offsets=None,
use_cache=False,
kv_cache_params=None,
attention_params=None,
lora_layer_params=None,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

attention_output = self.attention(
hidden_states,
attention_mask=attention_mask,
medusa_packed_mask=medusa_packed_mask, # For Medusa support
medusa_position_offsets=medusa_position_offsets,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
lora_layer_params=lora_layer_params,
)

if use_cache:
attention_output, presents = attention_output

hidden_states = residual + attention_output

residual = hidden_states
hidden_states = self.post_layernorm(hidden_states)

hidden_states = self.mlp(hidden_states, lora_layer_params=lora_layer_params)

hidden_states = residual + hidden_states
if use_cache:
return (hidden_states, presents)
return hidden_states


class GemmaDecoderLayerConfigBuilder(DecoderLayerConfigBuilder):
"""The LLAMA implementation of the DecoderLayerConfigBuilder."""

Expand Down Expand Up @@ -200,8 +116,7 @@ def build_decoder(self, layer):
world_size=self.tensor_parallel,
tp_size=self.tensor_parallel,
pp_size=1,
quant_mode=QuantMode(0),
quant_kwargs=None,
quantization=QuantConfig(),
max_lora_rank=layer.max_lora_rank,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/export/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ python tests/export/test_nemo_export.py --model_name FALCON-7B-base --existing_t
python tests/export/test_nemo_export.py --model_name FALCON-40B-base --existing_test_models --min_gpus 2 --max_gpus 8
python tests/export/test_nemo_export.py --model_name FALCON-180B-base --existing_test_models --min_gpus 8 --max_gpus 8
python tests/export/test_nemo_export.py --model_name STARCODER1-15B-base --existing_test_models --min_gpus 1 --max_gpus 1
python tests/export/test_nemo_export.py --model_name GEMMA-base --existing_test_models --min_gpus 1 --max_gpus 1 --run_accuracy --test_deployment True
python tests/export/test_nemo_export.py --model_name GEMMA-base --existing_test_models --min_gpus 1 --max_gpus 1
Loading