-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for new MOE model mistralai/Mixtral-8x7B-v0.1 (#8)
* Add support for new MOE model mistralai/Mixtral-8x7B-v0.1 Signed-off-by: akuruvil <[email protected]> * Update cache utils Signed-off-by: akuruvil <[email protected]> * Updated modeling files Signed-off-by: akuruvil <[email protected]> * Updated modeling files Signed-off-by: akuruvil <[email protected]> * RMSNorm moved to common utils file Signed-off-by: akuruvil <[email protected]> * Restructuring code Signed-off-by: akuruvil <[email protected]> * modeling file changes Signed-off-by: akuruvil <[email protected]> * Updating test files Signed-off-by: akuruvil <[email protected]> * Added logger warning Signed-off-by: akuruvil <[email protected]> * Updated utils Signed-off-by: akuruvil <[email protected]> * Update test_modeling_mixtral.py Updated model card to mistralai/Mixtral-8x7B-Instruct-v0.1. Signed-off-by: quic-amitraj <[email protected]> --------- Signed-off-by: akuruvil <[email protected]> Signed-off-by: quic-amitraj <[email protected]> Co-authored-by: quic-amitraj <[email protected]>
- Loading branch information
1 parent
c218129
commit 69ef228
Showing
10 changed files
with
1,008 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
|
||
from typing import Any, Dict, List, Optional, Tuple | ||
import torch | ||
from transformers.cache_utils import ( | ||
DynamicCache | ||
) | ||
|
||
|
||
class QEffDynamicCache(DynamicCache): | ||
""" | ||
A cache that grows dynamically as more tokens are generated. This is the default for generative models. | ||
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is | ||
`[batch_size, num_heads, seq_len, head_dim]`. | ||
""" | ||
|
||
def update( | ||
self, | ||
key_states: torch.Tensor, | ||
value_states: torch.Tensor, | ||
layer_idx: int, | ||
cache_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | ||
Parameters: | ||
key_states (`torch.Tensor`): | ||
The new key states to cache. | ||
value_states (`torch.Tensor`): | ||
The new value states to cache. | ||
layer_idx (`int`): | ||
The index of the layer to cache the states for. | ||
cache_kwargs (`Dict[str, Any]`, `optional`): | ||
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. | ||
Return: | ||
A tuple containing the updated key and value states. | ||
""" | ||
# Update the number of seen tokens | ||
if layer_idx == 0: | ||
self.seen_tokens += key_states.shape[-2] | ||
|
||
# Update the cache | ||
if len(self.key_cache) <= layer_idx: | ||
self.key_cache.append(key_states) | ||
self.value_cache.append(value_states) | ||
else: | ||
kv_indices = torch.arange(key_states.shape[2]) + cache_kwargs['cache_index'] | ||
self.key_cache[layer_idx][:,:, kv_indices] = key_states | ||
self.value_cache[layer_idx][:,:, kv_indices] = value_states | ||
|
||
return self.key_cache[layer_idx], self.value_cache[layer_idx] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- |
Oops, something went wrong.