From 63e5c3475f598689d6117812dd30ca175fefec95 Mon Sep 17 00:00:00 2001 From: Dvid Noel Ng Date: Sat, 13 Jan 2024 11:06:31 +0100 Subject: [PATCH] fix --- exllamav2/model.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/exllamav2/model.py b/exllamav2/model.py index 4dc8a34a..a4346efa 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -163,19 +163,20 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False): transformerSublayers = 2 layer_arrangement = [list(range(*interval)) for interval in config.repeats] layer_arrangement = [item for sublist in layer_arrangement for item in sublist] - LayeredModules = self.modules - self.modules = LayeredModules[:embedTokenLayers] + LayeredModules = self.modules[:embedTokenLayers] for idx in layer_arrangement: - self.modules += LayeredModules[idx*transformerSublayers + embedTokenLayers : idx*transformerSublayers + transformerSublayers + embedTokenLayers] - self.modules += LayeredModules[-2:] + LayeredModules += self.modules[idx*transformerSublayers + embedTokenLayers : idx*transformerSublayers + transformerSublayers + embedTokenLayers] + LayeredModules += self.modules[-2:] self.head_layer_idx = len(self.modules) -1 self.last_kv_layer_idx = len(self.modules) -4 - for i, m in enumerate(self.modules): + for i, m in enumerate(LayeredModules): print(i, m.key) + self.layeredModules = LayeredModules + def set_device_map(self, allocation, embed_cpu = True): @@ -631,9 +632,8 @@ def process_module(module, x, last_state): attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets) last_state = None - if hasattr(self, 'layers_list'): - for i, idx in enumerate(self.layers_list): - module = self.modules[idx] + if hasattr(self, 'layeredModules'): + for idx, module in enumerate(self.layeredModules): x, last_state = process_module(module, x, last_state) if preprocess_only and idx == self.last_kv_layer_idx: x = None