diff --git a/download.sh b/download.sh index db520dcfe..4936df6d9 100644 --- a/download.sh +++ b/download.sh @@ -1,9 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the GNU General Public License version 3. -PRESIGNED_URL="" # replace with presigned url from email -MODEL_SIZE="7B,13B,30B,65B" # edit this list with the model sizes you wish to download -TARGET_FOLDER="" # where all files should end up +PRESIGNED_URL="https://agi.gpt4.org/llama/LLaMA/*" + +MODEL_SIZE="7B,13B" #,30B,65B" # edit this list with the model sizes you wish to download +TARGET_FOLDER="./" # where all files should end up declare -A N_SHARD_DICT @@ -18,16 +19,14 @@ wget ${PRESIGNED_URL/'*'/"tokenizer_checklist.chk"} -O ${TARGET_FOLDER}"/tokeniz (cd ${TARGET_FOLDER} && md5sum -c tokenizer_checklist.chk) -for i in ${MODEL_SIZE//,/ } -do - echo "Downloading ${i}" - mkdir -p ${TARGET_FOLDER}"/${i}" - for s in $(seq -f "0%g" 0 ${N_SHARD_DICT[$i]}) - do - wget ${PRESIGNED_URL/'*'/"${i}/consolidated.${s}.pth"} -O ${TARGET_FOLDER}"/${i}/consolidated.${s}.pth" - done - wget ${PRESIGNED_URL/'*'/"${i}/params.json"} -O ${TARGET_FOLDER}"/${i}/params.json" - wget ${PRESIGNED_URL/'*'/"${i}/checklist.chk"} -O ${TARGET_FOLDER}"/${i}/checklist.chk" - echo "Checking checksums" - (cd ${TARGET_FOLDER}"/${i}" && md5sum -c checklist.chk) -done \ No newline at end of file +for i in ${MODEL_SIZE//,/ }; do + echo "Downloading ${i}" + mkdir -p ${TARGET_FOLDER}"/${i}" + for s in $(seq -f "0%g" 0 ${N_SHARD_DICT[$i]}); do + wget ${PRESIGNED_URL/'*'/"${i}/consolidated.${s}.pth"} -O ${TARGET_FOLDER}"/${i}/consolidated.${s}.pth" + done + wget ${PRESIGNED_URL/'*'/"${i}/params.json"} -O ${TARGET_FOLDER}"/${i}/params.json" + wget ${PRESIGNED_URL/'*'/"${i}/checklist.chk"} -O ${TARGET_FOLDER}"/${i}/checklist.chk" + echo "Checking checksums" + (cd ${TARGET_FOLDER}"/${i}" && md5sum -c checklist.chk) +done diff --git a/example.py b/example.py index fba9a54a5..60ac5a464 100755 --- a/example.py +++ b/example.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +#atforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the GNU General Public License version 3. from typing import Tuple @@ -11,25 +11,27 @@ from pathlib import Path -from fairscale.nn.model_parallel.initialize import initialize_model_parallel +frcale.nn.model_parallel.initialize import initialize_model_parallel from llama import ModelArgs, Transformer, Tokenizer, LLaMA -def setup_model_parallel() -> Tuple[int, int]: +def setup_model_parallel(seed: int) -> Tuple[int, int]: local_rank = int(os.environ.get("LOCAL_RANK", -1)) world_size = int(os.environ.get("WORLD_SIZE", -1)) - torch.distributed.init_process_group("nccl") + .distributed.init_process_group("nccl") initialize_model_parallel(world_size) torch.cuda.set_device(local_rank) # seed must be the same in all processes - torch.manual_seed(1) + torch.manual_seed(seed) return local_rank, world_size + def load( + ckpt_dir: str, tokenizer_path: str, local_rank: int, @@ -41,7 +43,7 @@ def load( checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert world_size == len( checkpoints - ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" + size}" ckpt_path = checkpoints[local_rank] print("Loading") checkpoint = torch.load(ckpt_path, map_location="cpu") @@ -49,10 +51,10 @@ def load( params = json.loads(f.read()) model_args: ModelArgs = ModelArgs( - max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params + q_len, max_batch_size=max_batch_size, **params ) tokenizer = Tokenizer(model_path=tokenizer_path) - model_args.vocab_size = tokenizer.n_words + model_args.vocab_size = tokenizer.n_words torch.set_default_tensor_type(torch.cuda.HalfTensor) model = Transformer(model_args) torch.set_default_tensor_type(torch.FloatTensor) @@ -66,54 +68,103 @@ def load( def main( ckpt_dir: str, tokenizer_path: str, - temperature: float = 0.8, - top_p: float = 0.95, + temperature: float = 0.7, + # top_p: float = 0.95, + top_p: float = 0.0, + top_k: int = 40, + repetition_penalty: float = (1 / 0.85), max_seq_len: int = 512, + max_gen_len: int = 256, max_batch_size: int = 32, -): - local_rank, world_size = setup_model_parallel() - if local_rank > 0: - sys.stdout = open(os.devnull, "w") + seed: int = 1, + count: int = 5, +): n") + width = 0 + def callback(text): + nonlocal width + text = text.replace('\n', '\n\n') + chars = [] + for i, c in enumerate(text): + if c == ' ' and width >= 60: + chars.append('\n') + width = 0 + else: + width += 1 + chars.append(c) + if c == '\n': + width = 0 + text = ''.join(chars) + print(text, end='', flush=True) + text, = generator.generate( + [prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, token_callback=callback, + ) - generator = load( - ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size - ) - prompts = [ +if __name__ == "__main__": + fire.Fire(main)=======\ = [ # For these prompts, the expected answer is the natural continuation of the prompt - "I believe the meaning of life is", - "Simply put, the theory of relativity states that ", - "Building a website can be done in 10 simple steps:\n", - # Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api - """Tweet: "I hate it when my phone battery dies." -Sentiment: Negative -### -Tweet: "My day has been 👍" -Sentiment: Positive -### -Tweet: "This is the link to the article" -Sentiment: Neutral -### -Tweet: "This new music video was incredibile" -Sentiment:""", - """Translate English to French: - -sea otter => loutre de mer - -peppermint => menthe poivrée - -plush girafe => girafe peluche - -cheese =>""", + + # "I believe the meaning of life is", + # "Simply put, the theory of relativity states that", + # "Building a website can be done in a few simple steps:\n1.", + # "Here's how to build it in a few simple steps:\n1.", + + "This is Captain Jean-Luc Picard", + "I am Lieutenant Commander Data", + "The Klingons are attacking", + +# # Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api +# """Tweet: "I hate it when my phone battery dies." +# Sentiment: Negative +# ### +# Tweet: "My day has been 👍" +# Sentiment: Positive +# ### +# Tweet: "This is the link to the article" + +# Sentiment: Neutral +# ### +# Tweet: "This new music video was incredibile" +# Sentiment:""", + +# """Translate English to French: +# +# sea otter => loutre de mer +# +# peppermint => menthe poivrée +# +# plush girafe => girafe peluche +# +# cheese =>""", + ] - results = generator.generate( - prompts, max_gen_len=256, temperature=temperature, top_p=top_p - ) - for result in results: - print(result) - print("\n==================================\n") + i = 0 + while i < count or count <= 0: + i += 1 + for prompt in prompts: + print(f"\n============== sample {i} ========== promptst = open(os.devnull, "w") -if __name__ == "__main__": - fire.Fire(main) + print("\n") + print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + print(json.dumps(dict( + seed=seed, + temp=temperature, + top_p=top_p, + top_k=top_k, + repetition_penalty=repetition_penalty, + max_seq_len=max_seq_len, + max_gen_len=max_gen_len, + ))) + print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + + + generator = load( + ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size + ) + + ys.stdou (seed) + if local_rank > 0: + s + local_rank, world_size = setup_model_parallel max_seq_len=max_se ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_ torchom fairs Copyright (c) Meta Pl diff --git a/llama/generation.py b/llama/generation.py index 3abd3edb1..ca3dfe32f 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -18,8 +18,11 @@ def generate( self, prompts: List[str], max_gen_len: int, - temperature: float = 0.8, - top_p: float = 0.95, + temperature: float = 0.7, + top_k: int = 40, + top_p: float = 0.0, # 0.95, + repetition_penalty: float = (1.0 / 0.85), + token_callback=None, ) -> List[str]: bsz = len(prompts) params = self.model.params @@ -27,7 +30,9 @@ def generate( prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + min_prompt_size = min([len(t) for t in prompt_tokens]) + max_prompt_size = max([len(t) for t in prompt_tokens]) total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) @@ -38,11 +43,26 @@ def generate( input_text_mask = tokens != self.tokenizer.pad_id start_pos = min_prompt_size prev_pos = 0 + prev_text = '' for cur_pos in range(start_pos, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + + # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + logits_new = logits.clone() + batch_size = len(tokens) + for i in range(batch_size): + for token in set(tokens[i].tolist()): + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if logits[i, token] < 0: + logits_new[i, token] = logits[i, token] * repetition_penalty + else: + logits_new[i, token] = logits[i, token] / repetition_penalty + logits = logits_new + if temperature > 0: probs = torch.softmax(logits / temperature, dim=-1) - next_token = sample_top_p(probs, top_p) + next_token = sample(probs, top_p=top_p, top_k=top_k) else: next_token = torch.argmax(logits, dim=-1) next_token = next_token.reshape(-1) @@ -50,28 +70,46 @@ def generate( next_token = torch.where( input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token ) + if next_token == self.tokenizer.eos_id: + break tokens[:, cur_pos] = next_token + if token_callback is not None: + assert len(prompts) == 1 + text, = self.decode(tokens) + # assert text.startswith(prev_text) + if not text.startswith(prev_text): + # Some kind of bogus token generation; abort early. + break + next_word = text[len(prev_text):] + prev_text = text + token_callback(next_word) prev_pos = cur_pos + return self.decode(tokens) + + def decode(self, tokens): decoded = [] for i, t in enumerate(tokens.tolist()): - # cut to max gen len - t = t[: len(prompt_tokens[i]) + max_gen_len] - # cut to eos tok if any - try: - t = t[: t.index(self.tokenizer.eos_id)] - except ValueError: - pass + t = [token for token in t if token != -1] + # # cut to max gen len + # t = t[: len(prompt_tokens[i]) + max_gen_len] + while self.tokenizer.eos_id in t: + pos = t.index(self.tokenizer.eos_id) + t[pos:pos+1] = self.tokenizer.encode('\n<|endoftext|>\n', bos=False, eos=False) decoded.append(self.tokenizer.decode(t)) return decoded -def sample_top_p(probs, p): - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) +def sample(probs, top_p=0.0, top_k=40): + if top_k > 0: + probs_sort, probs_idx = torch.topk(probs, top_k) + else: + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + if top_p > 0.0: + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token