Skip to content

Commit

Permalink
Dynamic frankenmerges
Browse files Browse the repository at this point in the history
  • Loading branch information
zpin committed Jan 13, 2024
1 parent 2dc8db8 commit cdff7b2
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
57 changes: 57 additions & 0 deletions modules/exllamav2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import traceback
from pathlib import Path
import itertools

import torch
from exllamav2 import (
Expand All @@ -10,6 +11,7 @@
ExLlamaV2Tokenizer
)
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
from exllamav2.attn import ExLlamaV2Attention

from modules import shared
from modules.logging_colors import logger
Expand All @@ -29,6 +31,20 @@
logger.warning('Failed to load flash-attention due to the following error:\n')
traceback.print_exc()

class ExLlamaV2AttentionWrapper(ExLlamaV2Attention):
def __init__(self, obj, new_idx):
object.__setattr__(self, '_obj', obj)
object.__setattr__(self, '_new_idx', new_idx)

def __getattribute__(self, name):
if name == 'layer_idx':
return object.__getattribute__(self, '_new_idx')

# Delegate all other attributes to the wrapped object
try:
return getattr(object.__getattribute__(self, '_obj'), name)
except AttributeError:
return object.__getattribute__(self, name)

class Exllamav2Model:
def __init__(self):
Expand Down Expand Up @@ -66,6 +82,8 @@ def from_pretrained(self, path_to_model):
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)

result = self()
result.orig_modules = model.modules
result._stepstride = '00'
result.model = model
result.cache = cache
result.tokenizer = tokenizer
Expand All @@ -92,6 +110,45 @@ def get_logits(self, token_ids, **kwargs):
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu()

def generate_with_streaming(self, prompt, state):
step = int(state.get('franken_step', 0))
stride = int(state.get('franken_stride', 0))
if not stride:
stride = step * 2
stepstride = f'{step}{stride}'
if self._stepstride != stepstride:
self._stepstride = stepstride
if step:
layer_arrangement = []

num_layers = int((len(self.orig_modules) - 3) / 2)
for i in range(int((num_layers - stride) / step) + 1):
layer_arrangement.append(range(i * step, i * step + stride))
layers = list(itertools.chain(*layer_arrangement))

print(f'Layers {num_layers} -> {len(layers)}, arrangement: {layer_arrangement}')

# modules arangement: [embedding, [...layers], rms-norm, head]
# where each layer is [attention, mlp]
self.model.modules = self.orig_modules[:1]
for i, idx in enumerate(layers):
self.model.modules.append(ExLlamaV2AttentionWrapper(self.orig_modules[idx*2 + 1], i))
self.model.modules.append(self.orig_modules[idx*2 + 2])
self.model.modules += self.orig_modules[-2:]
else:
self.model.modules = self.orig_modules
num_layers = int((len(self.model.modules) - 3) / 2)
self.model.head_layer_idx = len(self.model.modules) -1
self.model.config.num_hidden_layers = num_layers
self.model.last_kv_layer_idx = len(self.model.modules) -4
cache_class = type(self.cache)
del self.generator
del self.cache
print('Re-creating cache')
self.model.cache_map = {}
self.model.set_cache_map()
self.cache = cache_class(self.model)
self.generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer)

settings = ExLlamaV2Sampler.Settings()
settings.temperature = state['temperature']
settings.top_k = state['top_k']
Expand Down
2 changes: 2 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def transformers_samplers():
'custom_token_bans',
'skip_special_tokens',
'auto_max_new_tokens',
'franken_step',
'franken_stride',
},
'ExLlamav2_HF': {
'temperature',
Expand Down
2 changes: 2 additions & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def list_interface_input_elements():
'stream',
'tfs',
'top_a',
'franken_step',
'franken_stride',
]

# Chat elements
Expand Down
2 changes: 2 additions & 0 deletions modules/ui_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def create_ui(default_preset):
shared.gradio['top_a'] = gr.Slider(0.0, 1.0, value=generate_params['top_a'], step=0.01, label='top_a')
shared.gradio['epsilon_cutoff'] = gr.Slider(0, 9, value=generate_params['epsilon_cutoff'], step=0.01, label='epsilon_cutoff')
shared.gradio['eta_cutoff'] = gr.Slider(0, 20, value=generate_params['eta_cutoff'], step=0.01, label='eta_cutoff')
shared.gradio['franken_step'] = gr.Number(value=generate_params.get('franken_step', 0), label='Franken Step', info='Cut every X layers when stacking. 0 to disable.')
shared.gradio['franken_stride'] = gr.Number(value=generate_params.get('franken_stride', 0), label='Franken Stride', info='Cut length when stacking. 0 to use 2 * Franken Step')

with gr.Column():
shared.gradio['guidance_scale'] = gr.Slider(-0.5, 2.5, step=0.05, value=generate_params['guidance_scale'], label='guidance_scale', info='For CFG. 1.5 is a good value.')
Expand Down

0 comments on commit cdff7b2

Please sign in to comment.