From 31649f5776522ae60661499c595f73e1c019d72c Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Tue, 10 Aug 2021 16:25:02 -0700 Subject: [PATCH] PIQA in Tango (#294) * Make TransformerMC work with the transformer toolkit * TransformerMC upgraded to using the transformer toolkit * Changes multiple choice to use the new TransformerTextField * Changelog * Treat text as pairs, even if it cuts a sentence in half. * Keep token type IDs * Make type ids work when they are absent * Dropout * Old and new implementation now live side-by-side * Make mypy happy * Adds config for the transformer toolkit version of MC * More boring random seed * Remove duplicate code * Formatting * More formatting * Removing leftover debug code * Adds PIQA for Tango * Formatting * Renaming the dataset to DatasetDict * We no longer need to say "produce_results" * After the dataset rename, we have to bump the version * Give the piqa config a debug switch * New syntax * Changelog * skip testing tango configs for now Co-authored-by: epwalsh --- CHANGELOG.md | 1 + allennlp_models/mc/tango/__init__.py | 1 + allennlp_models/mc/tango/piqa.py | 82 ++++++++++++++++++++++++++++ tests/training_config_test.py | 7 +++ training_config/tango/piqa.jsonnet | 64 ++++++++++++++++++++++ 5 files changed, 155 insertions(+) create mode 100644 allennlp_models/mc/tango/__init__.py create mode 100644 allennlp_models/mc/tango/piqa.py create mode 100644 training_config/tango/piqa.jsonnet diff --git a/CHANGELOG.md b/CHANGELOG.md index a8e4612b8..c48b4fcb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added some additional `__init__()` parameters to the `T5` model in `allennlp_models.generation` for customizing beam search and other options. - Added a configuration file for fine-tuning `t5-11b` on CCN-DM (requires at least 8 GPUs). +- Added a configuration to train on the PIQA dataset with AllenNLP Tango. ### Fixed diff --git a/allennlp_models/mc/tango/__init__.py b/allennlp_models/mc/tango/__init__.py new file mode 100644 index 000000000..df4e85fa4 --- /dev/null +++ b/allennlp_models/mc/tango/__init__.py @@ -0,0 +1 @@ +from allennlp_models.mc.tango.piqa import PiqaInstances diff --git a/allennlp_models/mc/tango/piqa.py b/allennlp_models/mc/tango/piqa.py new file mode 100644 index 000000000..43011866c --- /dev/null +++ b/allennlp_models/mc/tango/piqa.py @@ -0,0 +1,82 @@ +from typing import Dict + +import datasets +import torch +from allennlp.common import cached_transformers +from allennlp.data import Field, Instance, Vocabulary +from allennlp.data.fields import ListField, TransformerTextField, IndexField +from allennlp.tango.dataset import DatasetDict +from allennlp.tango.step import Step + + +@Step.register("piqa_instances") +class PiqaInstances(Step): + DETERMINISTIC = True + VERSION = "003" + CACHEABLE = True + + def run( + self, + tokenizer_name: str, + max_length: int = 512, + ) -> DatasetDict: + tokenizer = cached_transformers.get_tokenizer(tokenizer_name) + assert tokenizer.pad_token_type_id == 0 + + dataset = { + split_name: [ + { + "correct_alternative": instance["label"], + "alternatives": [ + (instance["goal"], instance["sol1"]), + (instance["goal"], instance["sol2"]), + ], + } + for instance in instances + ] + for split_name, instances in datasets.load_dataset("piqa").items() + } + + # This thing is so complicated because we want to call `batch_encode_plus` with all + # the strings at once. + tokenized = { + split_name: tokenizer.batch_encode_plus( + [alternative for instance in instances for alternative in instance["alternatives"]], + add_special_tokens=True, + truncation=True, + max_length=max_length, + return_token_type_ids=True, + return_attention_mask=False, + ) + for split_name, instances in dataset.items() + } + + result = {} + for split_name, instances in dataset.items(): + tokenized_alts = tokenized[split_name] + results_per_split = [] + for i, instance in enumerate(instances): + alts = ListField( + [ + TransformerTextField( + torch.tensor(tokenized_alts["input_ids"][alt_index], dtype=torch.int32), + torch.tensor( + tokenized_alts["token_type_ids"][alt_index], dtype=torch.int32 + ), + ) + for alt_index in [2 * i, 2 * i + 1] + ] + ) + fields: Dict[str, Field] = {"alternatives": alts} + if instance["correct_alternative"] >= 0: + fields["correct_alternative"] = IndexField( + instance["correct_alternative"], alts + ) + results_per_split.append(Instance(fields)) + result[split_name] = results_per_split + + # make vocab + vocab = Vocabulary.empty() + vocab.add_transformer_vocab(tokenizer, "tokens") + + return DatasetDict(result, vocab) diff --git a/tests/training_config_test.py b/tests/training_config_test.py index 338a0026b..c73fa472a 100644 --- a/tests/training_config_test.py +++ b/tests/training_config_test.py @@ -18,9 +18,16 @@ "constituency_parser_transformer_elmo.jsonnet", } +FOLDERS_TO_IGNORE = { + # TODO (epwalsh/dirkg): need to test tango configs differently. + "tango", +} + def find_configs(): for item in os.walk("training_config/"): + if os.path.basename(item[0]) in FOLDERS_TO_IGNORE: + continue for pattern in ("*.json", "*.jsonnet"): for path in glob(os.path.join(item[0], pattern)): if os.path.basename(path) == "common.jsonnet": diff --git a/training_config/tango/piqa.jsonnet b/training_config/tango/piqa.jsonnet new file mode 100644 index 000000000..bbc8fb7df --- /dev/null +++ b/training_config/tango/piqa.jsonnet @@ -0,0 +1,64 @@ +local transformer_model = "bert-base-cased"; + +local debug = true; + +{ + "steps": { + "dataset": { + "type": "piqa_instances", + "tokenizer_name": transformer_model + }, + "trained_model": { + "type": "training", + "dataset": {"ref": "dataset"}, + "training_split": "train", + [if !debug then "data_loader"]: { + "batch_size": 32 + }, + [if debug then "data_loader"]: { + "type": "max_batches", + "max_batches_per_epoch": 7, + "inner": { + "batch_size": 5 + } + }, + "model": { + "type": "transformer_mc_tt", + "transformer_model": transformer_model, + }, + "optimizer": { + "type": "huggingface_adamw", + "weight_decay": 0.01, + "parameter_groups": [[["bias", "LayerNorm\\.weight", "layer_norm\\.weight"], {"weight_decay": 0}]], + "lr": 1e-5, + "eps": 1e-8, + "correct_bias": true + }, + "learning_rate_scheduler": { + "type": "linear_with_warmup", + "warmup_steps": 100 + }, + "num_epochs": if debug then 3 else 20, + "patience": 3, + "validation_metric": "+acc", + }, + "evaluation": { + "type": "evaluation", + "dataset": { + "type": "ref", + "ref": "dataset" + }, + "model": { + "type": "ref", + "ref": "trained_model" + }, + [if debug then "data_loader"]: { + "type": "max_batches", + "max_batches_per_epoch": 7, + "inner": { + "batch_size": 5 + } + }, + } + } +}