Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Adding ability to choose validation DatasetReader with "predict" (#3033)
Browse files Browse the repository at this point in the history
* Fixing name of default ATIS predictor

* Enabling using the validation DatasetReader with the 'predict' command

* Adding test case without any DatasetReader

* Adding a 'dataset_reader_to_load' option to 'Predictor.from_path'

* Removing unused import

* Changing validation to be the default dataset reader; Adding flag to override dataset reader choice

* Fixing documentation
  • Loading branch information
danieldeutsch authored and matt-gardner committed Jul 24, 2019
1 parent 30c4271 commit 417a757
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 10 deletions.
28 changes: 24 additions & 4 deletions allennlp/commands/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
usage: allennlp predict [-h] [--output-file OUTPUT_FILE]
[--weights-file WEIGHTS_FILE]
[--batch-size BATCH_SIZE] [--silent]
[--cuda-device CUDA_DEVICE] [--use-dataset-reader]
[--cuda-device CUDA_DEVICE]
[--use-dataset-reader]
[--dataset-reader-choice {train,validation}]
[-o OVERRIDES] [--predictor PREDICTOR]
[--include-package INCLUDE_PACKAGE]
archive_file input_file
Expand All @@ -31,7 +33,14 @@
--cuda-device CUDA_DEVICE
id of GPU to use (if any)
--use-dataset-reader Whether to use the dataset reader of the original
model to load Instances
model to load Instances. The validation dataset
reader will be used if it exists, otherwise it will
fall back to the train dataset reader. This
behavior can be overridden with the
--dataset-reader-choice flag.
--dataset-reader-choice {train,validation}
Indicates which model dataset reader to use if the
--use-dataset-reader flag is set.
-o OVERRIDES, --overrides OVERRIDES
a JSON structure used to override the experiment
configuration
Expand Down Expand Up @@ -76,7 +85,17 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar

subparser.add_argument('--use-dataset-reader',
action='store_true',
help='Whether to use the dataset reader of the original model to load Instances')
help='Whether to use the dataset reader of the original model to load Instances. '
'The validation dataset reader will be used if it exists, otherwise it will '
'fall back to the train dataset reader. This behavior can be overridden'
'with the --dataset-reader-choice flag.')

subparser.add_argument('--dataset-reader-choice',
type=str,
choices=['train', 'validation'],
default='validation',
help='Indicates which model dataset reader to use if the --use-dataset-reader '
'flag is set.')

subparser.add_argument('-o', '--overrides',
type=str,
Expand All @@ -98,7 +117,8 @@ def _get_predictor(args: argparse.Namespace) -> Predictor:
cuda_device=args.cuda_device,
overrides=args.overrides)

return Predictor.from_archive(archive, args.predictor)
return Predictor.from_archive(archive, args.predictor,
dataset_reader_to_load=args.dataset_reader_choice)


class _PredictManager:
Expand Down
23 changes: 17 additions & 6 deletions allennlp/predictors/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# a mapping from model `type` to the default Predictor for that type
DEFAULT_PREDICTORS = {
'atis_parser' : 'atis_parser',
'atis_parser' : 'atis-parser',
'basic_classifier': 'text_classifier',
'biaffine_parser': 'biaffine-dependency-parser',
'bidaf': 'machine-comprehension',
Expand Down Expand Up @@ -228,7 +228,8 @@ def _batch_json_to_instances(self, json_dicts: List[JsonDict]) -> List[Instance]
return instances

@classmethod
def from_path(cls, archive_path: str, predictor_name: str = None, cuda_device: int = -1) -> 'Predictor':
def from_path(cls, archive_path: str, predictor_name: str = None, cuda_device: int = -1,
dataset_reader_to_load: str = "validation") -> 'Predictor':
"""
Instantiate a :class:`Predictor` from an archive path.
Expand All @@ -245,19 +246,26 @@ def from_path(cls, archive_path: str, predictor_name: str = None, cuda_device: i
cuda_device: ``int``, optional (default=-1)
If `cuda_device` is >= 0, the model will be loaded onto the
corresponding GPU. Otherwise it will be loaded onto the CPU.
dataset_reader_to_load: ``str``, optional (default="validation")
Which dataset reader to load from the archive, either "train" or
"validation".
Returns
-------
A Predictor instance.
"""
return Predictor.from_archive(load_archive(archive_path, cuda_device=cuda_device), predictor_name)
return Predictor.from_archive(load_archive(archive_path, cuda_device=cuda_device), predictor_name,
dataset_reader_to_load=dataset_reader_to_load)

@classmethod
def from_archive(cls, archive: Archive, predictor_name: str = None) -> 'Predictor':
def from_archive(cls, archive: Archive, predictor_name: str = None,
dataset_reader_to_load: str = "validation") -> 'Predictor':
"""
Instantiate a :class:`Predictor` from an :class:`~allennlp.models.archival.Archive`;
that is, from the result of training a model. Optionally specify which `Predictor`
subclass; otherwise, the default one for the model will be used.
subclass; otherwise, the default one for the model will be used. Optionally specify
which :class:`DatasetReader` should be loaded; otherwise, the validation one will be used
if it exists followed by the training dataset reader.
"""
# Duplicate the config so that the config inside the archive doesn't get consumed
config = archive.config.duplicate()
Expand All @@ -269,7 +277,10 @@ def from_archive(cls, archive: Archive, predictor_name: str = None) -> 'Predicto
f"Please specify a predictor explicitly.")
predictor_name = DEFAULT_PREDICTORS[model_type]

