Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
PIQA in Tango (#294)
Browse files Browse the repository at this point in the history
* Make TransformerMC work with the transformer toolkit

* TransformerMC upgraded to using the transformer toolkit

* Changes multiple choice to use the new TransformerTextField

* Changelog

* Treat text as pairs, even if it cuts a sentence in half.

* Keep token type IDs

* Make type ids work when they are absent

* Dropout

* Old and new implementation now live side-by-side

* Make mypy happy

* Adds config for the transformer toolkit version of MC

* More boring random seed

* Remove duplicate code

* Formatting

* More formatting

* Removing leftover debug code

* Adds PIQA for Tango

* Formatting

* Renaming the dataset to DatasetDict

* We no longer need to say "produce_results"

* After the dataset rename, we have to bump the version

* Give the piqa config a debug switch

* New syntax

* Changelog

* skip testing tango configs for now

Co-authored-by: epwalsh <[email protected]>
  • Loading branch information
dirkgr and epwalsh committed Aug 10, 2021
1 parent 4eb7c27 commit 31649f5
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added some additional `__init__()` parameters to the `T5` model in `allennlp_models.generation` for customizing
beam search and other options.
- Added a configuration file for fine-tuning `t5-11b` on CCN-DM (requires at least 8 GPUs).
- Added a configuration to train on the PIQA dataset with AllenNLP Tango.

### Fixed

Expand Down
1 change: 1 addition & 0 deletions allennlp_models/mc/tango/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from allennlp_models.mc.tango.piqa import PiqaInstances
82 changes: 82 additions & 0 deletions allennlp_models/mc/tango/piqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Dict

import datasets
import torch
from allennlp.common import cached_transformers
from allennlp.data import Field, Instance, Vocabulary
from allennlp.data.fields import ListField, TransformerTextField, IndexField
from allennlp.tango.dataset import DatasetDict
from allennlp.tango.step import Step


@Step.register("piqa_instances")
class PiqaInstances(Step):
DETERMINISTIC = True
VERSION = "003"
CACHEABLE = True

def run(
self,
tokenizer_name: str,
max_length: int = 512,
) -> DatasetDict:
tokenizer = cached_transformers.get_tokenizer(tokenizer_name)
assert tokenizer.pad_token_type_id == 0

dataset = {
split_name: [
{
"correct_alternative": instance["label"],
"alternatives": [
(instance["goal"], instance["sol1"]),
(instance["goal"], instance["sol2"]),
],
}
for instance in instances
]
for split_name, instances in datasets.load_dataset("piqa").items()
}

# This thing is so complicated because we want to call `batch_encode_plus` with all
# the strings at once.
tokenized = {
split_name: tokenizer.batch_encode_plus(
[alternative for instance in instances for alternative in instance["alternatives"]],
add_special_tokens=True,
truncation=True,
max_length=max_length,
return_token_type_ids=True,
return_attention_mask=False,
)
for split_name, instances in dataset.items()
}

result = {}
for split_name, instances in dataset.items():
tokenized_alts = tokenized[split_name]
results_per_split = []
for i, instance in enumerate(instances):
alts = ListField(
[
TransformerTextField(
torch.tensor(tokenized_alts["input_ids"][alt_index], dtype=torch.int32),
torch.tensor(
tokenized_alts["token_type_ids"][alt_index], dtype=torch.int32
),
)
for alt_index in [2 * i, 2 * i + 1]
]
)
fields: Dict[str, Field] = {"alternatives": alts}
if instance["correct_alternative"] >= 0:
fields["correct_alternative"] = IndexField(
instance["correct_alternative"], alts
)
results_per_split.append(Instance(fields))
result[split_name] = results_per_split

# make vocab
vocab = Vocabulary.empty()
vocab.add_transformer_vocab(tokenizer, "tokens")

return DatasetDict(result, vocab)
7 changes: 7 additions & 0 deletions tests/training_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@
"constituency_parser_transformer_elmo.jsonnet",
}

FOLDERS_TO_IGNORE = {
# TODO (epwalsh/dirkg): need to test tango configs differently.
"tango",
}


def find_configs():
for item in os.walk("training_config/"):
if os.path.basename(item[0]) in FOLDERS_TO_IGNORE:
continue
for pattern in ("*.json", "*.jsonnet"):
for path in glob(os.path.join(item[0], pattern)):
if os.path.basename(path) == "common.jsonnet":
Expand Down
64 changes: 64 additions & 0 deletions training_config/tango/piqa.jsonnet
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
local transformer_model = "bert-base-cased";

local debug = true;

{
"steps": {
"dataset": {
"type": "piqa_instances",
"tokenizer_name": transformer_model
},
"trained_model": {
"type": "training",
"dataset": {"ref": "dataset"},
"training_split": "train",
[if !debug then "data_loader"]: {
"batch_size": 32
},
[if debug then "data_loader"]: {
"type": "max_batches",
"max_batches_per_epoch": 7,
"inner": {
"batch_size": 5
}
},
"model": {
"type": "transformer_mc_tt",
"transformer_model": transformer_model,
},
"optimizer": {
"type": "huggingface_adamw",
"weight_decay": 0.01,
"parameter_groups": [[["bias", "LayerNorm\\.weight", "layer_norm\\.weight"], {"weight_decay": 0}]],
"lr": 1e-5,
"eps": 1e-8,
"correct_bias": true
},
"learning_rate_scheduler": {
"type": "linear_with_warmup",
"warmup_steps": 100
},
"num_epochs": if debug then 3 else 20,
"patience": 3,
"validation_metric": "+acc",
},
"evaluation": {
"type": "evaluation",
"dataset": {
"type": "ref",
"ref": "dataset"
},
"model": {
"type": "ref",
"ref": "trained_model"
},
[if debug then "data_loader"]: {
"type": "max_batches",
"max_batches_per_epoch": 7,
"inner": {
"batch_size": 5
}
},
}
}
}

0 comments on commit 31649f5

Please sign in to comment.