Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repeat layers to create FrankenModels #275

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ build/
__pycache__/
.idea
venv
dist
dist
*.so
85 changes: 56 additions & 29 deletions exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False):
self.modules.append(ExLlamaV2Embedding(self, "model.embed_tokens"))
self.modules_dict[self.modules[-1].key] = self.modules[-1]

for layer_idx in range(self.config.num_hidden_layers):
for layer_list in range(self.config.num_hidden_layers):

self.modules.append(ExLlamaV2Attention(self, f"model.layers.{layer_idx}", layer_idx))
self.modules.append(ExLlamaV2Attention(self, f"model.layers.{layer_list}", layer_list))
for m in self.modules[-1].submodules: self.modules_dict[m.key] = m
if self.config.architecture == "Mixtral":
self.modules.append(ExLlamaV2MoEMLP(self, f"model.layers.{layer_idx}", layer_idx))
self.modules.append(ExLlamaV2MoEMLP(self, f"model.layers.{layer_list}", layer_list))
else:
self.modules.append(ExLlamaV2MLP(self, f"model.layers.{layer_idx}", layer_idx))
self.modules.append(ExLlamaV2MLP(self, f"model.layers.{layer_list}", layer_list))
for m in self.modules[-1].submodules: self.modules_dict[m.key] = m


Expand All @@ -150,13 +150,32 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False):

# Find last layer that affects k/v cache

layer_idx = len(self.modules)
layer_list = len(self.modules)
while True:
layer_idx -= 1
if isinstance(self.modules[layer_idx], ExLlamaV2Attention):
layer_list -= 1
if isinstance(self.modules[layer_list], ExLlamaV2Attention):
break

self.last_kv_layer_idx = layer_idx
self.last_kv_layer_idx = layer_list

if hasattr(config, 'repeats'):
embedTokenLayers = 1
transformerSublayers = 2
layer_arrangement = [list(range(*interval)) for interval in config.repeats]
layer_arrangement = [item for sublist in layer_arrangement for item in sublist]


LayeredModules = self.modules[:embedTokenLayers]
for idx in layer_arrangement:
LayeredModules += self.modules[idx*transformerSublayers + embedTokenLayers : idx*transformerSublayers + transformerSublayers + embedTokenLayers]
LayeredModules += self.modules[-2:]
self.head_layer_idx = len(self.modules) -1
self.last_kv_layer_idx = len(self.modules) -4

for i, m in enumerate(LayeredModules):
print(i, m.key)

self.layeredModules = LayeredModules


def set_device_map(self, allocation, embed_cpu = True):
Expand Down Expand Up @@ -582,6 +601,23 @@ def _forward(self,
return_last_state = False,
position_offsets = None):

def process_module(module, x, last_state):
device = _torch_device(module.device_idx)

if idx == self.head_layer_idx:
if last_id_only and return_last_state:
x = x.narrow(-2, -1, 1)
last_state = x
elif last_id_only:
x = x.narrow(-2, -1, 1)
elif return_last_state:
last_state = x.narrow(-2, -1, 1)

x = safe_move_tensor(x, device)
x = module.forward(x, cache=cache, attn_params=attn_params, past_len=past_len, loras=loras)

return x, last_state

batch_size, seq_len = input_ids.shape
past_len = 0
if cache is not None:
Expand All @@ -596,27 +632,18 @@ def _forward(self,
attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets)
last_state = None

for idx, module in enumerate(self.modules):

device = _torch_device(module.device_idx)

# Onward

if idx == self.head_layer_idx:
if last_id_only and return_last_state:
x = x.narrow(-2, -1, 1)
last_state = x
elif last_id_only:
x = x.narrow(-2, -1, 1)
elif return_last_state:
last_state = x.narrow(-2, -1, 1)

x = safe_move_tensor(x, device)
x = module.forward(x, cache = cache, attn_params = attn_params, past_len = past_len, loras = loras)

if preprocess_only and idx == self.last_kv_layer_idx:
x = None
break
if hasattr(self, 'layeredModules'):
for idx, module in enumerate(self.layeredModules):
x, last_state = process_module(module, x, last_state)
if preprocess_only and idx == self.last_kv_layer_idx:
x = None
break
else:
for idx, module in enumerate(self.modules):
x, last_state = process_module(module, x, last_state)
if preprocess_only and idx == self.last_kv_layer_idx:
x = None
break

# Advance cache

Expand Down
22 changes: 20 additions & 2 deletions exllamav2/model_init.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

import argparse, sys, os, glob
import argparse, sys, os, glob, ast

from exllamav2 import(
ExLlamaV2,
Expand All @@ -17,6 +17,7 @@ def add_args(parser):
parser.add_argument("-nfa", "--no_flash_attn", action = "store_true", help = "Disable Flash Attention")
parser.add_argument("-lm", "--low_mem", action = "store_true", help = "Enable VRAM optimizations, potentially trading off speed")
parser.add_argument("-ept", "--experts_per_token", type = int, help = "Override MoE model's default number of experts per token")
parser.add_argument("--repeats", type=parse_tuple_list, help="List of tuples of the layers to repeat")


def print_options(args):
Expand Down Expand Up @@ -60,6 +61,22 @@ def check_args(args):
print(f" ## Error: Cannot find {filename} in {args.model_dir}")
sys.exit()

def parse_tuple_list(string):
try:
# Safely evaluate the string as a Python literal (list of tuples)
tuple_list = ast.literal_eval(string)

# Ensure all elements in the list are tuples
if not all(isinstance(item, tuple) for item in tuple_list):
raise ValueError("All elements must be tuples")

# Convert tuple elements to integers
int_tuple_list = [tuple(int(x) for x in item) for item in tuple_list]

return int_tuple_list
except:
raise argparse.ArgumentTypeError("Input must be a valid list of tuples with integer elements")


def init(args, quiet = False, allow_auto_split = False, skip_load = False):

Expand All @@ -76,7 +93,8 @@ def init(args, quiet = False, allow_auto_split = False, skip_load = False):
if args.rope_alpha: config.scale_alpha_value = args.rope_alpha
config.no_flash_attn = args.no_flash_attn
if args.experts_per_token: config.num_experts_per_token = args.experts_per_token

if args.repeats: config.repeats = args.repeats

# Set low-mem options

if args.low_mem: config.set_low_mem()
Expand Down