Skip to content

Commit

Permalink
make it run
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 7, 2023
1 parent 2a87eba commit 316a360
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 14 deletions.
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,24 @@ toolformer = Toolformer(
model_seq_len = 256,
teach_tool_prompt = prompt,
tool_id = 'Calendar',
tool = Calendar
tool = Calendar,
finetune = True
)

# invoking this will
# (1) prompt the model with your inputs (data), inserted into [input] tag
# (2) with the sampled outputs, filter out the ones that made proper API calls
# (3) execute the API calls with the `tool` given
# (4) filter with the specialized filter function (which can be used independently as shown in the next section)
# (5) fine-tune on the filtered results

filtered_results = toolformer(data)
toolformer(data)

# then finetune with token ids at
# -> filtered_results.filtered_tokens_without_api_response
# (5) complete this with toolformer.finetune(filtered_results) - and return all statistics
# then, once you see the 'finetune complete' message

response = toolformer.sample_model_with_api_calls("How many days until the next new years?")

# hopefully you see it invoke the calendar and utilize the response of the api call...

```

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'toolformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.22',
version = '0.0.24',
license='MIT',
description = 'Toolformer - Pytorch',
author = 'Phil Wang',
Expand Down
118 changes: 110 additions & 8 deletions toolformer_pytorch/toolformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def sample(
select_api_start_id_top_k = 10,
):
device = next(model.parameters()).device
positions = positions.clone()
max_seq_len = seq_len + 1

# validate
Expand All @@ -258,7 +257,11 @@ def sample(

# sampling positions - different sequences have different cursors

positions = default(positions, torch.zeros((batch_size,), device = device, dtype = torch.long))
if exists(positions):
positions = positions.clone()
else:
positions = torch.zeros((batch_size,), device = device, dtype = torch.long)

assert (positions <= (prime_length + 1)).all() and (positions <= max_seq_len).all(), 'all positions must be less then initial prime length as well as the total sequence length + 1 (plus one for noop if one sequence finished sampling before the other)'

# eval model
Expand Down Expand Up @@ -516,7 +519,7 @@ def loss_fn(weight, probs):
selected_indices = indices[selected_mask]

ret = FilteredResults(
selected_mask.sum().item()
selected_mask.sum().item(),
(~selected_mask).sum().item(),
selected_indices,
selected_mask,
Expand Down Expand Up @@ -563,6 +566,22 @@ def PromptDataloader(ds: Dataset, *args, padding_value = 0, **kwargs):
collate_fn = partial(prompt_collate_fn, padding_value = padding_value)
return DataLoader(ds, *args, collate_fn = collate_fn, **kwargs)

class FinetuneDataset(Dataset):
def __init__(
self,
tokens: torch.Tensor
):
self.tokens = tokens

def __len__(self):
return len(self.tokens)

def __getitem__(self, idx):
return self.tokens[idx]

def FinetuneDataloader(ds: Dataset, *args, padding_value = 0, **kwargs):
return DataLoader(ds, *args, collate_fn = partial(pad_sequence, padding_value = padding_value), **kwargs)

# classes

@beartype
Expand All @@ -585,8 +604,16 @@ def __init__(
model_seq_len = 2048,
tokenizer_encode: Callable = tokenizer.encode,
tokenizer_decode: Callable = tokenizer.decode,
post_prompt_callback: Callable = identity,
prompt_input_tag: str = DEFAULT_PROMPT_INPUT_TAG,
exclude_filters: dict[str, Callable[[str], bool]] = dict()
exclude_filters: dict[str, Callable[[str], bool]] = dict(),
finetune = False,
finetune_lr = 1e-4,
finetune_wd = 1e-2,
finetune_betas = (0.9, 0.99),
finetune_eps = 1e-8,
finetune_epochs = 3,
finetune_batch_size = 16
):
super().__init__()
self.model = model
Expand All @@ -596,6 +623,8 @@ def __init__(
self.prompt_batch_size = prompt_batch_size
self.prompt_input_tag = prompt_input_tag

self.post_prompt_callback = post_prompt_callback # for easy mocking

self.tokenizer_encode = tokenizer_encode
self.tokenizer_decode = tokenizer_decode
self.tokenizer_encode_to_tensor = lambda s: torch.tensor(tokenizer_encode(s)).long()
Expand Down Expand Up @@ -631,6 +660,22 @@ def __init__(
self.teach_tool_prompt = teach_tool_prompt
self.exclude_filters = exclude_filters

self.should_finetune = finetune

if not finetune:
return

self.finetune_batch_size = finetune_batch_size
self.finetune_epochs = finetune_epochs

self.optimizer = get_optimizer(
model.parameters(),
lr = finetune_lr,
wd = finetune_wd,
betas = finetune_betas,
eps = finetune_eps
)

def generate_data_with_api_calls(
self,
data: List[str],
Expand Down Expand Up @@ -706,22 +751,46 @@ def filter_and_keep_only_first_api_call(

return included, excluded

@torch.no_grad()
def sample_model_with_api_calls(
self,
prime: torch.Tensor,
prime: Union[torch.Tensor, str],
occurrence = 1,
**kwargs
):
self.model.eval()

prime_is_str = isinstance(prime, str)

if prime_is_str:
prime = self.tokenizer_encode(prime)
prime = torch.tensor(prime).long()
prime = rearrange(prime, 'n -> 1 n')

assert prime.shape[0] == 1, 'only one at a time for now'

invoke_tools_ = partial(invoke_tools, self.registry)

def call_apis(t: torch.Tensor):
t = self.tokenizer_decode(t[0])
t = invoke_tools_(t)
t = self.tokenizer_encode_to_tensor(t)
return rearrange(t, 'n -> 1 n')

output = sample_with_api_call(
model = self.model,
prime = prime,
seq_len = self.model_seq_len,
call_apis = partial(invoke_tools, self.registry),
call_apis = call_apis,
api_end_token_id = self.api_stop_id,
occurrence = occurrence,
**kwargs
)

return output
if not prime_is_str:
return output

return self.tokenizer_decode(output[0])

def make_api_calls(
self,
Expand Down Expand Up @@ -764,17 +833,50 @@ def filter_by_api_responses(

return filtered_results

def finetune(
self,
filtered_results: Union[FilteredResults, torch.Tensor]
):
self.model.train()

if isinstance(filtered_results, FilteredResults):
filtered_results = filtered_results.filtered_tokens_without_api_response

dataset = FinetuneDataset(tokens = filtered_results)
dl = FinetuneDataloader(dataset, batch_size = self.finetune_batch_size, shuffle = True)

for epoch in tqdm(range(self.finetune_epochs), desc = 'finetune epochs'):
for batch in dl:
inp, labels = batch[:, :-1], batch[:, 1:]

logits = self.model(inp)
logits = rearrange(logits, 'b n c -> b c n')

loss = F.cross_entropy(logits, labels, ignore_index = self.pad_id)
loss.backward()

print(f'loss: {loss.item()}')
self.optimizer.step()
self.optimizer.zero_grad()

print(f'finished finetuning on {len(dataset)} filtered samples')

def forward(
self,
data: List[str]
):
data_with_api_calls = self.generate_data_with_api_calls(data)

data_with_api_calls = self.post_prompt_callback(data_with_api_calls)

filtered_data, filtered_data_with_api_calls = self.filter_and_keep_only_first_api_call(data, data_with_api_calls)

assert len(filtered_data_with_api_calls) > 0, 'your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering'

data_with_responses = self.make_api_calls(filtered_data_with_api_calls)
filtered_results = self.filter_by_api_responses(filtered_data, filtered_data_with_api_calls, data_with_responses)

return filtered_results
if not self.should_finetune:
return filtered_results

self.finetune(filtered_results)

0 comments on commit 316a360

Please sign in to comment.