>_ 🎓 LongLLaMA-Code 7B Instruct 📑🗨 |
---|
Learn more |
⇧ |
TLDR | Overview | Usage | LongLLaMA performance | Authors | Citation | License | Acknowledgments
This repository contains the research preview of LongLLaMA, a large language model capable of handling long contexts of 256k tokens or even more.
LongLLaMA is built upon the foundation of OpenLLaMA and fine-tuned using the Focused Transformer (FoT) method.
LongLLaMA Code is built upon the foundation of Code Llama.
We release a smaller 3B base variant (not instruction tuned) of the LongLLaMA model on a permissive license (Apache 2.0) and inference code supporting longer contexts on Hugging Face. Our model weights can serve as the drop-in replacement of LLaMA in existing implementations (for short context up to 2048 tokens). Additionally, we provide evaluation results and comparisons against the original OpenLLaMA models.
In addition to this, we release code for instruction tuning (PyTorch) and FoT continued pretraining (JAX).
Focused Transformer: Contrastive Training for Context Scaling (FoT) presents a simple method for endowing language models with the ability to handle context consisting possibly of millions of tokens while training on significantly shorter input. FoT permits a subset of attention layers to access a memory cache of (key, value) pairs to extend the context length. The distinctive aspect of FoT is its training procedure, drawing from contrastive learning. Specifically, we deliberately expose the memory attention layers to both relevant and irrelevant keys (like negative samples from unrelated documents). This strategy incentivizes the model to differentiate keys connected with semantically diverse values, thereby enhancing their structure. This, in turn, makes it possible to extrapolate the effective context length much beyond what is seen in training.
LongLLaMA is an OpenLLaMA model finetuned with the FoT method,
with three layers used for context extension. Crucially, LongLLaMA is able to extrapolate much beyond the context length seen in training:
LongLLaMA Code is a Code Llama model finetuned with the FoT method.
LongLLaMA-3B | LongLLaMA-3Bv1.1 | LongLLaMA-Code 7B | |
---|---|---|---|
Source model | OpenLLaMA-3B | OpenLLaMA-3Bv2 | CodeLLaMA-7b-hf |
Source model tokens | 1T | 1 T | 2T + 0.5 T |
Fine-tuning tokens | 10B | 5B | 35B |
Memory layers | 6, 12, 18 | 6, 12, 18 | 8, 16, 24 |
In the fot_continued_pretraining subfolder, we provide the code that can be used to tune LLaMA models with FoT.
This code is written in JAX & Flax and based on EasyLM.
In the instruction_fine_tuning subfolder, we provide the code that was used to create LongLLaMA-Instruct-3Bv1.1, an instruction-tuned version of LongLLaMA-3Bv1.1. We used OpenOrca (instructions) and zetavg/ShareGPT-Processed (chat) datasets for tuning.
This code utilizes PyTorch and Hugging Face trainer.
In the src subfolder we provide inference code for FoT models.
The code is written in PyTorch and based on Hugging Face implementation of LLaMA.
The code should support standard Hugging Face API. For more details see the Usage section.
See also:
pip install --upgrade pip
pip install transformers==4.33.2 sentencepiece accelerate
import torch
from transformers import LlamaTokenizer, AutoModelForCausalLM
tokenizer = LlamaTokenizer.from_pretrained("syzymon/long_llama_3b_v1_1")
model = AutoModelForCausalLM.from_pretrained("syzymon/long_llama_3b_v1_1",
torch_dtype=torch.float32,
trust_remote_code=True)
LongLLaMA uses the Hugging Face interface, the long input given to the model will be split into context windows and loaded into the memory cache.
prompt = "My name is Julien and I like to"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
outputs = model(input_ids=input_ids)
During the model call, one can provide the parameter last_context_length
(default
generation_output = model.generate(
input_ids=input_ids,
max_new_tokens=256,
num_beams=1,
last_context_length=1792,
do_sample=True,
temperature=1.0,
)
print(tokenizer.decode(generation_output[0]))
LongLLaMA has several other parameters:
-
mem_layers
specifies layers endowed with memory (should be either an empty list or a list of all memory layers specified in the description of the checkpoint). -
mem_dtype
allows changing the type of memory cache -
mem_attention_grouping
can trade off speed for reduced memory usage. When equal to(4, 2048)
, the memory layers will process at most$4*2048$ queries at once ($4$ heads and$2048$ queries for each head).
import torch
from transformers import LlamaTokenizer, AutoModelForCausalLM
tokenizer = LlamaTokenizer.from_pretrained("syzymon/long_llama_3b_v1_1")
model = AutoModelForCausalLM.from_pretrained(
"syzymon/long_llama_3b_v1_1", torch_dtype=torch.float32,
mem_layers=[],
mem_dtype='bfloat16',
trust_remote_code=True,
mem_attention_grouping=(4, 2048),
)
LongLLaMA checkpoints can also be used as a drop-in replacement for LLaMA checkpoints in Hugging Face implementation of LLaMA, but in this case, they will be limited to the original context length of
from transformers import LlamaTokenizer, LlamaForCausalLM
import torch
tokenizer = LlamaTokenizer.from_pretrained("syzymon/long_llama_3b_v1_1")
model = LlamaForCausalLM.from_pretrained("syzymon/long_llama_3b_v1_1", torch_dtype=torch.float32)
Inputs over last_context_length
. The model processes the windows one by one extending the memory cache after each. If use_cache
is True
, then the last window will not be loaded to the memory cache but to the local (generation) cache.
The memory cache stores mem_layers
. In addition to this, it stores attention masks.
If use_cache=True
(which is the case in generation), LongLLaMA will use two caches: the memory cache for the specified layers and the local (generation) cache for all layers. When the local cache exceeds
For simplicity, context extension is realized with a memory cache and full attention in this repo. Replacing this simple mechanism with a KNN search over an external database is possible with systems like Faiss. This potentially would enable further context length scaling. We leave this as a future work.
We present some illustrative examples of LongLLaMA results. Refer to our paper Focused Transformer: Contrastive Training for Context Scaling for more details.
We manage to achieve good performance on the passkey retrieval task from Landmark Attention: Random-Access Infinite Context Length for Transformers. The code for generating the prompt and running the model is located in examples/passkey.py
.
Our LongLLaMA 3B model also shows improvements when using long context on two downstream tasks, TREC question classification and WebQS question answering.
Context/Dataset | TREC | WebQS |
---|---|---|
67.0 | 21.2 | |
71.6 | 21.4 | |
72.9 | 22.2 | |
73.3 | 22.4 |
LongLLaMA retains performance on tasks that do not require long context.
In particular, LongLLaMA-Code 7B improves reasoning (GSM8K) and knowledge (MMLU) due to code fine-tuning:
We provide a comparison with OpenLLaMA on lm-evaluation-harness in the zero-shot setting.
Task/Metric | OpenLLaMA-3B | LongLLaMA-3B |
---|---|---|
anli_r1/acc | 0.33 | 0.32 |
anli_r2/acc | 0.32 | 0.33 |
anli_r3/acc | 0.35 | 0.35 |
arc_challenge/acc | 0.34 | 0.34 |
arc_challenge/acc_norm | 0.37 | 0.37 |
arc_easy/acc | 0.69 | 0.68 |
arc_easy/acc_norm | 0.65 | 0.63 |
boolq/acc | 0.68 | 0.68 |
hellaswag/acc | 0.49 | 0.48 |
hellaswag/acc_norm | 0.67 | 0.65 |
openbookqa/acc | 0.27 | 0.28 |
openbookqa/acc_norm | 0.40 | 0.38 |
piqa/acc | 0.75 | 0.73 |
piqa/acc_norm | 0.76 | 0.75 |
record/em | 0.88 | 0.87 |
record/f1 | 0.89 | 0.87 |
rte/acc | 0.58 | 0.60 |
truthfulqa_mc/mc1 | 0.22 | 0.24 |
truthfulqa_mc/mc2 | 0.35 | 0.38 |
wic/acc | 0.48 | 0.50 |
winogrande/acc | 0.62 | 0.60 |
Avg score | 0.53 | 0.53 |
Starting with v1.1 models we have decided to use EleutherAI implementation of lm-evaluation-harness with a slight modification, that adds <bos>
token at beginning of input sequence. The results are provided in the table below.
description | LongLLaMA-3B | OpenLLaMA-3Bv2 | LongLLaMA-3Bv1.1 | LongLLaMA-Instruct-3Bv1.1 |
---|---|---|---|---|
anli_r1/acc | 0.32 | 0.33 | 0.31 | 0.33 |
anli_r2/acc | 0.33 | 0.35 | 0.33 | 0.35 |
anli_r3/acc | 0.35 | 0.38 | 0.35 | 0.38 |
arc_challenge/acc | 0.34 | 0.33 | 0.32 | 0.36 |
arc_challenge/acc_norm | 0.37 | 0.36 | 0.36 | 0.37 |
arc_easy/acc | 0.67 | 0.68 | 0.68 | 0.7 |
arc_easy/acc_norm | 0.63 | 0.63 | 0.63 | 0.63 |
boolq/acc | 0.68 | 0.67 | 0.66 | 0.77 |
hellaswag/acc | 0.48 | 0.53 | 0.52 | 0.52 |
hellaswag/acc_norm | 0.65 | 0.7 | 0.69 | 0.68 |
openbookqa/acc | 0.28 | 0.28 | 0.28 | 0.28 |
openbookqa/acc_norm | 0.38 | 0.39 | 0.37 | 0.41 |
piqa/acc | 0.73 | 0.77 | 0.77 | 0.78 |
piqa/acc_norm | 0.75 | 0.78 | 0.77 | 0.77 |
record/em | 0.87 | 0.87 | 0.86 | 0.85 |
record/f1 | 0.88 | 0.88 | 0.87 | 0.86 |
rte/acc | 0.6 | 0.53 | 0.62 | 0.7 |
truthfulqa_mc/mc1 | 0.24 | 0.22 | 0.21 | 0.25 |
truthfulqa_mc/mc2 | 0.38 | 0.35 | 0.35 | 0.4 |
wic/acc | 0.5 | 0.5 | 0.5 | 0.54 |
winogrande/acc | 0.6 | 0.66 | 0.63 | 0.65 |
Avg score | 0.53 | 0.53 | 0.53 | 0.55 |
We also provide the results on human-eval. We cut the generated text after either
"\ndef "
"\nclass "
"\nif __name__"
OpenLLaMA-3Bv2 | LongLLaMA-3Bv1.1 | LongLLaMA-Instruct-3Bv1.1 | |
---|---|---|---|
pass@1 | 0.09 | 0.12 | 0.12 |
To cite this work please use
@misc{tworkowski2023focused,
title={Focused Transformer: Contrastive Training for Context Scaling},
author={Szymon Tworkowski and Konrad Staniszewski and Mikołaj Pacek and Yuhuai Wu and Henryk Michalewski and Piotr Miłoś},
year={2023},
eprint={2307.03170},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
The source code and base LongLLaMA 3B models checkpoints are licensed under Apache License, Version 2.0.
The instruction/chat tuned models are for research purposes only.
For the LongLLaMA-Code 7B see codellama/CodeLlama-7b-hf license.
LongLLaMA-Code 7B Instruct is LongLLaMA-Code 7B tuned on TIGER-Lab/MathInstruct, OpenOrca and ShareGPT-Processed datasets.
Some of the examples use external code (see headers of files for copyright notices and licenses).
We gratefully acknowledge the TPU Research Cloud program, which was instrumental to our research by providing significant computational resources. We are also grateful to Xinyang Geng and Hao Liu for releasing OpenLLaMA checkpoints and the EasyLM library.
Special thanks to Keiran Paster for providing immensely valuable suggestions about the pre-training data for LongLLaMA-Code.
We would like to thank Xiaosong,He for suggestions on how to improve the explanations of cross-batch code.