diff --git a/CHANGELOG.md b/CHANGELOG.md index ba073366b..e8cab1d8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Changed + +- Seperate start/end token check in `Seq2SeqDatasetReader` for source and target tokenizers. + ## [v2.7.0](https://github.com/allenai/allennlp-models/releases/tag/v2.7.0) - 2021-09-01 ### Added diff --git a/allennlp_models/generation/dataset_readers/seq2seq.py b/allennlp_models/generation/dataset_readers/seq2seq.py index 8e00188b8..4b0330ae8 100644 --- a/allennlp_models/generation/dataset_readers/seq2seq.py +++ b/allennlp_models/generation/dataset_readers/seq2seq.py @@ -103,22 +103,14 @@ def __init__( or target_add_start_token or target_add_end_token ): - # Check that the tokenizer correctly appends the start and end tokens to - # the sequence without splitting them. - tokens = self._source_tokenizer.tokenize(start_symbol + " " + end_symbol) - err_msg = ( - f"Bad start or end symbol ('{start_symbol}', '{end_symbol}') " - f"for tokenizer {self._source_tokenizer}" - ) - try: - start_token, end_token = tokens[0], tokens[-1] - except IndexError: - raise ValueError(err_msg) - if start_token.text != start_symbol or end_token.text != end_symbol: - raise ValueError(err_msg) - - self._start_token = start_token - self._end_token = end_token + if source_add_start_token or source_add_end_token: + self._check_start_end_tokens(start_symbol, end_symbol, self._source_tokenizer) + if ( + target_add_start_token or target_add_end_token + ) and self._target_tokenizer != self._source_tokenizer: + self._check_start_end_tokens(start_symbol, end_symbol, self._target_tokenizer) + self._start_token = Token(start_symbol) + self._end_token = Token(end_symbol) self._delimiter = delimiter self._source_max_tokens = source_max_tokens @@ -190,3 +182,22 @@ def apply_token_indexers(self, instance: Instance) -> None: instance.fields["source_tokens"]._token_indexers = self._source_token_indexers # type: ignore if "target_tokens" in instance.fields: instance.fields["target_tokens"]._token_indexers = self._target_token_indexers # type: ignore + + def _check_start_end_tokens( + self, start_symbol: str, end_symbol: str, tokenizer: Tokenizer + ) -> None: + """Check that `tokenizer` correctly appends `start_symbol` and `end_symbol` to the + sequence without splitting them. Raises a `ValueError` if this is not the case. + """ + + tokens = tokenizer.tokenize(start_symbol + " " + end_symbol) + err_msg = ( + f"Bad start or end symbol ('{start_symbol}', '{end_symbol}') " + f"for tokenizer {self._source_tokenizer}" + ) + try: + start_token, end_token = tokens[0], tokens[-1] + except IndexError: + raise ValueError(err_msg) + if start_token.text != start_symbol or end_token.text != end_symbol: + raise ValueError(err_msg)