This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Make TransformerMC work with the transformer toolkit * TransformerMC upgraded to using the transformer toolkit * Changes multiple choice to use the new TransformerTextField * Changelog * Treat text as pairs, even if it cuts a sentence in half. * Keep token type IDs * Make type ids work when they are absent * Dropout * Old and new implementation now live side-by-side * Make mypy happy * Adds config for the transformer toolkit version of MC * More boring random seed * Remove duplicate code * Formatting * More formatting * Removing leftover debug code * Adds PIQA for Tango * Formatting * Renaming the dataset to DatasetDict * We no longer need to say "produce_results" * After the dataset rename, we have to bump the version * Give the piqa config a debug switch * New syntax * Changelog * skip testing tango configs for now Co-authored-by: epwalsh <[email protected]>
- Loading branch information
Showing
5 changed files
with
155 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from allennlp_models.mc.tango.piqa import PiqaInstances |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from typing import Dict | ||
|
||
import datasets | ||
import torch | ||
from allennlp.common import cached_transformers | ||
from allennlp.data import Field, Instance, Vocabulary | ||
from allennlp.data.fields import ListField, TransformerTextField, IndexField | ||
from allennlp.tango.dataset import DatasetDict | ||
from allennlp.tango.step import Step | ||
|
||
|
||
@Step.register("piqa_instances") | ||
class PiqaInstances(Step): | ||
DETERMINISTIC = True | ||
VERSION = "003" | ||
CACHEABLE = True | ||
|
||
def run( | ||
self, | ||
tokenizer_name: str, | ||
max_length: int = 512, | ||
) -> DatasetDict: | ||
tokenizer = cached_transformers.get_tokenizer(tokenizer_name) | ||
assert tokenizer.pad_token_type_id == 0 | ||
|
||
dataset = { | ||
split_name: [ | ||
{ | ||
"correct_alternative": instance["label"], | ||
"alternatives": [ | ||
(instance["goal"], instance["sol1"]), | ||
(instance["goal"], instance["sol2"]), | ||
], | ||
} | ||
for instance in instances | ||
] | ||
for split_name, instances in datasets.load_dataset("piqa").items() | ||
} | ||
|
||
# This thing is so complicated because we want to call `batch_encode_plus` with all | ||
# the strings at once. | ||
tokenized = { | ||
split_name: tokenizer.batch_encode_plus( | ||
[alternative for instance in instances for alternative in instance["alternatives"]], | ||
add_special_tokens=True, | ||
truncation=True, | ||
max_length=max_length, | ||
return_token_type_ids=True, | ||
return_attention_mask=False, | ||
) | ||
for split_name, instances in dataset.items() | ||
} | ||
|
||
result = {} | ||
for split_name, instances in dataset.items(): | ||
tokenized_alts = tokenized[split_name] | ||
results_per_split = [] | ||
for i, instance in enumerate(instances): | ||
alts = ListField( | ||
[ | ||
TransformerTextField( | ||
torch.tensor(tokenized_alts["input_ids"][alt_index], dtype=torch.int32), | ||
torch.tensor( | ||
tokenized_alts["token_type_ids"][alt_index], dtype=torch.int32 | ||
), | ||
) | ||
for alt_index in [2 * i, 2 * i + 1] | ||
] | ||
) | ||
fields: Dict[str, Field] = {"alternatives": alts} | ||
if instance["correct_alternative"] >= 0: | ||
fields["correct_alternative"] = IndexField( | ||
instance["correct_alternative"], alts | ||
) | ||
results_per_split.append(Instance(fields)) | ||
result[split_name] = results_per_split | ||
|
||
# make vocab | ||
vocab = Vocabulary.empty() | ||
vocab.add_transformer_vocab(tokenizer, "tokens") | ||
|
||
return DatasetDict(result, vocab) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
local transformer_model = "bert-base-cased"; | ||
|
||
local debug = true; | ||
|
||
{ | ||
"steps": { | ||
"dataset": { | ||
"type": "piqa_instances", | ||
"tokenizer_name": transformer_model | ||
}, | ||
"trained_model": { | ||
"type": "training", | ||
"dataset": {"ref": "dataset"}, | ||
"training_split": "train", | ||
[if !debug then "data_loader"]: { | ||
"batch_size": 32 | ||
}, | ||
[if debug then "data_loader"]: { | ||
"type": "max_batches", | ||
"max_batches_per_epoch": 7, | ||
"inner": { | ||
"batch_size": 5 | ||
} | ||
}, | ||
"model": { | ||
"type": "transformer_mc_tt", | ||
"transformer_model": transformer_model, | ||
}, | ||
"optimizer": { | ||
"type": "huggingface_adamw", | ||
"weight_decay": 0.01, | ||
"parameter_groups": [[["bias", "LayerNorm\\.weight", "layer_norm\\.weight"], {"weight_decay": 0}]], | ||
"lr": 1e-5, | ||
"eps": 1e-8, | ||
"correct_bias": true | ||
}, | ||
"learning_rate_scheduler": { | ||
"type": "linear_with_warmup", | ||
"warmup_steps": 100 | ||
}, | ||
"num_epochs": if debug then 3 else 20, | ||
"patience": 3, | ||
"validation_metric": "+acc", | ||
}, | ||
"evaluation": { | ||
"type": "evaluation", | ||
"dataset": { | ||
"type": "ref", | ||
"ref": "dataset" | ||
}, | ||
"model": { | ||
"type": "ref", | ||
"ref": "trained_model" | ||
}, | ||
[if debug then "data_loader"]: { | ||
"type": "max_batches", | ||
"max_batches_per_epoch": 7, | ||
"inner": { | ||
"batch_size": 5 | ||
} | ||
}, | ||
} | ||
} | ||
} |