From 69ef228205daa66b39959c6df30fcc9d11f8fec7 Mon Sep 17 00:00:00 2001 From: quic-akuruvil Date: Tue, 28 May 2024 14:31:27 +0530 Subject: [PATCH] Add support for new MOE model mistralai/Mixtral-8x7B-v0.1 (#8) * Add support for new MOE model mistralai/Mixtral-8x7B-v0.1 Signed-off-by: akuruvil * Update cache utils Signed-off-by: akuruvil * Updated modeling files Signed-off-by: akuruvil * Updated modeling files Signed-off-by: akuruvil * RMSNorm moved to common utils file Signed-off-by: akuruvil * Restructuring code Signed-off-by: akuruvil * modeling file changes Signed-off-by: akuruvil * Updating test files Signed-off-by: akuruvil * Added logger warning Signed-off-by: akuruvil * Updated utils Signed-off-by: akuruvil * Update test_modeling_mixtral.py Updated model card to mistralai/Mixtral-8x7B-Instruct-v0.1. Signed-off-by: quic-amitraj <168538872+quic-amitraj@users.noreply.github.com> --------- Signed-off-by: akuruvil Signed-off-by: quic-amitraj <168538872+quic-amitraj@users.noreply.github.com> Co-authored-by: quic-amitraj <168538872+quic-amitraj@users.noreply.github.com> --- QEfficient/transformers/cache_utils.py | 62 ++ QEfficient/transformers/modeling_outputs.py | 92 +++ QEfficient/transformers/modeling_utils.py | 33 + .../models/mixtral_moe/__init__.py | 6 + .../models/mixtral_moe/modeling_mixtral.py | 660 ++++++++++++++++++ QEfficient/utils/constants.py | 3 + QEfficient/utils/device_utils.py | 43 +- scripts/Jenkinsfile | 2 +- .../mixtral_moe/test_modeling_mixtral.py | 83 +++ tests/utils.py | 33 +- 10 files changed, 1008 insertions(+), 9 deletions(-) create mode 100644 QEfficient/transformers/cache_utils.py create mode 100644 QEfficient/transformers/models/mixtral_moe/__init__.py create mode 100644 QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py create mode 100644 tests/transformer/models/mixtral_moe/test_modeling_mixtral.py diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py new file mode 100644 index 00000000..53a231fd --- /dev/null +++ b/QEfficient/transformers/cache_utils.py @@ -0,0 +1,62 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + + +from typing import Any, Dict, List, Optional, Tuple +import torch +from transformers.cache_utils import ( + DynamicCache +) + + +class QEffDynamicCache(DynamicCache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + """ + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self.seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + else: + kv_indices = torch.arange(key_states.shape[2]) + cache_kwargs['cache_index'] + self.key_cache[layer_idx][:,:, kv_indices] = key_states + self.value_cache[layer_idx][:,:, kv_indices] = value_states + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + \ No newline at end of file diff --git a/QEfficient/transformers/modeling_outputs.py b/QEfficient/transformers/modeling_outputs.py index f0eb7267..1ade8a3b 100644 --- a/QEfficient/transformers/modeling_outputs.py +++ b/QEfficient/transformers/modeling_outputs.py @@ -179,3 +179,95 @@ class QEffCausalLMOutputWithPast(ModelOutput): hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None attention_mask_RetainedState: Optional[torch.BoolTensor] = None + + +@dataclass +class QEffMoeModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + attention_mask_RetainedState: Optional[torch.BoolTensor] = None + +@dataclass +class QEffMoeCausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) with mixture of experts outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + aux_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + attention_mask_RetainedState: Optional[torch.BoolTensor] = None diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 433475c7..5ad29ef3 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -32,6 +32,16 @@ MistralRMSNorm, MistralRotaryEmbedding, ) +from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralForCausalLM, + MixtralModel, + MixtralDecoderLayer, + MixtralSparseMoeBlock, + MixtralBLockSparseTop2MLP, + MixtralRotaryEmbedding, + MixtralRMSNorm, +) from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel from QEfficient.customop import CustomRMSNormAIC @@ -47,6 +57,8 @@ QEffBaseModelOutputWithPastAndCrossAttentions, QEffCausalLMOutputWithCrossAttentions, QEffCausalLMOutputWithPast, + QEffMoeCausalLMOutputWithPast, + QEffMoeModelOutputWithPast, ) from .models.codegen.modeling_codegen import ( QEffCodeGenAttention, @@ -68,6 +80,15 @@ QEffMistralModel, QEffMistralRotaryEmbedding, ) +from .models.mixtral_moe.modeling_mixtral import ( + QEffMixtralModel, + QEffMixtralRotaryEmbedding, + QEffMixtralAttention, + QEffMixtralForCausalLM, + QEffMixtralDecoderLayer, + QEffMixtralSparseMoeBlock, + QEffMixtralBLockSparseTop2MLP, +) from .models.mpt.modeling_mpt import QEffMptAttention, QEffMptBlock, QEffMptForCausalLM, QEFfMptModel # Define a named tuple for ModelArchitectures @@ -81,6 +102,7 @@ CodeGenForCausalLM.__name__, LlamaForCausalLM.__name__, MistralForCausalLM.__name__, + MixtralForCausalLM.__name__, ] ) @@ -115,6 +137,15 @@ MistralForCausalLM: QEffMistralForCausalLM, MistralRotaryEmbedding: QEffMistralRotaryEmbedding, MistralRMSNorm: CustomRMSNormAIC, + # Mixtral model layers + MixtralAttention: QEffMixtralAttention, + MixtralModel: QEffMixtralModel, + MixtralDecoderLayer: QEffMixtralDecoderLayer, + MixtralForCausalLM: QEffMixtralForCausalLM, + MixtralRotaryEmbedding: QEffMixtralRotaryEmbedding, + MixtralRMSNorm: CustomRMSNormAIC, + MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock, + MixtralBLockSparseTop2MLP:QEffMixtralBLockSparseTop2MLP, } @@ -186,6 +217,8 @@ def transform(model: nn.Module, form_factor: str = "cloud") -> nn.Module: transformers.modeling_outputs.CausalLMOutputWithCrossAttentions = QEffCausalLMOutputWithCrossAttentions transformers.modeling_outputs.BaseModelOutputWithPast = QEffBaseModelOutputWithPast transformers.modeling_outputs.CausalLMOutputWithPast = QEffCausalLMOutputWithPast + transformers.modeling_outputs.MoeCausalLMOutputWithPast = QEffMoeCausalLMOutputWithPast + transformers.modeling_outputs.MoeModelOutputWithPast = QEffMoeModelOutputWithPast # Replace the modeling attn util classes and functions transformers.modeling_attn_mask_utils.AttentionMaskConverter = QEffAttentionMaskConverter diff --git a/QEfficient/transformers/models/mixtral_moe/__init__.py b/QEfficient/transformers/models/mixtral_moe/__init__.py new file mode 100644 index 00000000..8694aa93 --- /dev/null +++ b/QEfficient/transformers/models/mixtral_moe/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py new file mode 100644 index 00000000..195d9b9c --- /dev/null +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -0,0 +1,660 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + + +import math +import warnings +from typing import List, Optional, Tuple, Union +import onnxscript +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN + +from transformers.models.mixtral.modeling_mixtral import ( + logger, + MixtralAttention, + MixtralForCausalLM, + MixtralModel, + MixtralConfig, + MixtralDecoderLayer, + MixtralSparseMoeBlock, + MixtralBLockSparseTop2MLP, + MixtralRotaryEmbedding, + MixtralRMSNorm, + load_balancing_loss_func, + apply_rotary_pos_emb, + rotate_half, + repeat_kv, + _get_unpad_data, +) +from transformers.cache_utils import Cache +from QEfficient.customop import CustomRMSNormAIC +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_outputs import QEffMoeModelOutputWithPast, QEffMoeCausalLMOutputWithPast +from QEfficient.transformers.modeling_attn_mask_utils import _qeff_prepare_4d_causal_attention_mask + + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral +class QEffMixtralRotaryEmbedding(MixtralRotaryEmbedding): + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached.to(dtype=x.dtype), + self.sin_cached.to(dtype=x.dtype), + ) + + + + +class QEffMixtralAttention(MixtralAttention): + """ + Copied from MixtralForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py + The only differences are: + - add new args cache idx for the kv retention + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cache_index: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + if past_key_value.get_usable_length(kv_seq_len, self.layer_idx) > 0: + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs["cache_index"] = cache_index + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + + + +MIXTRAL_ATTENTION_CLASSES = { + "eager": MixtralAttention, +} + +class QEffMixtralBLockSparseTop2MLP(MixtralBLockSparseTop2MLP): + def __init__(self, config: MixtralConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + + +class QEffMixtralSparseMoeBlock(MixtralSparseMoeBlock): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([QEffMixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + expert_mask_tr = expert_mask[expert_idx].transpose(0, 1) + current_hidden_states = expert_layer(hidden_states) * (((routing_weights * expert_mask_tr).sum(1))[:, None]) + current_hidden_states = torch.where( + (routing_weights * expert_mask_tr).sum(1).to(torch.bool)[:, None], + current_hidden_states, + torch.tensor(0.0), + ) + final_hidden_states = final_hidden_states + current_hidden_states + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class QEffMixtralDecoderLayer(MixtralDecoderLayer): + """ + Copied from MixtralForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py + The only differences are: + - add new args cache idx for the kv retention + """ + + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.block_sparse_moe = QEffMixtralSparseMoeBlock(config) + self.input_layernorm = CustomRMSNormAIC(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = CustomRMSNormAIC(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_index: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + cache_index=cache_index, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +class QEffMixtralModel(MixtralModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] + + Args: + config: MixtralConfig + """ + + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [QEffMixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = CustomRMSNormAIC(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Ignore copy + # @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_index: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QEffMoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._use_flash_attention_2 and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if cache_index is not None: + attention_mask[:, cache_index + seq_length - 1] = True + attention_mask_RetainedState = attention_mask + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _qeff_prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + cache_index=cache_index, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + cache_index=cache_index, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return QEffMoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + attention_mask_RetainedState=attention_mask_RetainedState if cache_index is not None else None, + ) + + +class QEffMixtralForCausalLM(MixtralForCausalLM): + """ + Copied from MixtralForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py + The only differences are: + - add new args cache idx for the kv retention + """ + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = QEffMixtralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_index: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QEffMoeCausalLMOutputWithPast]: + """ + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cache_index=cache_index, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states[:, -1:]) + logits = logits.float() + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return QEffMoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + attention_mask_RetainedState=outputs.attention_mask_RetainedState, + ) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 6ee5e56f..236afc40 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -24,3 +24,6 @@ class Constants: INPUT_STRING = ["My name is"] CACHE_DIR = os.path.join(ROOT_DIR, "cache_dir") + + GB = 2**30 + MAX_QPC_LIMIT = 30 diff --git a/QEfficient/utils/device_utils.py b/QEfficient/utils/device_utils.py index 5e4737a4..74d7e3fc 100644 --- a/QEfficient/utils/device_utils.py +++ b/QEfficient/utils/device_utils.py @@ -6,7 +6,9 @@ # ----------------------------------------------------------------------------- import subprocess - +import math +from QEfficient.utils.logging_utils import logger +from QEfficient.utils.constants import Constants def get_available_device_id(): device_id = 0 @@ -27,3 +29,42 @@ def get_available_device_id(): elif "Failed to find requested device ID" in result.stdout: print("Failed to find requested device ID") return None + +def is_qpc_size_gt_32gb(params: int, mxfp6: bool) -> bool: + if mxfp6: + qpc_size = math.ceil((params * 1) / Constants.GB) + else: + qpc_size = math.ceil((params * 2) / Constants.GB) + + logger.warning(f"Approximate QPC size is: {qpc_size} GB") + num_devices = math.ceil(qpc_size / Constants.MAX_QPC_LIMIT) + logger.warning(f"Number of Devices required: {num_devices}" ) + return qpc_size > Constants.MAX_QPC_LIMIT + + +def is_multi_qranium_setup_available(): + result = None + command = ["/opt/qti-aic/tools/qaic-util", "-q"] + try: + result = subprocess.run(command, stdout=subprocess.PIPE, universal_newlines=True) + filtered_result = subprocess.run( + ["grep", "Device Capabilities"], input=result.stdout, stdout=subprocess.PIPE, text=True + ) + except OSError: + print("Command not found", command) + return None + + lines = filtered_result.stdout.split("\n") + + # to count the number of devices in MQ enabled set up + hybridboot_mdp_count = 0 + for line in lines: + if ("HybridBoot+" in line) and ("MDP+" in line): + hybridboot_mdp_count = hybridboot_mdp_count + 1 + + if hybridboot_mdp_count > 0: + print("No: of Devices with MQ enabled available: ", hybridboot_mdp_count) + return True + else: + print("Device in MQ set up not available") + return False diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 8bafa5a2..d7cc6aff 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -50,7 +50,7 @@ pipeline steps { //todo(ochougul): Increase when MQ tests are enabled - timeout(time: 60, unit: 'MINUTES') { + timeout(time: 420, unit: 'MINUTES') { sh ''' . preflight_qeff/bin/activate export TOKENIZERS_PARALLELISM=false diff --git a/tests/transformer/models/mixtral_moe/test_modeling_mixtral.py b/tests/transformer/models/mixtral_moe/test_modeling_mixtral.py new file mode 100644 index 00000000..ace4eeb1 --- /dev/null +++ b/tests/transformer/models/mixtral_moe/test_modeling_mixtral.py @@ -0,0 +1,83 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + + +import unittest +import pytest + +import transformers +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM + +from QEfficient.utils.constants import Constants +from QEfficient.utils.device_utils import get_available_device_id +from tests.utils import get_cloud_ai_100_tokens, set_up, skip_if_mq_not_enabled + + +def get_config(): + """ + Function to get config info from transformers.AutoConfig + :param: None + :return model_config - Dict + """ + + model_config = {} + model_config["model_name"] = "mistralai/Mixtral-8x7B-Instruct-v0.1" + config = transformers.AutoConfig.from_pretrained(model_config["model_name"]) + config._attn_implementation = "eager" + n_heads = config.num_attention_heads + d_head = config.hidden_size // n_heads + model_config["model_class"] = MixtralForCausalLM + model_config["n_layer"] = config.num_hidden_layers + model_config["padding_shape"] = [1, config.num_key_value_heads, Constants.CTX_LEN, d_head] + + return model_config + + + +class TestQEfficientMixtral(unittest.TestCase): + @classmethod + def setUpClass(self): + """ + Set up function to set up the test environment for TestQEfficientOPt class + :param None + """ + + self.model_config = get_config() + self.device_group = [0, 1, 2, 3] + self.setup_info = set_up(self.model_config, self.device_group) + + # TODO: Check this test case @Ann: Working when single model runs + @pytest.mark.skip(reason="Seeing issue in HF Model, Maybe due to tokenizer") + def test_qefficient_mixtral_torch(self): + """ + Test function to validate the mixtral model before and after KV changes on Pytorch + :param None + """ + assert ( + self.setup_info["pytorch_hf_tokens"] == self.setup_info["pytorch_kv_tokens"] + ).all(), "tokens aren't matching for hf pytorch model output and KV pytorch model output." + + def test_qefficient_mixtral_onnx(self): + """ + Test function to validate the mixtral model before and after KV changes on ONNXRT + # :param None + #""" + assert ( + self.setup_info["pytorch_kv_tokens"] == self.setup_info["ort_tokens"] + ).all(), "tokens aren't matching for onnxrt output and Pytorch output." + + @pytest.mark.skipif(not get_available_device_id(), reason="No available devices to run model on AIC") + @skip_if_mq_not_enabled + def test_qefficient_mixtral_cloud_ai_100(self): + """ + Test function to validate the mixtral model before and after KV changes on Cloud AI 100 + :param None + """ + cloud_ai_100_tokens = get_cloud_ai_100_tokens(self.setup_info) + assert ( + self.setup_info["ort_tokens"] == cloud_ai_100_tokens + ).all(), "tokens aren't matching for onnxrt output and Cloud AI 100 output." diff --git a/tests/utils.py b/tests/utils.py index d12921d4..d1c3a65f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,8 @@ import os import shutil - +import unittest +import functools import transformers import QEfficient @@ -15,10 +16,24 @@ from QEfficient.exporter.export_utils import compile_kv_model_on_cloud_ai_100 from QEfficient.utils import hf_download from QEfficient.utils.constants import QEFF_MODELS_DIR, ROOT_DIR, Constants -from QEfficient.utils.device_utils import get_available_device_id +from QEfficient.utils.device_utils import get_available_device_id, is_qpc_size_gt_32gb, is_multi_qranium_setup_available from QEfficient.utils.run_utils import ApiRunner +def skip_if_mq_not_enabled(test_method): + """ + Wrapper function to skip test if MQ setup not enabled + """ + + @functools.wraps(test_method) + def wrapper(self): + if self.setup_info["qpc_gt_32gb"] and (not is_multi_qranium_setup_available()): + raise unittest.SkipTest("Skip because MQ set up not available") + + return test_method(self) + + return wrapper + def prepare_work_dir(work_dir): """ Function to create the work directory location @@ -68,8 +83,9 @@ def load_pytorch_model(model_name, model_class): repo_id=model_name, ignore_patterns=["*.txt", "*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf"] ) model_hf = model_class.from_pretrained(model_path, use_cache=True) + params = sum(p.numel() for p in model_hf.parameters()) model_hf.eval() - return model_hf + return model_hf,params def transform_pt_model_with_qeff(model_hf): @@ -103,7 +119,7 @@ def export_onnx(model_kv, tokenizer, model_name, model_class): return base_path, onnx_model_path -def set_up(model_config): +def set_up(model_config, device_group=[0]): """ Set up function to set up the test environment for TestQEfficientModel class :param None @@ -115,8 +131,9 @@ def set_up(model_config): Constants.PROMPT_LEN, Constants.CTX_LEN, ) - - model_hf = load_pytorch_model(model_config["model_name"], model_config["model_class"]) + mxfp6 = False + model_hf, params = load_pytorch_model(model_config["model_name"], model_config["model_class"]) + qpc_gt_32gb = is_qpc_size_gt_32gb(params, mxfp6) try: pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) except Exception as e: @@ -149,7 +166,9 @@ def set_up(model_config): setup_info = {} setup_info["model_config"] = model_config + setup_info["device_group"] = device_group setup_info["api_runner"] = api_runner + setup_info["qpc_gt_32gb"] = qpc_gt_32gb setup_info["pytorch_hf_tokens"] = pytorch_hf_tokens setup_info["pytorch_kv_tokens"] = pytorch_kv_tokens setup_info["base_path"] = base_path @@ -175,7 +194,7 @@ def get_cloud_ai_100_tokens(setup_info): mxfp6=False, custom_io_path=os.path.join(setup_info["base_path"], "custom_io_fp16.yaml"), aic_enable_depth_first=False, - device_group=[0], + device_group=setup_info["device_group"], ) from QEfficient.generation.cloud_infer import QAICInferenceSession