Skip to content

Commit

Permalink
Add top_k, fix defaults, decode <|endoftext|>, stop at separators, ge…
Browse files Browse the repository at this point in the history
…nerate prompts one at a time, print each token to stdout as it's generated, use different prompts, and like 20 other things. Committing this at the request of someone on Twitter.
  • Loading branch information
shawwn committed Mar 6, 2023
1 parent e22bdb2 commit 40d99d3
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 53 deletions.
116 changes: 80 additions & 36 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from llama import ModelArgs, Transformer, Tokenizer, LLaMA


def setup_model_parallel() -> Tuple[int, int]:
def setup_model_parallel(seed: int) -> Tuple[int, int]:
local_rank = int(os.environ.get("LOCAL_RANK", -1))
world_size = int(os.environ.get("WORLD_SIZE", -1))

Expand All @@ -25,7 +25,7 @@ def setup_model_parallel() -> Tuple[int, int]:
torch.cuda.set_device(local_rank)

# seed must be the same in all processes
torch.manual_seed(1)
torch.manual_seed(seed)
return local_rank, world_size


Expand Down Expand Up @@ -66,53 +66,97 @@ def load(
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95,
temperature: float = 0.7,
# top_p: float = 0.95,
top_p: float = 0.0,
top_k: int = 40,
repetition_penalty: float = (1 / 0.85),
max_seq_len: int = 512,
max_gen_len: int = 256,
max_batch_size: int = 32,
seed: int = 1,
count: int = 5,
):
local_rank, world_size = setup_model_parallel()
local_rank, world_size = setup_model_parallel(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")

print("\n")
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
print(json.dumps(dict(
seed=seed,
temp=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
max_seq_len=max_seq_len,
max_gen_len=max_gen_len,
)))
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")


generator = load(
ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
)

prompts = [
# For these prompts, the expected answer is the natural continuation of the prompt
"I believe the meaning of life is",
"Simply put, the theory of relativity states that ",
"Building a website can be done in 10 simple steps:\n",
# Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api
"""Tweet: "I hate it when my phone battery dies."
Sentiment: Negative
###
Tweet: "My day has been 👍"
Sentiment: Positive
###
Tweet: "This is the link to the article"
Sentiment: Neutral
###
Tweet: "This new music video was incredibile"
Sentiment:""",
"""Translate English to French:
sea otter => loutre de mer
peppermint => menthe poivrée
plush girafe => girafe peluche
cheese =>""",
]
results = generator.generate(
prompts, max_gen_len=256, temperature=temperature, top_p=top_p
)

for result in results:
print(result)
print("\n==================================\n")
# "I believe the meaning of life is",
# "Simply put, the theory of relativity states that",
# "Building a website can be done in a few simple steps:\n1.",
# "Here's how to build it in a few simple steps:\n1.",

"This is Captain Jean-Luc Picard",
"I am Lieutenant Commander Data",
"The Klingons are attacking",

# # Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api
# """Tweet: "I hate it when my phone battery dies."
# Sentiment: Negative
# ###
# Tweet: "My day has been 👍"
# Sentiment: Positive
# ###
# Tweet: "This is the link to the article"
# Sentiment: Neutral
# ###
# Tweet: "This new music video was incredibile"
# Sentiment:""",
# """Translate English to French:
#
# sea otter => loutre de mer
#
# peppermint => menthe poivrée
#
# plush girafe => girafe peluche
#
# cheese =>""",
]
i = 0
while i < count or count <= 0:
i += 1
for prompt in prompts:
print(f"\n============== sample {i} =================\n")
width = 0
def callback(text):
nonlocal width
text = text.replace('\n', '\n\n')
chars = []
for i, c in enumerate(text):
if c == ' ' and width >= 60:
chars.append('\n')
width = 0
else:
width += 1
chars.append(c)
if c == '\n':
width = 0
text = ''.join(chars)
print(text, end='', flush=True)
text, = generator.generate(
[prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, token_callback=callback,
)


if __name__ == "__main__":
Expand Down
69 changes: 52 additions & 17 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ def generate(
self,
prompts: List[str],
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
temperature: float = 0.7,
top_k: int = 40,
top_p: float = 0.0, #0.95,
repetition_penalty: float = (1.0 / 0.85),
token_callback=None,
) -> List[str]:
bsz = len(prompts)
params = self.model.params
Expand All @@ -38,40 +41,72 @@ def generate(
input_text_mask = tokens != self.tokenizer.pad_id
start_pos = min_prompt_size
prev_pos = 0
prev_text = ''
for cur_pos in range(start_pos, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)

# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
logits_new = logits.clone()
batch_size = len(tokens)
for i in range(batch_size):
for token in set(tokens[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if logits[i, token] < 0:
logits_new[i, token] = logits[i, token] * repetition_penalty
else:
logits_new[i, token] = logits[i, token] / repetition_penalty
logits = logits_new

if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
next_token = sample(probs, top_p=top_p, top_k=top_k)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
if next_token == self.tokenizer.eos_id:
break
tokens[:, cur_pos] = next_token
if token_callback is not None:
assert len(prompts) == 1
text, = self.decode(tokens)
#assert text.startswith(prev_text)
if not text.startswith(prev_text):
# Some kind of bogus token generation; abort early.
break
next_word = text[len(prev_text):]
prev_text = text
token_callback(next_word)
prev_pos = cur_pos

return self.decode(tokens)

def decode(self, tokens):
decoded = []
for i, t in enumerate(tokens.tolist()):
# cut to max gen len
t = t[: len(prompt_tokens[i]) + max_gen_len]
# cut to eos tok if any
try:
t = t[: t.index(self.tokenizer.eos_id)]
except ValueError:
pass
t = [token for token in t if token != -1]
# # cut to max gen len
# t = t[: len(prompt_tokens[i]) + max_gen_len]
while self.tokenizer.eos_id in t:
pos = t.index(self.tokenizer.eos_id)
t[pos:pos+1] = self.tokenizer.encode('\n<|endoftext|>\n', bos=False, eos=False)
decoded.append(self.tokenizer.decode(t))
return decoded


def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
def sample(probs, top_p=0.0, top_k=40):
if top_k > 0:
probs_sort, probs_idx = torch.topk(probs, top_k)
else:
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
if top_p > 0.0:
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

0 comments on commit 40d99d3

Please sign in to comment.