Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dnhkng committed Jan 13, 2024
1 parent a4e75d4 commit 63e5c34
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,20 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False):
transformerSublayers = 2
layer_arrangement = [list(range(*interval)) for interval in config.repeats]
layer_arrangement = [item for sublist in layer_arrangement for item in sublist]
LayeredModules = self.modules


self.modules = LayeredModules[:embedTokenLayers]
LayeredModules = self.modules[:embedTokenLayers]
for idx in layer_arrangement:
self.modules += LayeredModules[idx*transformerSublayers + embedTokenLayers : idx*transformerSublayers + transformerSublayers + embedTokenLayers]
self.modules += LayeredModules[-2:]
LayeredModules += self.modules[idx*transformerSublayers + embedTokenLayers : idx*transformerSublayers + transformerSublayers + embedTokenLayers]
LayeredModules += self.modules[-2:]
self.head_layer_idx = len(self.modules) -1
self.last_kv_layer_idx = len(self.modules) -4

for i, m in enumerate(self.modules):
for i, m in enumerate(LayeredModules):
print(i, m.key)

self.layeredModules = LayeredModules


def set_device_map(self, allocation, embed_cpu = True):

Expand Down Expand Up @@ -631,9 +632,8 @@ def process_module(module, x, last_state):
attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets)
last_state = None

if hasattr(self, 'layers_list'):
for i, idx in enumerate(self.layers_list):
module = self.modules[idx]
if hasattr(self, 'layeredModules'):
for idx, module in enumerate(self.layeredModules):
x, last_state = process_module(module, x, last_state)
if preprocess_only and idx == self.last_kv_layer_idx:
x = None
Expand Down

0 comments on commit 63e5c34

Please sign in to comment.