diff --git a/FAQ.md b/FAQ.md
new file mode 100644
index 000000000..87ac67e13
--- /dev/null
+++ b/FAQ.md
@@ -0,0 +1,51 @@
+# FAQ
+## 1. The download.sh script doesn't work on default bash in MacOS X:
+
+Please see answers from theses issues:
+ - https://github.com/facebookresearch/llama/issues/41#issuecomment-1451290160
+ - https://github.com/facebookresearch/llama/issues/53#issue-1606582963
+
+
+## 2. Generations are bad!
+
+Keep in mind these models are not finetuned for question answering. As such, they should be prompted so that the expected answer is the natural continuation of the prompt.
+
+Here are a few examples of prompts (from [issue#69](https://github.com/facebookresearch/llama/issues/69)) geared towards finetuned models, and how to modify them to get the expected results:
+ - Do not prompt with "What is the meaning of life? Be concise and do not repeat yourself." but with "I believe the meaning of life is"
+ - Do not prompt with "Explain the theory of relativity." but with "Simply put, the theory of relativity states that"
+ - Do not prompt with "Ten easy steps to build a website..." but with "Building a website can be done in 10 simple steps:\n"
+
+To be able to directly prompt the models with questions / instructions, you can either:
+ - Prompt it with few-shot examples so that the model understands the task you have in mind.
+ - Finetune the models on datasets of instructions to make them more robust to input prompts.
+
+We've updated `example.py` with more sample prompts. Overall, always keep in mind that models are very sensitive to prompts (particularly when they have not been finetuned).
+
+## 3. CUDA Out of memory errors
+
+The `example.py` file pre-allocates a cache according to these settings:
+```python
+model_args: ModelArgs = ModelArgs(max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params)
+```
+
+Accounting for 14GB of memory for the model weights (7B model), this leaves 16GB available for the decoding cache which stores 2 * 2 * n_layers * max_batch_size * max_seq_len * n_heads * head_dim bytes.
+
+With default parameters, this cache was about 17GB (2 * 2 * 32 * 32 * 1024 * 32 * 128) for the 7B model.
+
+We've added command line options to `example.py` and changed the default `max_seq_len` to 512 which should allow decoding on 30GB GPUs.
+
+Feel free to lower these settings according to your hardware.
+
+## 4. Other languages
+The model was trained primarily on English, but also on a few other languages with Latin or Cyrillic alphabets.
+
+For instance, LLaMA was trained on Wikipedia for the 20 following languages: bg, ca, cs, da, de, en, es, fr, hr, hu, it, nl, pl, pt, ro, ru, sl, sr, sv, uk.
+
+LLaMA's tokenizer splits unseen characters into UTF-8 bytes, as a result, it might also be able to process other languages like Chinese or Japanese, even though they use different characters.
+
+Although the fraction of these languages in the training was negligible, LLaMA still showcases some abilities in Chinese-English translation:
+
+```
+Prompt = "J'aime le chocolat = I like chocolate\n祝你一天过得愉快 ="
+Output = "I wish you a nice day"
+```
\ No newline at end of file
diff --git a/README.md b/README.md
index 6627c0b7b..0e311916f 100755
--- a/README.md
+++ b/README.md
@@ -32,6 +32,11 @@ Different models require different MP values:
| 33B | 4 |
| 65B | 8 |
+### FAQ
+- [1. The download.sh script doesn't work on default bash in MacOS X](FAQ.md#1)
+- [2. Generations are bad!](FAQ.md#2)
+- [3. CUDA Out of memory errors](FAQ.md#3)
+- [4. Other languages](FAQ.md#4)
### Model Card
See [MODEL_CARD.md](MODEL_CARD.md)
diff --git a/example.py b/example.py
index ab056c6d1..fba9a54a5 100755
--- a/example.py
+++ b/example.py
@@ -29,11 +29,18 @@ def setup_model_parallel() -> Tuple[int, int]:
return local_rank, world_size
-def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int) -> LLaMA:
+def load(
+ ckpt_dir: str,
+ tokenizer_path: str,
+ local_rank: int,
+ world_size: int,
+ max_seq_len: int,
+ max_batch_size: int,
+) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
- assert (
- world_size == len(checkpoints)
+ assert world_size == len(
+ checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
ckpt_path = checkpoints[local_rank]
print("Loading")
@@ -41,7 +48,9 @@ def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int) -
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
- model_args: ModelArgs = ModelArgs(max_seq_len=1024, max_batch_size=32, **params)
+ model_args: ModelArgs = ModelArgs(
+ max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
+ )
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
@@ -54,14 +63,52 @@ def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int) -
return generator
-def main(ckpt_dir: str, tokenizer_path: str, temperature: float = 0.8, top_p: float = 0.95):
+def main(
+ ckpt_dir: str,
+ tokenizer_path: str,
+ temperature: float = 0.8,
+ top_p: float = 0.95,
+ max_seq_len: int = 512,
+ max_batch_size: int = 32,
+):
local_rank, world_size = setup_model_parallel()
if local_rank > 0:
- sys.stdout = open(os.devnull, 'w')
-
- generator = load(ckpt_dir, tokenizer_path, local_rank, world_size)
- prompts = ["The capital of Germany is the city of", "Here is my sonnet in the style of Shakespeare about an artificial intelligence:"]
- results = generator.generate(prompts, max_gen_len=256, temperature=temperature, top_p=top_p)
+ sys.stdout = open(os.devnull, "w")
+
+ generator = load(
+ ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
+ )
+
+ prompts = [
+ # For these prompts, the expected answer is the natural continuation of the prompt
+ "I believe the meaning of life is",
+ "Simply put, the theory of relativity states that ",
+ "Building a website can be done in 10 simple steps:\n",
+ # Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api
+ """Tweet: "I hate it when my phone battery dies."
+Sentiment: Negative
+###
+Tweet: "My day has been 👍"
+Sentiment: Positive
+###
+Tweet: "This is the link to the article"
+Sentiment: Neutral
+###
+Tweet: "This new music video was incredibile"
+Sentiment:""",
+ """Translate English to French:
+
+sea otter => loutre de mer
+
+peppermint => menthe poivrée
+
+plush girafe => girafe peluche
+
+cheese =>""",
+ ]
+ results = generator.generate(
+ prompts, max_gen_len=256, temperature=temperature, top_p=top_p
+ )
for result in results:
print(result)