Skip to content

Commit

Permalink
Add FAQ.md // add command line options
Browse files Browse the repository at this point in the history
  • Loading branch information
timlacroix committed Mar 3, 2023
1 parent 76066b1 commit e6145a0
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 10 deletions.
51 changes: 51 additions & 0 deletions FAQ.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# FAQ
## <a name="1"></a>1. The download.sh script doesn't work on default bash in MacOS X:

Please see answers from theses issues:
- https://github.com/facebookresearch/llama/issues/41#issuecomment-1451290160
- https://github.com/facebookresearch/llama/issues/53#issue-1606582963


## <a name="2"></a>2. Generations are bad!

Keep in mind these models are not finetuned for question answering. As such, they should be prompted so that the expected answer is the natural continuation of the prompt.

Here are a few examples of prompts (from [issue#69](https://github.com/facebookresearch/llama/issues/69)) geared towards finetuned models, and how to modify them to get the expected results:
- Do not prompt with "What is the meaning of life? Be concise and do not repeat yourself." but with "I believe the meaning of life is"
- Do not prompt with "Explain the theory of relativity." but with "Simply put, the theory of relativity states that"
- Do not prompt with "Ten easy steps to build a website..." but with "Building a website can be done in 10 simple steps:\n"

To be able to directly prompt the models with questions / instructions, you can either:
- Prompt it with few-shot examples so that the model understands the task you have in mind.
- Finetune the models on datasets of instructions to make them more robust to input prompts.

We've updated `example.py` with more sample prompts. Overall, always keep in mind that models are very sensitive to prompts (particularly when they have not been finetuned).

## <a name="3"></a>3. CUDA Out of memory errors

The `example.py` file pre-allocates a cache according to these settings:
```python
model_args: ModelArgs = ModelArgs(max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params)
```

Accounting for 14GB of memory for the model weights (7B model), this leaves 16GB available for the decoding cache which stores 2 * 2 * n_layers * max_batch_size * max_seq_len * n_heads * head_dim bytes.

With default parameters, this cache was about 17GB (2 * 2 * 32 * 32 * 1024 * 32 * 128) for the 7B model.

We've added command line options to `example.py` and changed the default `max_seq_len` to 512 which should allow decoding on 30GB GPUs.

Feel free to lower these settings according to your hardware.

## <a name="4"></a>4. Other languages
The model was trained primarily on English, but also on a few other languages with Latin or Cyrillic alphabets.

For instance, LLaMA was trained on Wikipedia for the 20 following languages: bg, ca, cs, da, de, en, es, fr, hr, hu, it, nl, pl, pt, ro, ru, sl, sr, sv, uk.

LLaMA's tokenizer splits unseen characters into UTF-8 bytes, as a result, it might also be able to process other languages like Chinese or Japanese, even though they use different characters.

Although the fraction of these languages in the training was negligible, LLaMA still showcases some abilities in Chinese-English translation:

```
Prompt = "J'aime le chocolat = I like chocolate\n祝你一天过得愉快 ="
Output = "I wish you a nice day"
```
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ Different models require different MP values:
| 33B | 4 |
| 65B | 8 |

### FAQ
- [1. The download.sh script doesn't work on default bash in MacOS X](FAQ.md#1)
- [2. Generations are bad!](FAQ.md#2)
- [3. CUDA Out of memory errors](FAQ.md#3)
- [4. Other languages](FAQ.md#4)

### Model Card
See [MODEL_CARD.md](MODEL_CARD.md)
Expand Down
67 changes: 57 additions & 10 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,28 @@ def setup_model_parallel() -> Tuple[int, int]:
return local_rank, world_size


def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int) -> LLaMA:
def load(
ckpt_dir: str,
tokenizer_path: str,
local_rank: int,
world_size: int,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert (
world_size == len(checkpoints)
assert world_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
ckpt_path = checkpoints[local_rank]
print("Loading")
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())

model_args: ModelArgs = ModelArgs(max_seq_len=1024, max_batch_size=32, **params)
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
Expand All @@ -54,14 +63,52 @@ def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int) -
return generator


def main(ckpt_dir: str, tokenizer_path: str, temperature: float = 0.8, top_p: float = 0.95):
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 512,
max_batch_size: int = 32,
):
local_rank, world_size = setup_model_parallel()
if local_rank > 0:
sys.stdout = open(os.devnull, 'w')

generator = load(ckpt_dir, tokenizer_path, local_rank, world_size)
prompts = ["The capital of Germany is the city of", "Here is my sonnet in the style of Shakespeare about an artificial intelligence:"]
results = generator.generate(prompts, max_gen_len=256, temperature=temperature, top_p=top_p)
sys.stdout = open(os.devnull, "w")

generator = load(
ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
)

prompts = [
# For these prompts, the expected answer is the natural continuation of the prompt
"I believe the meaning of life is",
"Simply put, the theory of relativity states that ",
"Building a website can be done in 10 simple steps:\n",
# Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api
"""Tweet: "I hate it when my phone battery dies."
Sentiment: Negative
###
Tweet: "My day has been 👍"
Sentiment: Positive
###
Tweet: "This is the link to the article"
Sentiment: Neutral
###
Tweet: "This new music video was incredibile"
Sentiment:""",
"""Translate English to French:
sea otter => loutre de mer
peppermint => menthe poivrée
plush girafe => girafe peluche
cheese =>""",
]
results = generator.generate(
prompts, max_gen_len=256, temperature=temperature, top_p=top_p
)

for result in results:
print(result)
Expand Down

0 comments on commit e6145a0

Please sign in to comment.