Skip to content

Commit

Permalink
Add support for new MOE model mistralai/Mixtral-8x7B-v0.1 (#8)
Browse files Browse the repository at this point in the history
* Add support for new MOE model mistralai/Mixtral-8x7B-v0.1

Signed-off-by: akuruvil <[email protected]>

* Update cache utils

Signed-off-by: akuruvil <[email protected]>

* Updated modeling files

Signed-off-by: akuruvil <[email protected]>

* Updated modeling files

Signed-off-by: akuruvil <[email protected]>

* RMSNorm moved to common utils file

Signed-off-by: akuruvil <[email protected]>

* Restructuring code

Signed-off-by: akuruvil <[email protected]>

* modeling file changes

Signed-off-by: akuruvil <[email protected]>

* Updating test files

Signed-off-by: akuruvil <[email protected]>

* Added logger warning

Signed-off-by: akuruvil <[email protected]>

* Updated utils

Signed-off-by: akuruvil <[email protected]>

* Update test_modeling_mixtral.py

Updated model card to mistralai/Mixtral-8x7B-Instruct-v0.1.

Signed-off-by: quic-amitraj <[email protected]>

---------

Signed-off-by: akuruvil <[email protected]>
Signed-off-by: quic-amitraj <[email protected]>
Co-authored-by: quic-amitraj <[email protected]>
  • Loading branch information
quic-akuruvil and quic-amitraj authored May 28, 2024
1 parent c218129 commit 69ef228
Show file tree
Hide file tree
Showing 10 changed files with 1,008 additions and 9 deletions.
62 changes: 62 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------


from typing import Any, Dict, List, Optional, Tuple
import torch
from transformers.cache_utils import (
DynamicCache
)


class QEffDynamicCache(DynamicCache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
"""

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
kv_indices = torch.arange(key_states.shape[2]) + cache_kwargs['cache_index']
self.key_cache[layer_idx][:,:, kv_indices] = key_states
self.value_cache[layer_idx][:,:, kv_indices] = value_states

return self.key_cache[layer_idx], self.value_cache[layer_idx]


92 changes: 92 additions & 0 deletions QEfficient/transformers/modeling_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,95 @@ class QEffCausalLMOutputWithPast(ModelOutput):
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
attention_mask_RetainedState: Optional[torch.BoolTensor] = None


@dataclass
class QEffMoeModelOutputWithPast(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
loss for Mixture of Experts models.
"""

last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
router_logits: Optional[Tuple[torch.FloatTensor]] = None
attention_mask_RetainedState: Optional[torch.BoolTensor] = None

@dataclass
class QEffMoeCausalLMOutputWithPast(ModelOutput):
"""
Base class for causal language model (or autoregressive) with mixture of experts outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
aux_loss for the sparse modules.
router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
loss for Mixture of Experts models.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

loss: Optional[torch.FloatTensor] = None
aux_loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
router_logits: Optional[Tuple[torch.FloatTensor]] = None
attention_mask_RetainedState: Optional[torch.BoolTensor] = None
33 changes: 33 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@
MistralRMSNorm,
MistralRotaryEmbedding,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralForCausalLM,
MixtralModel,
MixtralDecoderLayer,
MixtralSparseMoeBlock,
MixtralBLockSparseTop2MLP,
MixtralRotaryEmbedding,
MixtralRMSNorm,
)
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel

from QEfficient.customop import CustomRMSNormAIC
Expand All @@ -47,6 +57,8 @@
QEffBaseModelOutputWithPastAndCrossAttentions,
QEffCausalLMOutputWithCrossAttentions,
QEffCausalLMOutputWithPast,
QEffMoeCausalLMOutputWithPast,
QEffMoeModelOutputWithPast,
)
from .models.codegen.modeling_codegen import (
QEffCodeGenAttention,
Expand All @@ -68,6 +80,15 @@
QEffMistralModel,
QEffMistralRotaryEmbedding,
)
from .models.mixtral_moe.modeling_mixtral import (
QEffMixtralModel,
QEffMixtralRotaryEmbedding,
QEffMixtralAttention,
QEffMixtralForCausalLM,
QEffMixtralDecoderLayer,
QEffMixtralSparseMoeBlock,
QEffMixtralBLockSparseTop2MLP,
)
from .models.mpt.modeling_mpt import QEffMptAttention, QEffMptBlock, QEffMptForCausalLM, QEFfMptModel

# Define a named tuple for ModelArchitectures
Expand All @@ -81,6 +102,7 @@
CodeGenForCausalLM.__name__,
LlamaForCausalLM.__name__,
MistralForCausalLM.__name__,
MixtralForCausalLM.__name__,
]
)

Expand Down Expand Up @@ -115,6 +137,15 @@
MistralForCausalLM: QEffMistralForCausalLM,
MistralRotaryEmbedding: QEffMistralRotaryEmbedding,
MistralRMSNorm: CustomRMSNormAIC,
# Mixtral model layers
MixtralAttention: QEffMixtralAttention,
MixtralModel: QEffMixtralModel,
MixtralDecoderLayer: QEffMixtralDecoderLayer,
MixtralForCausalLM: QEffMixtralForCausalLM,
MixtralRotaryEmbedding: QEffMixtralRotaryEmbedding,
MixtralRMSNorm: CustomRMSNormAIC,
MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock,
MixtralBLockSparseTop2MLP:QEffMixtralBLockSparseTop2MLP,
}


Expand Down Expand Up @@ -186,6 +217,8 @@ def transform(model: nn.Module, form_factor: str = "cloud") -> nn.Module:
transformers.modeling_outputs.CausalLMOutputWithCrossAttentions = QEffCausalLMOutputWithCrossAttentions
transformers.modeling_outputs.BaseModelOutputWithPast = QEffBaseModelOutputWithPast
transformers.modeling_outputs.CausalLMOutputWithPast = QEffCausalLMOutputWithPast
transformers.modeling_outputs.MoeCausalLMOutputWithPast = QEffMoeCausalLMOutputWithPast
transformers.modeling_outputs.MoeModelOutputWithPast = QEffMoeModelOutputWithPast

# Replace the modeling attn util classes and functions
transformers.modeling_attn_mask_utils.AttentionMaskConverter = QEffAttentionMaskConverter
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/mixtral_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Loading

0 comments on commit 69ef228

Please sign in to comment.