dataset_reader_params = config["dataset_reader"]
if dataset_reader_to_load == "validation" and "validation_dataset_reader" in config:
dataset_reader_params = config["validation_dataset_reader"]
else:
dataset_reader_params = config["dataset_reader"]
dataset_reader = DatasetReader.from_params(dataset_reader_params)

model = archive.model
Expand Down
75 changes: 75 additions & 0 deletions allennlp/tests/commands/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def setUp(self):
self.bidaf_model_path = (self.FIXTURES_ROOT / "bidaf" /
"serialization" / "model.tar.gz")
self.bidaf_data_path = self.FIXTURES_ROOT / 'data' / 'squad.json'
self.atis_model_path = (self.FIXTURES_ROOT / "semantic_parsing" / "atis" /
"serialization" / "model.tar.gz")
self.atis_data_path = self.FIXTURES_ROOT / 'data' / 'atis' / 'sample.json'
self.tempdir = pathlib.Path(tempfile.mkdtemp())
self.infile = self.tempdir / "inputs.txt"
self.outfile = self.tempdir / "outputs.txt"
Expand Down Expand Up @@ -107,6 +110,78 @@ def test_using_dataset_reader_works_with_known_model(self):

shutil.rmtree(self.tempdir)

def test_uses_correct_dataset_reader(self):
# pylint: disable=protected-access
# The ATIS archive has both a training and validation ``DatasetReader``
# with different values for ``keep_if_unparseable`` (``True`` for validation
# and ``False`` for training). We create a new ``Predictor`` class that
# outputs this value so we can test which ``DatasetReader`` was used.
@Predictor.register('test-predictor')
class _TestPredictor(Predictor):
# pylint: disable=abstract-method
def dump_line(self, outputs: JsonDict) -> str:
data = {'keep_if_unparseable': self._dataset_reader._keep_if_unparseable} # type: ignore
return json.dumps(data) + '\n'

# --use-dataset-reader argument only should use validation
sys.argv = ["run.py", # executable
"predict", # command
str(self.atis_model_path),
str(self.atis_data_path), # input_file
"--output-file", str(self.outfile),
"--silent",
"--predictor", "test-predictor",
"--use-dataset-reader"]
main()
assert os.path.exists(self.outfile)
with open(self.outfile, 'r') as f:
results = [json.loads(line) for line in f]
assert results[0]['keep_if_unparseable'] is True

# --use-dataset-reader, override with train
sys.argv = ["run.py", # executable
"predict", # command
str(self.atis_model_path),
str(self.atis_data_path), # input_file
"--output-file", str(self.outfile),
"--silent",
"--predictor", "test-predictor",
"--use-dataset-reader",
"--dataset-reader-choice", "train"]
main()
assert os.path.exists(self.outfile)
with open(self.outfile, 'r') as f:
results = [json.loads(line) for line in f]
assert results[0]['keep_if_unparseable'] is False

# --use-dataset-reader, override with train
sys.argv = ["run.py", # executable
"predict", # command
str(self.atis_model_path),
str(self.atis_data_path), # input_file
"--output-file", str(self.outfile),
"--silent",
"--predictor", "test-predictor",
"--use-dataset-reader",
"--dataset-reader-choice", "validation"]
main()
assert os.path.exists(self.outfile)
with open(self.outfile, 'r') as f:
results = [json.loads(line) for line in f]
assert results[0]['keep_if_unparseable'] is True

# No --use-dataset-reader flag, fails because the loading logic
# is not implemented in the testing predictor
sys.argv = ["run.py", # executable
"predict", # command
str(self.atis_model_path),
str(self.atis_data_path), # input_file
"--output-file", str(self.outfile),
"--silent",
"--predictor", "test-predictor"]
with self.assertRaises(NotImplementedError):
main()

def test_batch_prediction_works_with_known_model(self):
with open(self.infile, 'w') as f:
f.write("""{"passage": "the seahawks won the super bowl in 2016", """
Expand Down
16 changes: 16 additions & 0 deletions allennlp/tests/predictors/predictor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@ def test_from_archive_does_not_consume_params(self):
# If it consumes the params, this will raise an exception
Predictor.from_archive(archive, 'machine-comprehension')

def test_loads_correct_dataset_reader(self):
# pylint: disable=protected-access
# The ATIS archive has both training and validation ``DatasetReaders``. The
# ``keep_if_unparseable`` argument has a different value in each of them
# (``True`` for validation, ``False`` for training).
archive = load_archive(self.FIXTURES_ROOT / 'semantic_parsing' / 'atis' / 'serialization' / 'model.tar.gz')

predictor = Predictor.from_archive(archive, 'atis-parser')
assert predictor._dataset_reader._keep_if_unparseable is True

predictor = Predictor.from_archive(archive, 'atis-parser', dataset_reader_to_load='train')
assert predictor._dataset_reader._keep_if_unparseable is False

predictor = Predictor.from_archive(archive, 'atis-parser', dataset_reader_to_load='validation')
assert predictor._dataset_reader._keep_if_unparseable is True

def test_get_gradients(self):
inputs = {
"premise": "I always write unit tests",
Expand Down

0 comments on commit 417a757

Please sign in to comment.