From b72c722cb9ac97e47c6a93726334fc781634698c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 7 Apr 2022 12:27:38 +0000 Subject: [PATCH 01/75] WIP --- python/ray/data/dataset.py | 22 +- .../huggingface/huggingface_trainer.py | 200 ++++++++++++++++-- python/ray/ml/utils/torch_utils.py | 17 +- 3 files changed, 213 insertions(+), 26 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index c5c8a4bc50b1..33c41a9138a5 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -26,6 +26,7 @@ import ray.util.sgd import torch import tensorflow as tf + import torch.utils.data from ray.data.dataset_pipeline import DatasetPipeline from ray.data.grouped_dataset import GroupedDataset @@ -298,7 +299,8 @@ def transform(block: Block) -> Iterable[Block]: ): raise ValueError( "The map batches UDF returned the value " - f"{applied}, which is not allowed. " + f"{applied} of type {type(applied)}, " + "which is not allowed. " "The return type must be either list, " "pandas.DataFrame, or pyarrow.Table" ) @@ -2066,6 +2068,7 @@ def to_torch( prefetch_blocks: int = 0, drop_last: bool = False, unsqueeze_label_tensor: bool = True, + unsqueeze_feature_tensors: bool = True, ) -> "torch.utils.data.IterableDataset": """Return a Torch IterableDataset over this dataset. @@ -2139,6 +2142,10 @@ def to_torch( be left as is, that is (N, ). In general, regression loss functions expect an unsqueezed tensor, while classification loss functions expect a squeezed one. Defaults to True. + unsqueeze_feature_tensors (bool): If set to True, the features tensors + will be unsqueezed (reshaped to (N, 1)) before being concatenated into + the final features tensor. Otherwise, they will be left as is, that is + (N, ). Defaults to True. Returns: A torch IterableDataset. @@ -2190,10 +2197,13 @@ def make_generator(): drop_last=drop_last, ): if label_column: - label_vals = batch.pop(label_column).values - label_tensor = torch.as_tensor(label_vals, dtype=label_column_dtype) - if unsqueeze_label_tensor: - label_tensor = label_tensor.view(-1, 1) + label_tensor = convert_pandas_to_torch_tensor( + batch, + [label_column], + label_column_dtype, + unsqueeze=unsqueeze_label_tensor, + ) + batch.pop(label_column) else: label_tensor = None @@ -2205,6 +2215,7 @@ def make_generator(): feature_column_dtypes[key] if isinstance(feature_column_dtypes, dict) else feature_column_dtypes, + unsqueeze=unsqueeze_feature_tensors, ) for key in feature_columns } @@ -2213,6 +2224,7 @@ def make_generator(): batch, columns=feature_columns, column_dtypes=feature_column_dtypes, + unsqueeze=unsqueeze_feature_tensors, ) yield (features_tensor, label_tensor) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 09dbb4e99cee..06b1bb0ee27b 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -1,9 +1,17 @@ -from typing import Any, Callable, Optional, Dict +from typing import Any, Callable, List, Optional, Dict, Type +import os -from transformers.trainer import Trainer -from torch.utils.data import Dataset as TorchDataset +import torch +import transformers.trainer +from transformers.training_args import TrainingArguments +from transformers.trainer_callback import TrainerCallback +from torch.utils.data import Dataset as TorchDataset, IterableDataset, DataLoader + +from ray import train +from ray.data.dataset import Dataset from ray.train.torch import TorchConfig +from ray.train.session import SessionMisuseError from ray.ml.trainer import GenDataset from ray.ml.train.integrations.torch import TorchTrainer from ray.ml.config import ScalingConfig, RunConfig @@ -13,6 +21,45 @@ from ray.ml.constants import TRAIN_DATASET_KEY, EVALUATION_DATASET_KEY +class _HFIterableDatasetWithLen(IterableDataset): + def __init__(self, generator, length: int): + self.generator = generator + self._len = length + + def __iter__(self): + it = self.generator + for x in it: + yield {**x[0], "labels": x[1]} + + def __len__(self): + return self._len + + +class _TrainReportCallback(TrainerCallback): + def on_log(self, args, state, control, model=None, logs=None, **kwargs): + train.report(**{**logs, "step": state.global_step, "epoch": state.epoch}) + + +def _process_dataset_for_hf( + dataset: Dataset, feature_columns: Dict[str, List[str]], batch_size: int = 1 +) -> IterableDataset: + torch_dataset = dataset.to_torch( + batch_size=batch_size, + feature_columns=feature_columns, + label_column="labels", + unsqueeze_label_tensor=False, + unsqueeze_feature_tensors=False, + ) + try: + count = dataset.count() + except ValueError: + # pipeline case + count = None + if count: + torch_dataset = _HFIterableDatasetWithLen(torch_dataset, count) + return torch_dataset + + @PublicAPI(stability="alpha") class HuggingFaceTrainer(TorchTrainer): """A Trainer for data parallel HuggingFace Transformers on PyTorch training. @@ -153,8 +200,9 @@ def trainer_init_per_worker(train_dataset, eval_dataset, **config): def __init__( self, + *, trainer_init_per_worker: Callable[ - [TorchDataset, Optional[TorchDataset], Any], Trainer + [TorchDataset, Optional[TorchDataset], Any], transformers.trainer.Trainer ], trainer_init_config: Optional[Dict] = None, torch_config: Optional[TorchConfig] = None, @@ -175,23 +223,139 @@ def __init__( ) super().__init__( - self._create_train_func(trainer_init_per_worker), - trainer_init_config, - torch_config, - scaling_config, - run_config, - datasets, - preprocessor, - resume_from_checkpoint, + train_loop_per_worker=self._create_train_func(trainer_init_per_worker), + train_loop_config=trainer_init_config, + torch_config=torch_config, + scaling_config=scaling_config, + run_config=run_config, + datasets=datasets, + preprocessor=preprocessor, + resume_from_checkpoint=resume_from_checkpoint, ) - def _create_train_func(self, trainer_init_per_worker): + def _validate_train_loop_per_worker( + self, train_loop_per_worker: Callable, fn_name: str + ) -> None: + pass + + def _create_train_func( + self, + trainer_init_per_worker: Callable[ + [TorchDataset, Optional[TorchDataset], Any], transformers.trainer.Trainer + ], + ): def train_loop_per_worker(config): - # Set to None just to make CI pass & show - # the intended usage with trainer_init_per_worker - train_dataset = None - eval_dataset = None - trainer = trainer_init_per_worker(train_dataset, eval_dataset, **config) + os.environ["RANK"] = str(train.world_rank()) + os.environ["WORLD_SIZE"] = str(train.world_size()) + os.environ["LOCAL_RANK"] = str(train.local_rank()) + os.environ["WANDB_DISABLED"] = "true" + os.environ["DISABLE_MLFLOW_INTEGRATION"] = "true" + + train_dataset = train.get_dataset_shard(TRAIN_DATASET_KEY) + eval_dataset = train.get_dataset_shard(EVALUATION_DATASET_KEY) + train_columns = set(train_dataset.schema(fetch_if_missing=True).names) + if "labels" not in train_columns: + raise ValueError( + "'labels' column must be present in the training dataset!" + ) + train_columns.remove("labels") + if eval_dataset: + eval_columns = set(eval_dataset.schema(fetch_if_missing=True).names) + if "labels" not in eval_columns: + raise ValueError( + "'labels' column must be present in the evaluation dataset!" + ) + eval_columns.remove("labels") + + if not eval_columns.issuperset(train_columns): + raise ValueError( + "Evaluation dataset must have a superset of the columns in " + "the training dataset. " + f"Missing columns: {list(train_columns - eval_columns)}" + ) + + feature_columns = {column: [column] for column in train_columns} + + # we use batch size 1 here, as it will be converted to + # desired size inside transformers.Trainer. Possible optimization + # in the future + batch_size = 1 + train_torch_dataset = _process_dataset_for_hf( + train_dataset, feature_columns, batch_size=batch_size + ) + + if eval_dataset: + eval_torch_dataset = _process_dataset_for_hf( + eval_dataset, feature_columns, batch_size=batch_size + ) + else: + eval_torch_dataset = None + + trainer: transformers.trainer.Trainer = trainer_init_per_worker( + train_torch_dataset, eval_torch_dataset, **config + ) + + if not trainer.args.local_rank == train.local_rank(): + raise RuntimeError( + "local_rank set in TrainingArguments doesn't match " + "Ray Train local_rank " + f"({trainer.args.local_rank} != {train.local_rank()}. " + "Ensure you are not setting local_rank manually." + ) + + base_training_arguments_class: Type[ + TrainingArguments + ] = trainer.args.__class__ + + class RayTrainingArguments(base_training_arguments_class): + @property + def device(self) -> "torch.device": + try: + return train.torch.get_device() + except SessionMisuseError: + return super().device + + base_trainer_class: Type[transformers.trainer.Trainer] = trainer.__class__ + + class RayTrainer(base_trainer_class): + def get_train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + def _wrap_model(self, model, training=True): + if not training: + return model + try: + kwargs = {} + # same logic as in transformers.Trainer + if self.args.ddp_find_unused_parameters is not None: + kwargs[ + "find_unused_parameters" + ] = self.args.ddp_find_unused_parameters + elif isinstance(model, transformers.trainer.PreTrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + kwargs[ + "find_unused_parameters" + ] = not model.is_gradient_checkpointing + else: + kwargs["find_unused_parameters"] = True + + if self.args.ddp_bucket_cap_mb is not None: + kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb + return train.torch.prepare_model(model, ddp_kwargs=kwargs) + except SessionMisuseError: + return super()._wrap_model(model, training) + + trainer.__class__ = RayTrainer + trainer.args.__class__ = RayTrainingArguments + trainer.add_callback(_TrainReportCallback) + torch.cuda.set_device(train.torch.get_device()) trainer.train() return train_loop_per_worker diff --git a/python/ray/ml/utils/torch_utils.py b/python/ray/ml/utils/torch_utils.py index c338ca06cd68..ad209967d708 100644 --- a/python/ray/ml/utils/torch_utils.py +++ b/python/ray/ml/utils/torch_utils.py @@ -1,6 +1,7 @@ from typing import Optional, Union, List, Dict import pandas as pd +import numpy as np import torch @@ -8,6 +9,7 @@ def convert_pandas_to_torch_tensor( data_batch: pd.DataFrame, columns: Optional[Union[List[str], List[List[str]]]] = None, column_dtypes: Optional[Union[torch.dtype, List[torch.dtype]]] = None, + unsqueeze: bool = True, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Converts a Pandas dataframe to a torch Tensor or list of torch Tensors. @@ -46,6 +48,12 @@ def convert_pandas_to_torch_tensor( columns = columns if columns else [] + def tensorize(vals, dtype): + if vals.dtype == np.object: + # TODO: clarify if this should be cat or stack + return torch.stack([tensorize(x, dtype) for x in vals]) + return torch.as_tensor(vals, dtype=dtype) + def get_tensor_for_columns(columns, dtype): feature_tensors = [] @@ -56,11 +64,14 @@ def get_tensor_for_columns(columns, dtype): for col in batch.columns: col_vals = batch[col].values - t = torch.as_tensor(col_vals, dtype=dtype) - t = t.view(-1, 1) + t = tensorize(col_vals, dtype=dtype) + if unsqueeze: + t = t.view(-1, 1) feature_tensors.append(t) - return torch.cat(feature_tensors, dim=1) + if len(feature_tensors) > 1: + return torch.cat(feature_tensors, dim=1) + return feature_tensors[0] if multi_input: if type(column_dtypes) not in [list, tuple]: From 9a0b41e7dc6cf5f5f1e5f126eec812ed7e512cc7 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 7 Apr 2022 13:04:39 +0000 Subject: [PATCH 02/75] WIP --- .../huggingface/huggingface_trainer.py | 80 +++++++++++-------- 1 file changed, 46 insertions(+), 34 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 06b1bb0ee27b..75dfff306637 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -1,5 +1,6 @@ from typing import Any, Callable, List, Optional, Dict, Type import os +from unittest.mock import patch import torch import transformers.trainer @@ -248,8 +249,6 @@ def train_loop_per_worker(config): os.environ["RANK"] = str(train.world_rank()) os.environ["WORLD_SIZE"] = str(train.world_size()) os.environ["LOCAL_RANK"] = str(train.local_rank()) - os.environ["WANDB_DISABLED"] = "true" - os.environ["DISABLE_MLFLOW_INTEGRATION"] = "true" train_dataset = train.get_dataset_shard(TRAIN_DATASET_KEY) eval_dataset = train.get_dataset_shard(EVALUATION_DATASET_KEY) @@ -291,9 +290,15 @@ def train_loop_per_worker(config): else: eval_torch_dataset = None - trainer: transformers.trainer.Trainer = trainer_init_per_worker( - train_torch_dataset, eval_torch_dataset, **config - ) + # ensure no HF logging callbacks are added + # aside from doubling functionality with our callbacks, + # the Wandb callbacks causes training to freeze + with patch( + "transformers.trainer.get_reporting_integration_callbacks", lambda x: [] + ): + trainer: transformers.trainer.Trainer = trainer_init_per_worker( + train_torch_dataset, eval_torch_dataset, **config + ) if not trainer.args.local_rank == train.local_rank(): raise RuntimeError( @@ -319,43 +324,50 @@ def device(self) -> "torch.device": class RayTrainer(base_trainer_class): def get_train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.args.per_device_train_batch_size, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.dataloader_pin_memory, - ) + try: + train.world_rank() # check if we are in session + return DataLoader( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + except SessionMisuseError: + super().get_train_dataloader() def _wrap_model(self, model, training=True): - if not training: - return model try: - kwargs = {} - # same logic as in transformers.Trainer - if self.args.ddp_find_unused_parameters is not None: - kwargs[ - "find_unused_parameters" - ] = self.args.ddp_find_unused_parameters - elif isinstance(model, transformers.trainer.PreTrainedModel): - # find_unused_parameters breaks checkpointing as per - # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 - kwargs[ - "find_unused_parameters" - ] = not model.is_gradient_checkpointing - else: - kwargs["find_unused_parameters"] = True - - if self.args.ddp_bucket_cap_mb is not None: - kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb - return train.torch.prepare_model(model, ddp_kwargs=kwargs) + train.world_rank() # check if we are in session except SessionMisuseError: - return super()._wrap_model(model, training) + return super()._wrap_model(model, training=training) + + if not training: + return model + kwargs = {} + # same logic as in transformers.Trainer + if self.args.ddp_find_unused_parameters is not None: + kwargs[ + "find_unused_parameters" + ] = self.args.ddp_find_unused_parameters + elif isinstance(model, transformers.trainer.PreTrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + kwargs[ + "find_unused_parameters" + ] = not model.is_gradient_checkpointing + else: + kwargs["find_unused_parameters"] = True + + if self.args.ddp_bucket_cap_mb is not None: + kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb + return train.torch.prepare_model(model, ddp_kwargs=kwargs) trainer.__class__ = RayTrainer trainer.args.__class__ = RayTrainingArguments trainer.add_callback(_TrainReportCallback) - torch.cuda.set_device(train.torch.get_device()) + if trainer.args.device.type == "cuda": + torch.cuda.set_device(trainer.args.device) trainer.train() return train_loop_per_worker From 5ca35b082f21d7691c923f502e73a6915d0938e4 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 7 Apr 2022 13:21:42 +0000 Subject: [PATCH 03/75] WIP --- .../huggingface/huggingface_trainer.py | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 75dfff306637..77dd85726809 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -12,7 +12,7 @@ from ray import train from ray.data.dataset import Dataset from ray.train.torch import TorchConfig -from ray.train.session import SessionMisuseError +from ray.train.session import get_session from ray.ml.trainer import GenDataset from ray.ml.train.integrations.torch import TorchTrainer from ray.ml.config import ScalingConfig, RunConfig @@ -315,31 +315,26 @@ def train_loop_per_worker(config): class RayTrainingArguments(base_training_arguments_class): @property def device(self) -> "torch.device": - try: - return train.torch.get_device() - except SessionMisuseError: + if get_session() is None: return super().device + return train.torch.get_device() base_trainer_class: Type[transformers.trainer.Trainer] = trainer.__class__ class RayTrainer(base_trainer_class): def get_train_dataloader(self): - try: - train.world_rank() # check if we are in session - return DataLoader( - self.train_dataset, - batch_size=self.args.per_device_train_batch_size, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.dataloader_pin_memory, - ) - except SessionMisuseError: - super().get_train_dataloader() + if get_session() is None: + return super().get_train_dataloader() + return DataLoader( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) def _wrap_model(self, model, training=True): - try: - train.world_rank() # check if we are in session - except SessionMisuseError: + if get_session() is None: return super()._wrap_model(model, training=training) if not training: From f8e153a49662cc33cf4d03916ef654ec2365df8c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 7 Apr 2022 14:03:37 +0000 Subject: [PATCH 04/75] WIP --- .../huggingface/huggingface_trainer.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 77dd85726809..d2214ab21ee9 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -1,5 +1,6 @@ from typing import Any, Callable, List, Optional, Dict, Type import os +import inspect from unittest.mock import patch import torch @@ -218,10 +219,19 @@ def __init__( trainer_init_per_worker, "trainer_init_per_worker" ) - assert TRAIN_DATASET_KEY in datasets - assert all( + if TRAIN_DATASET_KEY not in datasets: + raise KeyError( + f"'{TRAIN_DATASET_KEY}' key must be preset in `datasets`. " + f"Got {list(self.datasets.keys())}" + ) + if not all( key in (TRAIN_DATASET_KEY, EVALUATION_DATASET_KEY) for key in datasets - ) + ): + raise KeyError( + f"Only '{TRAIN_DATASET_KEY}' and '{EVALUATION_DATASET_KEY}' " + "keys can be preset in `datasets`. " + f"Got {list(self.datasets.keys())}" + ) super().__init__( train_loop_per_worker=self._create_train_func(trainer_init_per_worker), @@ -237,7 +247,12 @@ def __init__( def _validate_train_loop_per_worker( self, train_loop_per_worker: Callable, fn_name: str ) -> None: - pass + num_params = len(inspect.signature(train_loop_per_worker).parameters) + if num_params != 3: + raise ValueError( + f"{fn_name} should take in 3 arguments, " + f"but it accepts {num_params} arguments instead." + ) def _create_train_func( self, From 240e661c839b77146ed2b704be72e088cd75e741 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 7 Apr 2022 14:24:30 +0000 Subject: [PATCH 05/75] Make datasets arg mandatory --- .../huggingface/huggingface_trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index d2214ab21ee9..c26d9c01b28a 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -180,13 +180,6 @@ def trainer_init_per_worker(train_dataset, eval_dataset, **config): and config as kwargs. The Torch Datasets are automatically created by converting the Ray Datasets internally before they are passed into the function. - trainer_init_config: Configurations to pass into - ``trainer_init_per_worker`` as kwargs. - torch_config: Configuration for setting up the PyTorch backend. If set to - None, use the default configuration. This replaces the ``backend_config`` - arg of ``DataParallelTrainer``. Same as in ``TorchTrainer``. - scaling_config: Configuration for how to scale data parallel training. - run_config: Configuration for the execution of the training run. datasets: Any Ray Datasets to use for training. Use the key "train" to denote which dataset is the training dataset and (optionally) key "evaluation" to denote the evaluation @@ -195,6 +188,13 @@ def trainer_init_per_worker(train_dataset, eval_dataset, **config): If a ``preprocessor`` is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by the ``preprocessor`` if one is provided. + trainer_init_config: Configurations to pass into + ``trainer_init_per_worker`` as kwargs. + torch_config: Configuration for setting up the PyTorch backend. If set to + None, use the default configuration. This replaces the ``backend_config`` + arg of ``DataParallelTrainer``. Same as in ``TorchTrainer``. + scaling_config: Configuration for how to scale data parallel training. + run_config: Configuration for the execution of the training run. preprocessor: A ray.ml.preprocessor.Preprocessor to preprocess the provided datasets. resume_from_checkpoint: A checkpoint to resume training from. @@ -206,11 +206,11 @@ def __init__( trainer_init_per_worker: Callable[ [TorchDataset, Optional[TorchDataset], Any], transformers.trainer.Trainer ], + datasets: Dict[str, GenDataset], trainer_init_config: Optional[Dict] = None, torch_config: Optional[TorchConfig] = None, scaling_config: Optional[ScalingConfig] = None, run_config: Optional[RunConfig] = None, - datasets: Optional[Dict[str, GenDataset]] = None, preprocessor: Optional[Preprocessor] = None, resume_from_checkpoint: Optional[Checkpoint] = None, ): From 196905ae35fd0ac0a13ed5f0cb9bcef9a64c4aba Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 7 Apr 2022 20:22:29 +0000 Subject: [PATCH 06/75] WIP --- language_modeling_from_scratch copy.ipynb | 579 ++++++++++++++++++ python/ray/ml/train/data_parallel_trainer.py | 3 + .../huggingface/huggingface_trainer.py | 55 +- 3 files changed, 632 insertions(+), 5 deletions(-) create mode 100644 language_modeling_from_scratch copy.ipynb diff --git a/language_modeling_from_scratch copy.ipynb b/language_modeling_from_scratch copy.ipynb new file mode 100644 index 000000000000..036939ab4000 --- /dev/null +++ b/language_modeling_from_scratch copy.ipynb @@ -0,0 +1,579 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "X4cRE8IbIrIV" + }, + "source": [ + "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets. Uncomment the following cell and run it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "MOsHUjgdIrIW", + "outputId": "f84a093e-147f-470e-aad9-80fb51193c8e" + }, + "outputs": [], + "source": [ + "#! pip install datasets transformers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", + "\n", + "To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.\n", + "\n", + "First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then execute the following cell and input your username and password:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then you need to install Git-LFS. Uncomment the following instructions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# !apt install git-lfs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Make sure your version of Transformers is at least 4.11.0 since the functionality was introduced in that version:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import transformers\n", + "\n", + "print(transformers.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HFASsisvIrIb" + }, + "source": [ + "You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs [here](https://github.com/huggingface/transformers/tree/master/examples/language-modeling)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a3KD3WXU3l-O" + }, + "source": [ + "# Train a language model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JAscNNUD3l-P" + }, + "source": [ + "In this notebook, we'll see how to train a [🤗 Transformers](https://github.com/huggingface/transformers) model on a language modeling task. We will cover two types of language modeling tasks which are:\n", + "\n", + "- Causal language modeling: the model has to predict the next token in the sentence (so the labels are the same as the inputs shifted to the right). To make sure the model does not cheat, it gets an attention mask that will prevent it to access the tokens after token i when trying to predict the token i+1 in the sentence.\n", + "\n", + "![Widget inference representing the causal language modeling task](images/causal_language_modeling.png)\n", + "\n", + "- Masked language modeling: the model has to predict some tokens that are masked in the input. It still has access to the whole sentence, so it can use the tokens before and after the tokens masked to predict their value.\n", + "\n", + "![Widget inference representing the masked language modeling task](images/masked_language_modeling.png)\n", + "\n", + "We will see how to easily load and preprocess the dataset for each one of those tasks, and how to use the `Trainer` API to train a model on it.\n", + "\n", + "This notebooks assumes you have trained a tokenizer on the corpus you are using, see the [How to train a tokenizer](https://github.com/huggingface/notebooks/blob/master/examples/tokenizer_training.ipynb) notebook ([open in colab](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/tokenizer_training.ipynb)).\n", + "\n", + "A script version of this notebook you can directly run on a distributed environment or on TPU is available in our [examples folder](https://github.com/huggingface/transformers/tree/master/examples)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1r_n9OWV3l-Q" + }, + "source": [ + "## Preparing the dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kswRMhPc3l-Q" + }, + "source": [ + "For each of those tasks, we will use the [Wikitext 2]() dataset as an example. You can load it very easily with the 🤗 Datasets library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "n2ZRs1cL3l-R", + "outputId": "11151c56-be90-4d11-e7df-db85e745ca5c" + }, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "# datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f1-9jepM3l-W" + }, + "source": [ + "You can replace the dataset above with any dataset hosted on [the hub](https://huggingface.co/datasets) or use your own files. Just uncomment the following cell and replace the paths with values that will lead to your files:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uxSaGa_l3l-W" + }, + "outputs": [], + "source": [ + "# datasets = load_dataset(\"text\", data_files={\"train\": path_to_train.txt, \"validation\": path_to_validation.txt}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jY1SwIrY3l-a" + }, + "source": [ + "You can also load datasets from a csv or a JSON file, see the [full documentation](https://huggingface.co/docs/datasets/loading_datasets.html#from-local-files) for more information." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u3EtYfeHIrIz" + }, + "source": [ + "To access an actual element, you need to select a split first, then give an index:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WHUmphG3IrI3" + }, + "source": [ + "To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ur5sNUcZ3l-g" + }, + "outputs": [], + "source": [ + "from datasets import ClassLabel\n", + "import random\n", + "import pandas as pd\n", + "from IPython.display import display, HTML\n", + "\n", + "def show_random_elements(dataset, num_examples=10):\n", + " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n", + " picks = []\n", + " for _ in range(num_examples):\n", + " pick = random.randint(0, len(dataset)-1)\n", + " while pick in picks:\n", + " pick = random.randint(0, len(dataset)-1)\n", + " picks.append(pick)\n", + " \n", + " df = pd.DataFrame(dataset[picks])\n", + " for column, typ in dataset.features.items():\n", + " if isinstance(typ, ClassLabel):\n", + " df[column] = df[column].transform(lambda i: typ.names[i])\n", + " display(HTML(df.to_html()))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CKerdF353l-o" + }, + "source": [ + "As we can see, some of the texts are a full paragraph of a Wikipedia article while others are just titles or empty lines." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JEA1ju653l-p" + }, + "source": [ + "## Causal Language modeling" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "v5GTGKZS3l-q" + }, + "source": [ + "For causal language modeling (CLM) we are going to take all the texts in our dataset and concatenate them after they are tokenized. Then we will split them in examples of a certain sequence length. This way the model will receive chunks of contiguous text that may look like:\n", + "```\n", + "part of text 1\n", + "```\n", + "or \n", + "```\n", + "end of text 1 [BOS_TOKEN] beginning of text 2\n", + "```\n", + "depending on whether they span over several of the original texts in the dataset or not. The labels will be the same as the inputs, shifted to the left.\n", + "\n", + "We will use the [`gpt2`](https://huggingface.co/gpt2) architecture for this example. You can pick any of the checkpoints listed [here](https://huggingface.co/models?filter=causal-lm) instead. For the tokenizer, you can replace the checkpoint by the one you trained yourself." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-WGBCO343l-q" + }, + "outputs": [], + "source": [ + "model_checkpoint = \"gpt2\"\n", + "tokenizer_checkpoint = \"sgugger/gpt2-like-tokenizer\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5io6fY_d3l-u" + }, + "source": [ + "To tokenize all our texts with the same vocabulary that was used when training the model, we have to download a pretrained tokenizer. This is all done by the `AutoTokenizer` class:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iAYlS40Z3l-v" + }, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rpOiBrJ13l-y" + }, + "source": [ + "We can now call the tokenizer on all our texts. This is very simple, using the [`map`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map) method from the Datasets library. First we define a function that call the tokenizer on our texts:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M9xVAa3s3l-2" + }, + "source": [ + "Then we apply it to all the splits in our `datasets` object, using `batched=True` and 4 processes to speed up the preprocessing. We won't need the `text` column afterward, so we discard it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NVAO0H8u3l-3", + "outputId": "30d88b8a-e353-4e13-f709-8e5e06ef747b" + }, + "outputs": [], + "source": [ + "# tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=[\"text\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8qik3J_C3l-7" + }, + "source": [ + "If we now look at an element of our datasets, we will see the text have been replaced by the `input_ids` the model will need:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "obvgcXda3l--" + }, + "source": [ + "Now for the harder part: we need to concatenate all our texts together then split the result in small chunks of a certain `block_size`. To do this, we will use the `map` method again, with the option `batched=True`. This option actually lets us change the number of examples in the datasets by returning a different number of examples than we got. This way, we can create our new samples from a batch of examples.\n", + "\n", + "First, we grab the maximum length our model was pretrained with. This might be a big too big to fit in your GPU RAM, so here we take a bit less at just 128." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DVHs5aCA3l-_" + }, + "outputs": [], + "source": [ + "# block_size = tokenizer.model_max_length\n", + "block_size = 128" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RpNfGiMw3l_A" + }, + "source": [ + "Then we write the preprocessing function that will group our texts:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LGJWXtNv3l_C" + }, + "source": [ + "First note that we duplicate the inputs for our labels. This is because the model of the 🤗 Transformers library apply the shifting to the right, so we don't need to do it manually.\n", + "\n", + "Also note that by default, the `map` method will send a batch of 1,000 examples to be treated by the preprocessing function. So here, we will drop the remainder to make the concatenated tokenized texts a multiple of `block_size` every 1,000 examples. You can adjust this behavior by passing a higher batch size (which will also be processed slower). You can also speed-up the preprocessing by using multiprocessing:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6n84V8Gc3l_G" + }, + "source": [ + "And we can check our datasets have changed: now the samples contain chunks of `block_size` contiguous tokens, potentially spanning over several of our original texts." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iEmeQ7Xm3l_H" + }, + "source": [ + "Now that the data has been cleaned, we're ready to instantiate our `Trainer`. First we create the model using the same config as our checkpoint, but initialized with random weights:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sPqQA3TT3l_I" + }, + "outputs": [], + "source": [ + "from transformers import AutoConfig, AutoModelForCausalLM" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VyPQTOF_3l_J" + }, + "source": [ + "And we will needsome `TrainingArguments`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jElf8LJ33l_K" + }, + "outputs": [], + "source": [ + "from transformers import Trainer, TrainingArguments" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The last argument to setup everything so we can push the model to the [Hub](https://huggingface.co/models) regularly during training. Remove it if you didn't follow the installation steps at the top of the notebook. If you want to save your model locally in a name that is different than the name of the repository it will be pushed, or if you want to push your model under an organization and not your name space, use the `hub_model_id` argument to set the repo name (it needs to be the full name, including your namespace: for instance `\"sgugger/gpt-finetuned-wikitext2\"` or `\"huggingface/gpt-finetuned-wikitext2\"`)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sZRbT9ui3l_N" + }, + "source": [ + "We pass along all of those to the `Trainer` class:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ray" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_dataset():\n", + " datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')\n", + " tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)\n", + " def tokenize_function(examples):\n", + " return tokenizer(examples[\"text\"])\n", + " tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=1, remove_columns=[\"text\"])\n", + " def group_texts(examples):\n", + " # Concatenate all texts.\n", + " concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n", + " total_length = len(concatenated_examples[list(examples.keys())[0]])\n", + " # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n", + " # customize this part to your needs.\n", + " total_length = (total_length // block_size) * block_size\n", + " # Split by chunks of max_len.\n", + " result = {\n", + " k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n", + " for k, t in concatenated_examples.items()\n", + " }\n", + " result[\"labels\"] = result[\"input_ids\"].copy()\n", + " return result\n", + " lm_datasets = tokenized_datasets.map(\n", + " group_texts,\n", + " batched=True,\n", + " batch_size=1000,\n", + " num_proc=1,\n", + " )\n", + " return lm_datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ray.data\n", + "lm_dataset = get_dataset()\n", + "ray_train = ray.data.from_arrow(lm_dataset[\"train\"]._data.table)\n", + "ray_validation = ray.data.from_arrow(lm_dataset[\"validation\"]._data.table)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ray.ml.train.integrations.huggingface import HuggingFaceTrainer\n", + "\n", + "def train_function(train_dataset, eval_dataset = None, **config):\n", + " model_config = AutoConfig.from_pretrained(model_checkpoint)\n", + " model = AutoModelForCausalLM.from_config(model_config)\n", + " print(\"initializing training_args\")\n", + " training_args = TrainingArguments(\n", + " f\"{model_checkpoint}-wikitext2\",\n", + " evaluation_strategy = \"epoch\",\n", + " num_train_epochs=2,\n", + " learning_rate=2e-5,\n", + " weight_decay=0.01,\n", + " disable_tqdm=True,\n", + " save_strategy=\"epoch\",\n", + " )\n", + " print(\"initializing\")\n", + " trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=eval_dataset,\n", + " )\n", + " print(\"trainer initialized\")\n", + " return trainer\n", + "\n", + "trainer = HuggingFaceTrainer(\n", + " trainer_init_per_worker=train_function,\n", + " scaling_config={\"num_workers\": 2, \"use_gpu\": False},\n", + " datasets={\"train\": ray_train.limit(16), \"evaluation\": ray_validation.limit(8)},\n", + ")\n", + "trainer.fit()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6Vvz34Td3l_O" + }, + "source": [ + "And we can train our model:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3APq-vUc3l_R" + }, + "source": [ + "Once the training is completed, we can evaluate our model and get its perplexity on the validation set like this:" + ] + } + ], + "metadata": { + "colab": { + "name": "Train a language model", + "provenance": [] + }, + "interpreter": { + "hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f" + }, + "kernelspec": { + "display_name": "Python 3.8.10 ('venv': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/python/ray/ml/train/data_parallel_trainer.py b/python/ray/ml/train/data_parallel_trainer.py index 21fcb54552a5..535160a4e430 100644 --- a/python/ray/ml/train/data_parallel_trainer.py +++ b/python/ray/ml/train/data_parallel_trainer.py @@ -1,3 +1,4 @@ +import gc import inspect import logging from pathlib import Path @@ -325,6 +326,7 @@ def on_init(self, preprocessor: Preprocessor): super(_DataParallelCheckpointManager, self).on_init() def write_checkpoint(self, checkpoint: Dict): + self.latest_checkpoint = None self.add_tune_checkpoint_id(checkpoint) # Add the preprocessor to the checkpoint. @@ -334,6 +336,7 @@ def write_checkpoint(self, checkpoint: Dict): # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: checkpoint_obj.to_directory(path=checkpoint_dir) + gc.collect() @property def latest_checkpoint_dir(self) -> Optional[Path]: diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index c26d9c01b28a..72493eee8729 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -1,6 +1,7 @@ from typing import Any, Callable, List, Optional, Dict, Type import os import inspect +import gc from unittest.mock import patch import torch @@ -38,8 +39,37 @@ def __len__(self): class _TrainReportCallback(TrainerCallback): + def __init__(self) -> None: + self.delayed_report = None + super().__init__() + def on_log(self, args, state, control, model=None, logs=None, **kwargs): - train.report(**{**logs, "step": state.global_step, "epoch": state.epoch}) + print(f"on log {state.epoch}, should save {control.should_save}") + report = {**logs, "step": state.global_step, "epoch": state.epoch} + if control.should_save: + self.delayed_report = report + else: + train.report(**report) + + def on_save(self, args, state, control, **kwargs): + checkpoint_path = transformers.trainer.get_last_checkpoint(args.output_dir) + print( + f"on save {state.epoch}, checkpoint_path {checkpoint_path}, delayed_report {bool(self.delayed_report)}" + ) + if checkpoint_path: + print("creating checkpoint") + ml_checkpoint = Checkpoint.from_directory(str(checkpoint_path)) + print("saving checkpoint") + if train.world_rank() == 0: + train.save_checkpoint(**ml_checkpoint.to_dict()) + else: + train.save_checkpoint(**{"DUMMY": 0}) + print("checkpoint saved") + if self.delayed_report: + train.report(**self.delayed_report) + self.delayed_report = None + print("on save done") + gc.collect() def _process_dataset_for_hf( @@ -215,7 +245,7 @@ def __init__( resume_from_checkpoint: Optional[Checkpoint] = None, ): - self._validate_train_loop_per_worker( + self._validate_trainer_init_per_worker( trainer_init_per_worker, "trainer_init_per_worker" ) @@ -244,16 +274,21 @@ def __init__( resume_from_checkpoint=resume_from_checkpoint, ) - def _validate_train_loop_per_worker( - self, train_loop_per_worker: Callable, fn_name: str + def _validate_trainer_init_per_worker( + self, trainer_init_per_worker: Callable, fn_name: str ) -> None: - num_params = len(inspect.signature(train_loop_per_worker).parameters) + num_params = len(inspect.signature(trainer_init_per_worker).parameters) if num_params != 3: raise ValueError( f"{fn_name} should take in 3 arguments, " f"but it accepts {num_params} arguments instead." ) + def _validate_train_loop_per_worker( + self, train_loop_per_worker: Callable, fn_name: str + ) -> None: + pass + def _create_train_func( self, trainer_init_per_worker: Callable[ @@ -373,8 +408,18 @@ def _wrap_model(self, model, training=True): kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb return train.torch.prepare_model(model, ddp_kwargs=kwargs) + def _save(self, output_dir=None, state_dict=None): + # Workaround for RayTrainingArguments not being + # pickleable + self.args.__class__ = base_training_arguments_class + ret = super()._save(output_dir, state_dict) + self.args.__class__ = RayTrainingArguments + return ret + trainer.__class__ = RayTrainer trainer.args.__class__ = RayTrainingArguments + trainer.args.no_cuda = not torch.cuda.is_available() + trainer.args.save_on_each_node = True trainer.add_callback(_TrainReportCallback) if trainer.args.device.type == "cuda": torch.cuda.set_device(trainer.args.device) From f6e9daf1e696c6f980d432eb3eeb2911d7bb662c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 11 Apr 2022 17:01:10 +0000 Subject: [PATCH 07/75] Add docs --- doc/source/ray-air/getting-started.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/source/ray-air/getting-started.rst b/doc/source/ray-air/getting-started.rst index bcc3af3126a8..183c523e64ad 100644 --- a/doc/source/ray-air/getting-started.rst +++ b/doc/source/ray-air/getting-started.rst @@ -60,6 +60,10 @@ Trainer :members: :show-inheritance: +.. automodule:: ray.ml.train.integrations.huggingface + :members: + :show-inheritance: + .. autoclass:: ray.ml.train.data_parallel_trainer.DataParallelTrainer :members: :show-inheritance: From 55633ef07af1c265310802b7a55f592a7617927e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 11 Apr 2022 18:17:16 +0000 Subject: [PATCH 08/75] WIP --- python/ray/ml/train/data_parallel_trainer.py | 3 -- .../huggingface/huggingface_trainer.py | 29 ++++++++++++++----- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/python/ray/ml/train/data_parallel_trainer.py b/python/ray/ml/train/data_parallel_trainer.py index 535160a4e430..21fcb54552a5 100644 --- a/python/ray/ml/train/data_parallel_trainer.py +++ b/python/ray/ml/train/data_parallel_trainer.py @@ -1,4 +1,3 @@ -import gc import inspect import logging from pathlib import Path @@ -326,7 +325,6 @@ def on_init(self, preprocessor: Preprocessor): super(_DataParallelCheckpointManager, self).on_init() def write_checkpoint(self, checkpoint: Dict): - self.latest_checkpoint = None self.add_tune_checkpoint_id(checkpoint) # Add the preprocessor to the checkpoint. @@ -336,7 +334,6 @@ def write_checkpoint(self, checkpoint: Dict): # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: checkpoint_obj.to_directory(path=checkpoint_dir) - gc.collect() @property def latest_checkpoint_dir(self) -> Optional[Path]: diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 72493eee8729..6037f74cf332 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional, Dict, Type +from typing import Any, Callable, Generator, List, Optional, Dict, Type import os import inspect import gc @@ -25,11 +25,12 @@ class _HFIterableDatasetWithLen(IterableDataset): - def __init__(self, generator, length: int): + """Special Torch IterableDataset with preset length.""" + def __init__(self, generator: Generator, length: int): self.generator = generator self._len = length - def __iter__(self): + def __iter__(self) -> Dict[str, torch.Tensor]: it = self.generator for x in it: yield {**x[0], "labels": x[1]} @@ -39,7 +40,12 @@ def __len__(self): class _TrainReportCallback(TrainerCallback): + """HF TrainerCallback for Ray Train metric reporting & checkpointing.""" def __init__(self) -> None: + # HF first logs metrics, and then checkpoints. With Ray AIR, we need the + # opposite. Therefore, if we detect that a checkpoint will be created, + # we delay the train.report call after the checkpoint is reported + # to Ray Train. self.delayed_report = None super().__init__() @@ -75,6 +81,7 @@ def on_save(self, args, state, control, **kwargs): def _process_dataset_for_hf( dataset: Dataset, feature_columns: Dict[str, List[str]], batch_size: int = 1 ) -> IterableDataset: + """Converts a Ray Dataset into a HF-compatible Torch Dataset.""" torch_dataset = dataset.to_torch( batch_size=batch_size, feature_columns=feature_columns, @@ -116,7 +123,10 @@ class HuggingFaceTrainer(TorchTrainer): sampling, as the dataset will already be sharded. Hugging Face loggers will be automatically disabled, and the ``local_rank`` - argument in ``TrainingArguments`` will be automatically set. + argument in ``TrainingArguments`` will be automatically set. Please note + that if you want to use CPU training, you will need to set the ``no_cuda`` + argument in ``TrainingArguments`` manually - otherwise, an exception + may be thrown. Example: .. code-block:: python @@ -296,6 +306,7 @@ def _create_train_func( ], ): def train_loop_per_worker(config): + # Env vars necessary for HF to setup DDP os.environ["RANK"] = str(train.world_rank()) os.environ["WORLD_SIZE"] = str(train.world_size()) os.environ["LOCAL_RANK"] = str(train.local_rank()) @@ -323,6 +334,7 @@ def train_loop_per_worker(config): f"Missing columns: {list(train_columns - eval_columns)}" ) + # HF-supported format feature_columns = {column: [column] for column in train_columns} # we use batch size 1 here, as it will be converted to @@ -350,7 +362,7 @@ def train_loop_per_worker(config): train_torch_dataset, eval_torch_dataset, **config ) - if not trainer.args.local_rank == train.local_rank(): + if trainer.args.local_rank != train.local_rank(): raise RuntimeError( "local_rank set in TrainingArguments doesn't match " "Ray Train local_rank " @@ -408,11 +420,12 @@ def _wrap_model(self, model, training=True): kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb return train.torch.prepare_model(model, ddp_kwargs=kwargs) - def _save(self, output_dir=None, state_dict=None): + def _save(self, *args, **kwargs): # Workaround for RayTrainingArguments not being - # pickleable + # pickleable due to it being defined in a local + # scope self.args.__class__ = base_training_arguments_class - ret = super()._save(output_dir, state_dict) + ret = super()._save(*args, **kwargs) self.args.__class__ = RayTrainingArguments return ret From 3152272034ca00b26adbaeca2251ebd8bf483a5f Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 12 Apr 2022 23:14:13 +0000 Subject: [PATCH 09/75] HuggingFaceTrainer --- python/ray/ml/train/data_parallel_trainer.py | 5 +- .../huggingface/huggingface_trainer.py | 277 +++++++++++++----- 2 files changed, 204 insertions(+), 78 deletions(-) diff --git a/python/ray/ml/train/data_parallel_trainer.py b/python/ray/ml/train/data_parallel_trainer.py index 21fcb54552a5..18cda0ca1e18 100644 --- a/python/ray/ml/train/data_parallel_trainer.py +++ b/python/ray/ml/train/data_parallel_trainer.py @@ -249,6 +249,9 @@ def _validate_train_loop_per_worker( f"but it accepts {num_params} arguments instead." ) + def _get_checkpoint_manager(self) -> TuneCheckpointManager: + return _DataParallelCheckpointManager() + def training_loop(self) -> None: scaling_config_dataclass = ScalingConfigDataClass(**self.scaling_config) @@ -271,7 +274,7 @@ def training_loop(self) -> None: max_retries=0, ) - checkpoint_manager = _DataParallelCheckpointManager() + checkpoint_manager = self._get_checkpoint_manager() checkpoint_manager.on_init(preprocessor=self.preprocessor) # Start the remote actors. diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 6037f74cf332..1fcdf4f9304a 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -1,8 +1,13 @@ -from typing import Any, Callable, Generator, List, Optional, Dict, Type +from pathlib import Path +import shutil +import tempfile +from typing import Any, Callable, Generator, Iterator, List, Optional, Dict, Type import os import inspect import gc from unittest.mock import patch +from ray.train.checkpoint import TuneCheckpointManager +from ray.train.constants import TUNE_CHECKPOINT_ID import torch import transformers.trainer @@ -12,6 +17,8 @@ from ray import train +from ray import tune +import ray.cloudpickle as cpickle from ray.data.dataset import Dataset from ray.train.torch import TorchConfig from ray.train.session import get_session @@ -20,83 +27,22 @@ from ray.ml.config import ScalingConfig, RunConfig from ray.ml.preprocessor import Preprocessor from ray.ml.checkpoint import Checkpoint -from ray.util import PublicAPI -from ray.ml.constants import TRAIN_DATASET_KEY, EVALUATION_DATASET_KEY +from ray.util import PublicAPI, get_node_ip_address +from ray.tune.utils.file_transfer import sync_dir_between_nodes +from ray.ml.constants import TRAIN_DATASET_KEY, EVALUATION_DATASET_KEY, PREPROCESSOR_KEY +# This trainer uses a special checkpoint syncing logic. +# Because HF checkpoints are very large dirs (at least several GBs), +# we use directory checkpoints that are synced between nodes when +# required instead of serializing the checkpoints and sending +# bytes over nodes. This is a much more performant solution for +# large directory checkpoints. The current implementation +# is special for HuggingFaceTrainer, but can and should be +# made generic. +# TODO(ml-team): Make dir syncing checkpoint logic generic. -class _HFIterableDatasetWithLen(IterableDataset): - """Special Torch IterableDataset with preset length.""" - def __init__(self, generator: Generator, length: int): - self.generator = generator - self._len = length - - def __iter__(self) -> Dict[str, torch.Tensor]: - it = self.generator - for x in it: - yield {**x[0], "labels": x[1]} - - def __len__(self): - return self._len - - -class _TrainReportCallback(TrainerCallback): - """HF TrainerCallback for Ray Train metric reporting & checkpointing.""" - def __init__(self) -> None: - # HF first logs metrics, and then checkpoints. With Ray AIR, we need the - # opposite. Therefore, if we detect that a checkpoint will be created, - # we delay the train.report call after the checkpoint is reported - # to Ray Train. - self.delayed_report = None - super().__init__() - - def on_log(self, args, state, control, model=None, logs=None, **kwargs): - print(f"on log {state.epoch}, should save {control.should_save}") - report = {**logs, "step": state.global_step, "epoch": state.epoch} - if control.should_save: - self.delayed_report = report - else: - train.report(**report) - - def on_save(self, args, state, control, **kwargs): - checkpoint_path = transformers.trainer.get_last_checkpoint(args.output_dir) - print( - f"on save {state.epoch}, checkpoint_path {checkpoint_path}, delayed_report {bool(self.delayed_report)}" - ) - if checkpoint_path: - print("creating checkpoint") - ml_checkpoint = Checkpoint.from_directory(str(checkpoint_path)) - print("saving checkpoint") - if train.world_rank() == 0: - train.save_checkpoint(**ml_checkpoint.to_dict()) - else: - train.save_checkpoint(**{"DUMMY": 0}) - print("checkpoint saved") - if self.delayed_report: - train.report(**self.delayed_report) - self.delayed_report = None - print("on save done") - gc.collect() - - -def _process_dataset_for_hf( - dataset: Dataset, feature_columns: Dict[str, List[str]], batch_size: int = 1 -) -> IterableDataset: - """Converts a Ray Dataset into a HF-compatible Torch Dataset.""" - torch_dataset = dataset.to_torch( - batch_size=batch_size, - feature_columns=feature_columns, - label_column="labels", - unsqueeze_label_tensor=False, - unsqueeze_feature_tensors=False, - ) - try: - count = dataset.count() - except ValueError: - # pipeline case - count = None - if count: - torch_dataset = _HFIterableDatasetWithLen(torch_dataset, count) - return torch_dataset +NODE_IP_KEY = "node_ip" +CHECKPOINT_PATH_ON_NODE_KEY = "checkpoint_path_on_node" @PublicAPI(stability="alpha") @@ -284,6 +230,36 @@ def __init__( resume_from_checkpoint=resume_from_checkpoint, ) + def __new__(cls, *args, **kwargs): + """Store the init args as attributes so this can be merged with Tune hparams.""" + # This if will be entered in the driver-side Trainer. + # The Trainer inside the trainable will have a dict + # checkpoint created here. + # This is required to ensure that the dir syncing logic + # is used instead of serializing several gigabytes of data + # when a Checkpoint is sent to a Ray Actor. + if ( + "resume_from_checkpoint" in kwargs + and kwargs["resume_from_checkpoint"]._local_path + ): + checkpoint_path = kwargs["resume_from_checkpoint"].to_directory() + + # Load checkpoint from path. + checkpoint_path = Path(checkpoint_path).expanduser().absolute() + if not checkpoint_path.exists(): + raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.") + with open(checkpoint_path.joinpath(TUNE_CHECKPOINT_ID), "r") as f: + tune_checkpoint_id = int(f.read()) + + kwargs["resume_from_checkpoint"] = Checkpoint.from_dict( + { + NODE_IP_KEY: get_node_ip_address(), + CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), + TUNE_CHECKPOINT_ID: tune_checkpoint_id, + } + ) + return super(HuggingFaceTrainer, cls).__new__(cls, *args, **kwargs) + def _validate_trainer_init_per_worker( self, trainer_init_per_worker: Callable, fn_name: str ) -> None: @@ -297,8 +273,13 @@ def _validate_trainer_init_per_worker( def _validate_train_loop_per_worker( self, train_loop_per_worker: Callable, fn_name: str ) -> None: + # Do not validate train_loop_per_worker. We validate + # trainer_init_per_worker instead. pass + def _get_checkpoint_manager(self) -> TuneCheckpointManager: + return _DataParallelSyncingCheckpointManager() + def _create_train_func( self, trainer_init_per_worker: Callable[ @@ -355,6 +336,7 @@ def train_loop_per_worker(config): # ensure no HF logging callbacks are added # aside from doubling functionality with our callbacks, # the Wandb callbacks causes training to freeze + # TODO(yard1): Automatically set `no_cuda` with patch( "transformers.trainer.get_reporting_integration_callbacks", lambda x: [] ): @@ -436,6 +418,147 @@ def _save(self, *args, **kwargs): trainer.add_callback(_TrainReportCallback) if trainer.args.device.type == "cuda": torch.cuda.set_device(trainer.args.device) - trainer.train() + + checkpoint = train.load_checkpoint() + checkpoint_path = None + if checkpoint: + source_ip = checkpoint[NODE_IP_KEY] + source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY] + target_ip = get_node_ip_address() + if source_ip == target_ip: + checkpoint_path = source_path + else: + # TODO(yard1): Confirm if tempdir is the right approach here. + checkpoint_path = tempfile.mkdtemp() + sync_dir_between_nodes( + source_ip=source_ip, + source_path=source_path, + target_ip=target_ip, + target_path=checkpoint_path, + return_futures=False, + max_size_bytes=None, + ) + trainer.train(resume_from_checkpoint=checkpoint_path) return train_loop_per_worker + + +class _HFIterableDatasetWithLen(IterableDataset): + """Special Torch IterableDataset with preset length.""" + + def __init__(self, generator: Generator, length: int): + self.generator = generator + self._len = length + + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + it = self.generator + for x in it: + # HF-specific format + yield {**x[0], "labels": x[1]} + + def __len__(self): + return self._len + + +class _TrainReportCallback(TrainerCallback): + """HF TrainerCallback for Ray Train metric reporting & checkpointing.""" + + def __init__(self) -> None: + # HF first logs metrics, and then checkpoints. With Ray AIR, we need the + # opposite. Therefore, if we detect that a checkpoint will be created, + # we delay the train.report call after the checkpoint is reported + # to Ray Train. + self.delayed_report = None + # Avoid double reporting at the end. + # TODO(yard1): Train statistics are only reported at the end. Combine + # the second to last report and the last report somehow. We want + # steps/epochs to match the training iteration. + self.last_step = None + super().__init__() + + def on_log(self, args, state, control, model=None, logs=None, **kwargs): + if state.global_step == self.last_step: + return + self.last_step = state.global_step + report = {**logs, "step": state.global_step, "epoch": state.epoch} + if control.should_save: + self.delayed_report = report + else: + train.report(**report) + + def on_save(self, args, state, control, **kwargs): + checkpoint_path = Path( + transformers.trainer.get_last_checkpoint(args.output_dir) + ).absolute() + if checkpoint_path: + train.save_checkpoint( + **{ + NODE_IP_KEY: get_node_ip_address(), + CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), + } + ) + if self.delayed_report: + train.report(**self.delayed_report) + self.delayed_report = None + gc.collect() + + +def _process_dataset_for_hf( + dataset: Dataset, feature_columns: Dict[str, List[str]], batch_size: int = 1 +) -> IterableDataset: + """Converts a Ray Dataset into a HF-compatible Torch Dataset.""" + torch_dataset = dataset.to_torch( + batch_size=batch_size, + feature_columns=feature_columns, + label_column="labels", + unsqueeze_label_tensor=False, + unsqueeze_feature_tensors=False, + ) + try: + count = dataset.count() + except ValueError: + # pipeline case + count = None + if count: + torch_dataset = _HFIterableDatasetWithLen(torch_dataset, count) + return torch_dataset + + +# TODO(team-ml): Refactor checkpoint management along with Tune. +class _DataParallelSyncingCheckpointManager(TuneCheckpointManager): + """Same as _DataParallelCheckpointManager, but syncs the dir instead + of serializing it.""" + + def add_tune_checkpoint_id(self, path: str): + # Store the checkpoint_id in the file so that the Tune trial can be + # resumed after failure or cancellation. + with open(Path(path).joinpath(TUNE_CHECKPOINT_ID), "w") as f: + f.write(str(self._latest_checkpoint_id)) + + def on_init(self, preprocessor: Preprocessor): + self.preprocessor = preprocessor + super(_DataParallelSyncingCheckpointManager, self).on_init() + + def write_checkpoint(self, checkpoint: Dict): + # If inside a Tune Trainable, then checkpoint with Tune. + with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: + source_ip = checkpoint[NODE_IP_KEY] + source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY] + target_ip = get_node_ip_address() + sync_dir_between_nodes( + source_ip=source_ip, + source_path=source_path, + target_ip=target_ip, + target_path=checkpoint_dir, + return_futures=False, + max_size_bytes=None, + ) + if source_ip == target_ip: + shutil.rmtree(source_path, ignore_errors=True) + with open(Path(checkpoint_dir).joinpath(PREPROCESSOR_KEY), "wb") as f: + cpickle.dump(self.preprocessor, f) + self.add_tune_checkpoint_id(checkpoint_dir) + + @property + def latest_checkpoint_dir(self) -> Optional[Path]: + raise NotImplementedError From 7b19a01ca3a64707feebaeb40bc35b2000feb9ee Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 12 Apr 2022 23:23:01 +0000 Subject: [PATCH 10/75] Add basic example --- python/ray/ml/examples/huggingface_example.py | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 python/ray/ml/examples/huggingface_example.py diff --git a/python/ray/ml/examples/huggingface_example.py b/python/ray/ml/examples/huggingface_example.py new file mode 100644 index 000000000000..9dbe22e33929 --- /dev/null +++ b/python/ray/ml/examples/huggingface_example.py @@ -0,0 +1,94 @@ +# Based on +# huggingface/notebooks/examples/language_modeling_from_scratch.ipynb + +from datasets import load_dataset +from transformers import ( + AutoTokenizer, + AutoConfig, + AutoModelForCausalLM, + Trainer, + TrainingArguments, +) + +import ray +import ray.data +from ray.ml.train.integrations.huggingface import HuggingFaceTrainer + +model_checkpoint = "gpt2" +tokenizer_checkpoint = "sgugger/gpt2-like-tokenizer" + +# block_size = tokenizer.model_max_length +block_size = 128 + + +def get_dataset(): + datasets = load_dataset("wikitext", "wikitext-2-raw-v1") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) + + def tokenize_function(examples): + return tokenizer(examples["text"]) + + tokenized_datasets = datasets.map( + tokenize_function, batched=True, num_proc=1, remove_columns=["text"] + ) + + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + batch_size=1000, + num_proc=1, + ) + return lm_datasets + + +lm_dataset = get_dataset() +ray_train = ray.data.from_arrow(lm_dataset["train"]._data.table) +ray_validation = ray.data.from_arrow(lm_dataset["validation"]._data.table) + + +def train_function(train_dataset, eval_dataset=None, **config): + model_config = AutoConfig.from_pretrained(model_checkpoint) + model = AutoModelForCausalLM.from_config(model_config) + print("Initializing TrainingArguments...") + training_args = TrainingArguments( + f"{model_checkpoint}-wikitext2", + evaluation_strategy="epoch", + num_train_epochs=2, + learning_rate=2e-5, + weight_decay=0.01, + disable_tqdm=True, + save_strategy="epoch", + ) + print("Initializing Trainer...") + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + print("Trainer initialized! Starting training...") + return trainer + + +trainer = HuggingFaceTrainer( + trainer_init_per_worker=train_function, + scaling_config={"num_workers": 2, "use_gpu": False}, + datasets={"train": ray_train.limit(16), "evaluation": ray_validation.limit(8)}, +) +results = trainer.fit() +print(results.metrics) From 0a0223e5c7ae132ddbaafba8a26d91590cb53954 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 12 Apr 2022 23:24:23 +0000 Subject: [PATCH 11/75] Remove notebook --- language_modeling_from_scratch copy.ipynb | 579 ---------------------- 1 file changed, 579 deletions(-) delete mode 100644 language_modeling_from_scratch copy.ipynb diff --git a/language_modeling_from_scratch copy.ipynb b/language_modeling_from_scratch copy.ipynb deleted file mode 100644 index 036939ab4000..000000000000 --- a/language_modeling_from_scratch copy.ipynb +++ /dev/null @@ -1,579 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "X4cRE8IbIrIV" - }, - "source": [ - "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets. Uncomment the following cell and run it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "id": "MOsHUjgdIrIW", - "outputId": "f84a093e-147f-470e-aad9-80fb51193c8e" - }, - "outputs": [], - "source": [ - "#! pip install datasets transformers" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", - "\n", - "To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.\n", - "\n", - "First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then execute the following cell and input your username and password:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then you need to install Git-LFS. Uncomment the following instructions:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# !apt install git-lfs" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Make sure your version of Transformers is at least 4.11.0 since the functionality was introduced in that version:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import transformers\n", - "\n", - "print(transformers.__version__)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HFASsisvIrIb" - }, - "source": [ - "You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs [here](https://github.com/huggingface/transformers/tree/master/examples/language-modeling)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "a3KD3WXU3l-O" - }, - "source": [ - "# Train a language model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JAscNNUD3l-P" - }, - "source": [ - "In this notebook, we'll see how to train a [🤗 Transformers](https://github.com/huggingface/transformers) model on a language modeling task. We will cover two types of language modeling tasks which are:\n", - "\n", - "- Causal language modeling: the model has to predict the next token in the sentence (so the labels are the same as the inputs shifted to the right). To make sure the model does not cheat, it gets an attention mask that will prevent it to access the tokens after token i when trying to predict the token i+1 in the sentence.\n", - "\n", - "![Widget inference representing the causal language modeling task](images/causal_language_modeling.png)\n", - "\n", - "- Masked language modeling: the model has to predict some tokens that are masked in the input. It still has access to the whole sentence, so it can use the tokens before and after the tokens masked to predict their value.\n", - "\n", - "![Widget inference representing the masked language modeling task](images/masked_language_modeling.png)\n", - "\n", - "We will see how to easily load and preprocess the dataset for each one of those tasks, and how to use the `Trainer` API to train a model on it.\n", - "\n", - "This notebooks assumes you have trained a tokenizer on the corpus you are using, see the [How to train a tokenizer](https://github.com/huggingface/notebooks/blob/master/examples/tokenizer_training.ipynb) notebook ([open in colab](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/tokenizer_training.ipynb)).\n", - "\n", - "A script version of this notebook you can directly run on a distributed environment or on TPU is available in our [examples folder](https://github.com/huggingface/transformers/tree/master/examples)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1r_n9OWV3l-Q" - }, - "source": [ - "## Preparing the dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kswRMhPc3l-Q" - }, - "source": [ - "For each of those tasks, we will use the [Wikitext 2]() dataset as an example. You can load it very easily with the 🤗 Datasets library." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "n2ZRs1cL3l-R", - "outputId": "11151c56-be90-4d11-e7df-db85e745ca5c" - }, - "outputs": [], - "source": [ - "from datasets import load_dataset\n", - "# datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "f1-9jepM3l-W" - }, - "source": [ - "You can replace the dataset above with any dataset hosted on [the hub](https://huggingface.co/datasets) or use your own files. Just uncomment the following cell and replace the paths with values that will lead to your files:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "uxSaGa_l3l-W" - }, - "outputs": [], - "source": [ - "# datasets = load_dataset(\"text\", data_files={\"train\": path_to_train.txt, \"validation\": path_to_validation.txt}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jY1SwIrY3l-a" - }, - "source": [ - "You can also load datasets from a csv or a JSON file, see the [full documentation](https://huggingface.co/docs/datasets/loading_datasets.html#from-local-files) for more information." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "u3EtYfeHIrIz" - }, - "source": [ - "To access an actual element, you need to select a split first, then give an index:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WHUmphG3IrI3" - }, - "source": [ - "To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ur5sNUcZ3l-g" - }, - "outputs": [], - "source": [ - "from datasets import ClassLabel\n", - "import random\n", - "import pandas as pd\n", - "from IPython.display import display, HTML\n", - "\n", - "def show_random_elements(dataset, num_examples=10):\n", - " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n", - " picks = []\n", - " for _ in range(num_examples):\n", - " pick = random.randint(0, len(dataset)-1)\n", - " while pick in picks:\n", - " pick = random.randint(0, len(dataset)-1)\n", - " picks.append(pick)\n", - " \n", - " df = pd.DataFrame(dataset[picks])\n", - " for column, typ in dataset.features.items():\n", - " if isinstance(typ, ClassLabel):\n", - " df[column] = df[column].transform(lambda i: typ.names[i])\n", - " display(HTML(df.to_html()))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CKerdF353l-o" - }, - "source": [ - "As we can see, some of the texts are a full paragraph of a Wikipedia article while others are just titles or empty lines." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JEA1ju653l-p" - }, - "source": [ - "## Causal Language modeling" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "v5GTGKZS3l-q" - }, - "source": [ - "For causal language modeling (CLM) we are going to take all the texts in our dataset and concatenate them after they are tokenized. Then we will split them in examples of a certain sequence length. This way the model will receive chunks of contiguous text that may look like:\n", - "```\n", - "part of text 1\n", - "```\n", - "or \n", - "```\n", - "end of text 1 [BOS_TOKEN] beginning of text 2\n", - "```\n", - "depending on whether they span over several of the original texts in the dataset or not. The labels will be the same as the inputs, shifted to the left.\n", - "\n", - "We will use the [`gpt2`](https://huggingface.co/gpt2) architecture for this example. You can pick any of the checkpoints listed [here](https://huggingface.co/models?filter=causal-lm) instead. For the tokenizer, you can replace the checkpoint by the one you trained yourself." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "-WGBCO343l-q" - }, - "outputs": [], - "source": [ - "model_checkpoint = \"gpt2\"\n", - "tokenizer_checkpoint = \"sgugger/gpt2-like-tokenizer\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5io6fY_d3l-u" - }, - "source": [ - "To tokenize all our texts with the same vocabulary that was used when training the model, we have to download a pretrained tokenizer. This is all done by the `AutoTokenizer` class:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "iAYlS40Z3l-v" - }, - "outputs": [], - "source": [ - "from transformers import AutoTokenizer" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rpOiBrJ13l-y" - }, - "source": [ - "We can now call the tokenizer on all our texts. This is very simple, using the [`map`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map) method from the Datasets library. First we define a function that call the tokenizer on our texts:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "M9xVAa3s3l-2" - }, - "source": [ - "Then we apply it to all the splits in our `datasets` object, using `batched=True` and 4 processes to speed up the preprocessing. We won't need the `text` column afterward, so we discard it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "NVAO0H8u3l-3", - "outputId": "30d88b8a-e353-4e13-f709-8e5e06ef747b" - }, - "outputs": [], - "source": [ - "# tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=[\"text\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8qik3J_C3l-7" - }, - "source": [ - "If we now look at an element of our datasets, we will see the text have been replaced by the `input_ids` the model will need:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "obvgcXda3l--" - }, - "source": [ - "Now for the harder part: we need to concatenate all our texts together then split the result in small chunks of a certain `block_size`. To do this, we will use the `map` method again, with the option `batched=True`. This option actually lets us change the number of examples in the datasets by returning a different number of examples than we got. This way, we can create our new samples from a batch of examples.\n", - "\n", - "First, we grab the maximum length our model was pretrained with. This might be a big too big to fit in your GPU RAM, so here we take a bit less at just 128." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DVHs5aCA3l-_" - }, - "outputs": [], - "source": [ - "# block_size = tokenizer.model_max_length\n", - "block_size = 128" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RpNfGiMw3l_A" - }, - "source": [ - "Then we write the preprocessing function that will group our texts:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LGJWXtNv3l_C" - }, - "source": [ - "First note that we duplicate the inputs for our labels. This is because the model of the 🤗 Transformers library apply the shifting to the right, so we don't need to do it manually.\n", - "\n", - "Also note that by default, the `map` method will send a batch of 1,000 examples to be treated by the preprocessing function. So here, we will drop the remainder to make the concatenated tokenized texts a multiple of `block_size` every 1,000 examples. You can adjust this behavior by passing a higher batch size (which will also be processed slower). You can also speed-up the preprocessing by using multiprocessing:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6n84V8Gc3l_G" - }, - "source": [ - "And we can check our datasets have changed: now the samples contain chunks of `block_size` contiguous tokens, potentially spanning over several of our original texts." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "iEmeQ7Xm3l_H" - }, - "source": [ - "Now that the data has been cleaned, we're ready to instantiate our `Trainer`. First we create the model using the same config as our checkpoint, but initialized with random weights:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sPqQA3TT3l_I" - }, - "outputs": [], - "source": [ - "from transformers import AutoConfig, AutoModelForCausalLM" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VyPQTOF_3l_J" - }, - "source": [ - "And we will needsome `TrainingArguments`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jElf8LJ33l_K" - }, - "outputs": [], - "source": [ - "from transformers import Trainer, TrainingArguments" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The last argument to setup everything so we can push the model to the [Hub](https://huggingface.co/models) regularly during training. Remove it if you didn't follow the installation steps at the top of the notebook. If you want to save your model locally in a name that is different than the name of the repository it will be pushed, or if you want to push your model under an organization and not your name space, use the `hub_model_id` argument to set the repo name (it needs to be the full name, including your namespace: for instance `\"sgugger/gpt-finetuned-wikitext2\"` or `\"huggingface/gpt-finetuned-wikitext2\"`)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sZRbT9ui3l_N" - }, - "source": [ - "We pass along all of those to the `Trainer` class:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ray" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def get_dataset():\n", - " datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')\n", - " tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)\n", - " def tokenize_function(examples):\n", - " return tokenizer(examples[\"text\"])\n", - " tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=1, remove_columns=[\"text\"])\n", - " def group_texts(examples):\n", - " # Concatenate all texts.\n", - " concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n", - " total_length = len(concatenated_examples[list(examples.keys())[0]])\n", - " # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n", - " # customize this part to your needs.\n", - " total_length = (total_length // block_size) * block_size\n", - " # Split by chunks of max_len.\n", - " result = {\n", - " k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n", - " for k, t in concatenated_examples.items()\n", - " }\n", - " result[\"labels\"] = result[\"input_ids\"].copy()\n", - " return result\n", - " lm_datasets = tokenized_datasets.map(\n", - " group_texts,\n", - " batched=True,\n", - " batch_size=1000,\n", - " num_proc=1,\n", - " )\n", - " return lm_datasets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import ray.data\n", - "lm_dataset = get_dataset()\n", - "ray_train = ray.data.from_arrow(lm_dataset[\"train\"]._data.table)\n", - "ray_validation = ray.data.from_arrow(lm_dataset[\"validation\"]._data.table)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from ray.ml.train.integrations.huggingface import HuggingFaceTrainer\n", - "\n", - "def train_function(train_dataset, eval_dataset = None, **config):\n", - " model_config = AutoConfig.from_pretrained(model_checkpoint)\n", - " model = AutoModelForCausalLM.from_config(model_config)\n", - " print(\"initializing training_args\")\n", - " training_args = TrainingArguments(\n", - " f\"{model_checkpoint}-wikitext2\",\n", - " evaluation_strategy = \"epoch\",\n", - " num_train_epochs=2,\n", - " learning_rate=2e-5,\n", - " weight_decay=0.01,\n", - " disable_tqdm=True,\n", - " save_strategy=\"epoch\",\n", - " )\n", - " print(\"initializing\")\n", - " trainer = Trainer(\n", - " model=model,\n", - " args=training_args,\n", - " train_dataset=train_dataset,\n", - " eval_dataset=eval_dataset,\n", - " )\n", - " print(\"trainer initialized\")\n", - " return trainer\n", - "\n", - "trainer = HuggingFaceTrainer(\n", - " trainer_init_per_worker=train_function,\n", - " scaling_config={\"num_workers\": 2, \"use_gpu\": False},\n", - " datasets={\"train\": ray_train.limit(16), \"evaluation\": ray_validation.limit(8)},\n", - ")\n", - "trainer.fit()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6Vvz34Td3l_O" - }, - "source": [ - "And we can train our model:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3APq-vUc3l_R" - }, - "source": [ - "Once the training is completed, we can evaluate our model and get its perplexity on the validation set like this:" - ] - } - ], - "metadata": { - "colab": { - "name": "Train a language model", - "provenance": [] - }, - "interpreter": { - "hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f" - }, - "kernelspec": { - "display_name": "Python 3.8.10 ('venv': venv)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - } - }, - "nbformat": 4, - "nbformat_minor": 1 -} From ec83f909c610767f920306d9df74d5cf35ddce7c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 12 Apr 2022 23:26:00 +0000 Subject: [PATCH 12/75] Doc --- python/ray/ml/utils/torch_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/ray/ml/utils/torch_utils.py b/python/ray/ml/utils/torch_utils.py index ad209967d708..e0f4647dcd0b 100644 --- a/python/ray/ml/utils/torch_utils.py +++ b/python/ray/ml/utils/torch_utils.py @@ -49,8 +49,10 @@ def convert_pandas_to_torch_tensor( columns = columns if columns else [] def tensorize(vals, dtype): + """This recursive function allows to convert pyarrow List dtypes + to multi-dimensional tensors.""" if vals.dtype == np.object: - # TODO: clarify if this should be cat or stack + # TODO(yard1): clarify if this should be cat or stack return torch.stack([tensorize(x, dtype) for x in vals]) return torch.as_tensor(vals, dtype=dtype) From 66cef4ae66006a4980e24ae1868817a6415c542c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 12 Apr 2022 23:30:50 +0000 Subject: [PATCH 13/75] Doc --- python/ray/ml/utils/torch_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ray/ml/utils/torch_utils.py b/python/ray/ml/utils/torch_utils.py index e0f4647dcd0b..da1a95845094 100644 --- a/python/ray/ml/utils/torch_utils.py +++ b/python/ray/ml/utils/torch_utils.py @@ -29,6 +29,8 @@ def convert_pandas_to_torch_tensor( column_dtype (Optional[Union[torch.dtype, List[torch.dtype]): The torch dtype to use for the tensor. If set to None, then automatically infer the dtype. + unsqueeze: Whether to unsqueeze (reshape to a 2d, 1 column tensor) + the columns or not. Returns: Either a torch tensor of size (N, len(columns)) where N is the From 347ba0b02ff9913224b8394c5c55e6ea6f3a9041 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 10:34:27 +0000 Subject: [PATCH 14/75] Better example --- .../ray/ml/examples/huggingface/__init__.py | 0 ...ngface_basic_language_modelling_example.py | 187 ++++++++++++++++++ python/ray/ml/examples/huggingface_example.py | 94 --------- 3 files changed, 187 insertions(+), 94 deletions(-) create mode 100644 python/ray/ml/examples/huggingface/__init__.py create mode 100644 python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py delete mode 100644 python/ray/ml/examples/huggingface_example.py diff --git a/python/ray/ml/examples/huggingface/__init__.py b/python/ray/ml/examples/huggingface/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py b/python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py new file mode 100644 index 000000000000..46a326dd140c --- /dev/null +++ b/python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py @@ -0,0 +1,187 @@ +# Based on +# huggingface/notebooks/examples/language_modeling_from_scratch.ipynb + +import argparse +import tempfile +from datasets import load_dataset +from transformers import ( + AutoTokenizer, + AutoConfig, + AutoModelForCausalLM, + Trainer, + TrainingArguments, +) + +import ray +import ray.data +from ray.ml.train.integrations.huggingface import HuggingFaceTrainer + +def main( + model_checkpoint="gpt2", + tokenizer_checkpoint="sgugger/gpt2-like-tokenizer", + dataset_name="wikitext-2-raw-v1", + dataset_path="wikitext", + num_epochs=5, + num_workers=2, + use_gpu=False, + smoke_test=False, +): + # block_size = tokenizer.model_max_length + block_size = 128 + + # Run this as a remote function to avoid downloading on the driver + @ray.remote + def get_dataset(): + datasets = load_dataset(dataset_path, dataset_name) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) + + def tokenize_function(examples): + return tokenizer(examples["text"]) + + tokenized_datasets = datasets.map( + tokenize_function, batched=True, num_proc=1, remove_columns=["text"] + ) + + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported + # it instead of this drop, you can + # customize this part to your needs. + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + batch_size=1000, + num_proc=1, + ) + ray_train = ray.data.from_arrow(lm_datasets["train"]._data.table) + ray_validation = ray.data.from_arrow(lm_datasets["validation"]._data.table) + return ray_train, ray_validation + + ray_train, ray_validation = ray.get(get_dataset.remote()) + + def train_function(train_dataset, eval_dataset=None, **config): + model_config = AutoConfig.from_pretrained(model_checkpoint) + model = AutoModelForCausalLM.from_config(model_config) + print("Initializing TrainingArguments...") + # The checkpoints will be moved to Ray Tune results + # directory automatically + training_dir = tempfile.mkdtemp() + training_args = TrainingArguments( + training_dir, + evaluation_strategy="epoch", + num_train_epochs=num_epochs, + learning_rate=2e-5, + weight_decay=0.01, + disable_tqdm=True, + save_strategy="epoch", + ) + print("Initializing Trainer...") + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + print("Trainer initialized! Starting training...") + return trainer + + if smoke_test: + ray_train = ray_train.limit(16) + ray_validation = ray_validation.limit(8) + + trainer = HuggingFaceTrainer( + trainer_init_per_worker=train_function, + scaling_config={"num_workers": num_workers, "use_gpu": use_gpu}, + datasets={"train": ray_train, "evaluation": ray_validation}, + ) + results = trainer.fit() + print(results.metrics) + + +if __name__ == "__main__": + # Training settings + parser = argparse.ArgumentParser( + description="Language modelling from scratch with HuggingFaceTrainer Example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--model-checkpoint", + type=str, + default="gpt2", + help="Model checkpoint name to download from HF hub", + ) + parser.add_argument( + "--tokenizer-checkpoint", + type=str, + default="sgugger/gpt2-like-tokenizer", + help="Tokenizer checkpoint name to download from HF hub", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="wikitext-2-raw-v1", + help="Dataset name to download from HF hub", + ) + parser.add_argument( + "--dataset-path", + type=str, + default="wikitext", + help="Path on the head node to save the dataset to", + ) + parser.add_argument( + "--num-epochs", + type=int, + default=5, + help="number of epochs to train (default: 5)", + ) + parser.add_argument( + "--use-gpu", action="store_true", default=False, help="enables CUDA training" + ) + parser.add_argument( + "--num-workers", + type=int, + default=2, + help="Number of Ray workers to use for training.", + ) + parser.add_argument( + "--smoke-test", + action="store_true", + default=False, + help="Limit dataset size to finish quickly for testing", + ) + parser.add_argument( + "--address", + required=False, + type=str, + default=None, + help="Address of Ray cluster.", + ) + + args = parser.parse_args() + + if args.address: + ray.init(args.address) + else: + ray.init() + + main( + model_checkpoint=args.model_checkpoint, + tokenizer_checkpoint=args.tokenizer_checkpoint, + dataset_name=args.dataset_name, + dataset_path=args.dataset_path, + num_epochs=args.num_epochs, + num_workers=args.num_workers, + use_gpu=args.use_gpu, + smoke_test=args.smoke_test, + ) diff --git a/python/ray/ml/examples/huggingface_example.py b/python/ray/ml/examples/huggingface_example.py deleted file mode 100644 index 9dbe22e33929..000000000000 --- a/python/ray/ml/examples/huggingface_example.py +++ /dev/null @@ -1,94 +0,0 @@ -# Based on -# huggingface/notebooks/examples/language_modeling_from_scratch.ipynb - -from datasets import load_dataset -from transformers import ( - AutoTokenizer, - AutoConfig, - AutoModelForCausalLM, - Trainer, - TrainingArguments, -) - -import ray -import ray.data -from ray.ml.train.integrations.huggingface import HuggingFaceTrainer - -model_checkpoint = "gpt2" -tokenizer_checkpoint = "sgugger/gpt2-like-tokenizer" - -# block_size = tokenizer.model_max_length -block_size = 128 - - -def get_dataset(): - datasets = load_dataset("wikitext", "wikitext-2-raw-v1") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) - - def tokenize_function(examples): - return tokenizer(examples["text"]) - - tokenized_datasets = datasets.map( - tokenize_function, batched=True, num_proc=1, remove_columns=["text"] - ) - - def group_texts(examples): - # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} - total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can - # customize this part to your needs. - total_length = (total_length // block_size) * block_size - # Split by chunks of max_len. - result = { - k: [t[i : i + block_size] for i in range(0, total_length, block_size)] - for k, t in concatenated_examples.items() - } - result["labels"] = result["input_ids"].copy() - return result - - lm_datasets = tokenized_datasets.map( - group_texts, - batched=True, - batch_size=1000, - num_proc=1, - ) - return lm_datasets - - -lm_dataset = get_dataset() -ray_train = ray.data.from_arrow(lm_dataset["train"]._data.table) -ray_validation = ray.data.from_arrow(lm_dataset["validation"]._data.table) - - -def train_function(train_dataset, eval_dataset=None, **config): - model_config = AutoConfig.from_pretrained(model_checkpoint) - model = AutoModelForCausalLM.from_config(model_config) - print("Initializing TrainingArguments...") - training_args = TrainingArguments( - f"{model_checkpoint}-wikitext2", - evaluation_strategy="epoch", - num_train_epochs=2, - learning_rate=2e-5, - weight_decay=0.01, - disable_tqdm=True, - save_strategy="epoch", - ) - print("Initializing Trainer...") - trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - ) - print("Trainer initialized! Starting training...") - return trainer - - -trainer = HuggingFaceTrainer( - trainer_init_per_worker=train_function, - scaling_config={"num_workers": 2, "use_gpu": False}, - datasets={"train": ray_train.limit(16), "evaluation": ray_validation.limit(8)}, -) -results = trainer.fit() -print(results.metrics) From c58ec5fea1dfa8be061f89a625d26b23af8eb9f4 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 10:38:46 +0000 Subject: [PATCH 15/75] Lint --- .../huggingface/huggingface_basic_language_modelling_example.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py b/python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py index 46a326dd140c..a33c2ddead4e 100644 --- a/python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py +++ b/python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py @@ -16,6 +16,7 @@ import ray.data from ray.ml.train.integrations.huggingface import HuggingFaceTrainer + def main( model_checkpoint="gpt2", tokenizer_checkpoint="sgugger/gpt2-like-tokenizer", From 5d252d84e2831951ecb4e0f3274449fbd2b84388 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 10:39:57 +0000 Subject: [PATCH 16/75] Typo fix --- python/ray/ml/utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/ml/utils/torch_utils.py b/python/ray/ml/utils/torch_utils.py index da1a95845094..66eb64fb68c1 100644 --- a/python/ray/ml/utils/torch_utils.py +++ b/python/ray/ml/utils/torch_utils.py @@ -92,7 +92,7 @@ def load_torch_model( saved_model: Union[torch.nn.Module, Dict], model_definition: Optional[torch.nn.Module] = None, ) -> torch.nn.Module: - """Loads a PyTorch model from the provided``saved_model``. + """Loads a PyTorch model from the provided ``saved_model``. If ``saved_model`` is a torch Module, then return it directly. If ``saved_model`` is a torch state dict, then load it in the ``model_definition`` and return the loaded From 0fc3a045916a477ef326a40b25834fdb939e7dc2 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 18:27:32 +0000 Subject: [PATCH 17/75] _checkpoint_manager_cls --- python/ray/ml/train/data_parallel_trainer.py | 54 +++++------ .../huggingface/huggingface_trainer.py | 93 ++++++++++--------- 2 files changed, 77 insertions(+), 70 deletions(-) diff --git a/python/ray/ml/train/data_parallel_trainer.py b/python/ray/ml/train/data_parallel_trainer.py index 18cda0ca1e18..c199fb276e76 100644 --- a/python/ray/ml/train/data_parallel_trainer.py +++ b/python/ray/ml/train/data_parallel_trainer.py @@ -1,7 +1,7 @@ import inspect import logging from pathlib import Path -from typing import Dict, Callable, Optional, Union +from typing import Callable, Dict, Optional, Type, Union import ray from ray import tune @@ -19,6 +19,27 @@ logger = logging.getLogger(__name__) +# TODO(team-ml): Refactor checkpoint management along with Tune. +class _DataParallelCheckpointManager(TuneCheckpointManager): + def on_init(self, preprocessor: Preprocessor): + self.preprocessor = preprocessor + super(_DataParallelCheckpointManager, self).on_init() + + def write_checkpoint(self, checkpoint: Dict): + self.add_tune_checkpoint_id(checkpoint) + + # Add the preprocessor to the checkpoint. + checkpoint[PREPROCESSOR_KEY] = self.preprocessor + + checkpoint_obj = Checkpoint.from_dict(checkpoint) + # If inside a Tune Trainable, then checkpoint with Tune. + with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: + checkpoint_obj.to_directory(path=checkpoint_dir) + + @property + def latest_checkpoint_dir(self) -> Optional[Path]: + raise NotImplementedError + @DeveloperAPI class DataParallelTrainer(Trainer): @@ -181,6 +202,10 @@ def __init__(self, train_loop_per_worker, my_backend_config: resume_from_checkpoint: A checkpoint to resume training from. """ + _checkpoint_manager_cls: Type[ + TuneCheckpointManager + ] = _DataParallelCheckpointManager + def __init__( self, *, @@ -249,9 +274,6 @@ def _validate_train_loop_per_worker( f"but it accepts {num_params} arguments instead." ) - def _get_checkpoint_manager(self) -> TuneCheckpointManager: - return _DataParallelCheckpointManager() - def training_loop(self) -> None: scaling_config_dataclass = ScalingConfigDataClass(**self.scaling_config) @@ -274,7 +296,7 @@ def training_loop(self) -> None: max_retries=0, ) - checkpoint_manager = self._get_checkpoint_manager() + checkpoint_manager = self._checkpoint_manager_cls() checkpoint_manager.on_init(preprocessor=self.preprocessor) # Start the remote actors. @@ -319,25 +341,3 @@ def training_loop(self) -> None: # Shutdown workers. backend_executor.shutdown() - - -# TODO(team-ml): Refactor checkpoint management along with Tune. -class _DataParallelCheckpointManager(TuneCheckpointManager): - def on_init(self, preprocessor: Preprocessor): - self.preprocessor = preprocessor - super(_DataParallelCheckpointManager, self).on_init() - - def write_checkpoint(self, checkpoint: Dict): - self.add_tune_checkpoint_id(checkpoint) - - # Add the preprocessor to the checkpoint. - checkpoint[PREPROCESSOR_KEY] = self.preprocessor - - checkpoint_obj = Checkpoint.from_dict(checkpoint) - # If inside a Tune Trainable, then checkpoint with Tune. - with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: - checkpoint_obj.to_directory(path=checkpoint_dir) - - @property - def latest_checkpoint_dir(self) -> Optional[Path]: - raise NotImplementedError diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 1fcdf4f9304a..7f1d2a4131a1 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -45,6 +45,52 @@ CHECKPOINT_PATH_ON_NODE_KEY = "checkpoint_path_on_node" +# TODO(team-ml): Refactor checkpoint management along with Tune. +class _DataParallelSyncingCheckpointManager(TuneCheckpointManager): + """Same as _DataParallelCheckpointManager, but syncs the dir instead + of serializing it.""" + + def add_tune_checkpoint_id(self, path: str): + # Store the checkpoint_id in the file so that the Tune trial can be + # resumed after failure or cancellation. + with open(Path(path).joinpath(TUNE_CHECKPOINT_ID), "w") as f: + f.write(str(self._latest_checkpoint_id)) + + def on_init(self, preprocessor: Preprocessor): + self.preprocessor = preprocessor + super(_DataParallelSyncingCheckpointManager, self).on_init() + + def write_checkpoint(self, checkpoint: Dict): + # If inside a Tune Trainable, then checkpoint with Tune. + with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: + source_ip = checkpoint[NODE_IP_KEY] + source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY] + target_ip = get_node_ip_address() + if source_ip == target_ip: + # Move contents of source_path, but not source_path + # itself. shutil.move is already recursive. + for path in Path(source_path).iterdir(): + shutil.move(str(path.absolute()), checkpoint_dir) + shutil.rmtree(source_path, ignore_errors=True) + else: + sync_dir_between_nodes( + source_ip=source_ip, + source_path=source_path, + target_ip=target_ip, + target_path=checkpoint_dir, + return_futures=False, + max_size_bytes=None, + ) + delete_on_node(node_ip=source_ip, path=source_path) + with open(Path(checkpoint_dir).joinpath(PREPROCESSOR_KEY), "wb") as f: + cpickle.dump(self.preprocessor, f) + self.add_tune_checkpoint_id(checkpoint_dir) + + @property + def latest_checkpoint_dir(self) -> Optional[Path]: + raise NotImplementedError + + @PublicAPI(stability="alpha") class HuggingFaceTrainer(TorchTrainer): """A Trainer for data parallel HuggingFace Transformers on PyTorch training. @@ -186,6 +232,10 @@ def trainer_init_per_worker(train_dataset, eval_dataset, **config): resume_from_checkpoint: A checkpoint to resume training from. """ + _checkpoint_manager_cls: Type[ + TuneCheckpointManager + ] = _DataParallelSyncingCheckpointManager + def __init__( self, *, @@ -277,9 +327,6 @@ def _validate_train_loop_per_worker( # trainer_init_per_worker instead. pass - def _get_checkpoint_manager(self) -> TuneCheckpointManager: - return _DataParallelSyncingCheckpointManager() - def _create_train_func( self, trainer_init_per_worker: Callable[ @@ -522,43 +569,3 @@ def _process_dataset_for_hf( if count: torch_dataset = _HFIterableDatasetWithLen(torch_dataset, count) return torch_dataset - - -# TODO(team-ml): Refactor checkpoint management along with Tune. -class _DataParallelSyncingCheckpointManager(TuneCheckpointManager): - """Same as _DataParallelCheckpointManager, but syncs the dir instead - of serializing it.""" - - def add_tune_checkpoint_id(self, path: str): - # Store the checkpoint_id in the file so that the Tune trial can be - # resumed after failure or cancellation. - with open(Path(path).joinpath(TUNE_CHECKPOINT_ID), "w") as f: - f.write(str(self._latest_checkpoint_id)) - - def on_init(self, preprocessor: Preprocessor): - self.preprocessor = preprocessor - super(_DataParallelSyncingCheckpointManager, self).on_init() - - def write_checkpoint(self, checkpoint: Dict): - # If inside a Tune Trainable, then checkpoint with Tune. - with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: - source_ip = checkpoint[NODE_IP_KEY] - source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY] - target_ip = get_node_ip_address() - sync_dir_between_nodes( - source_ip=source_ip, - source_path=source_path, - target_ip=target_ip, - target_path=checkpoint_dir, - return_futures=False, - max_size_bytes=None, - ) - if source_ip == target_ip: - shutil.rmtree(source_path, ignore_errors=True) - with open(Path(checkpoint_dir).joinpath(PREPROCESSOR_KEY), "wb") as f: - cpickle.dump(self.preprocessor, f) - self.add_tune_checkpoint_id(checkpoint_dir) - - @property - def latest_checkpoint_dir(self) -> Optional[Path]: - raise NotImplementedError From 28ae105de204cad23b0fea0b1ff69be3e1d2d1f1 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 18:27:53 +0000 Subject: [PATCH 18/75] Improve checkpointing --- .../huggingface/huggingface_trainer.py | 54 +++++++++++-------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 7f1d2a4131a1..f578bc309ec3 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -28,7 +28,7 @@ from ray.ml.preprocessor import Preprocessor from ray.ml.checkpoint import Checkpoint from ray.util import PublicAPI, get_node_ip_address -from ray.tune.utils.file_transfer import sync_dir_between_nodes +from ray.tune.utils.file_transfer import delete_on_node, sync_dir_between_nodes from ray.ml.constants import TRAIN_DATASET_KEY, EVALUATION_DATASET_KEY, PREPROCESSOR_KEY # This trainer uses a special checkpoint syncing logic. @@ -288,35 +288,43 @@ def __new__(cls, *args, **kwargs): # This is required to ensure that the dir syncing logic # is used instead of serializing several gigabytes of data # when a Checkpoint is sent to a Ray Actor. - if ( - "resume_from_checkpoint" in kwargs - and kwargs["resume_from_checkpoint"]._local_path - ): - checkpoint_path = kwargs["resume_from_checkpoint"].to_directory() - - # Load checkpoint from path. - checkpoint_path = Path(checkpoint_path).expanduser().absolute() - if not checkpoint_path.exists(): - raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.") - with open(checkpoint_path.joinpath(TUNE_CHECKPOINT_ID), "r") as f: - tune_checkpoint_id = int(f.read()) - - kwargs["resume_from_checkpoint"] = Checkpoint.from_dict( - { - NODE_IP_KEY: get_node_ip_address(), - CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), - TUNE_CHECKPOINT_ID: tune_checkpoint_id, - } - ) + if "resume_from_checkpoint" in kwargs: + resume_from_checkpoint: Checkpoint = kwargs["resume_from_checkpoint"] + ( + checkpoint_type, + checkpoint_path, + ) = resume_from_checkpoint.get_internal_representation() + if checkpoint_type != "local_path": + raise ValueError( + "Unexpected checkpoint type in `resume_from_checkpoint`. " + f"Expected 'local_path', got '{checkpoint_type}'" + ) + if checkpoint_path: + # Load checkpoint from path. + checkpoint_path = Path(checkpoint_path).expanduser().absolute() + if not checkpoint_path.exists(): + raise ValueError( + f"Checkpoint path {checkpoint_path} does not exist." + ) + with open(checkpoint_path.joinpath(TUNE_CHECKPOINT_ID), "r") as f: + tune_checkpoint_id = int(f.read()) + + kwargs["resume_from_checkpoint"] = Checkpoint.from_dict( + { + NODE_IP_KEY: get_node_ip_address(), + CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), + TUNE_CHECKPOINT_ID: tune_checkpoint_id, + } + ) return super(HuggingFaceTrainer, cls).__new__(cls, *args, **kwargs) def _validate_trainer_init_per_worker( self, trainer_init_per_worker: Callable, fn_name: str ) -> None: num_params = len(inspect.signature(trainer_init_per_worker).parameters) - if num_params != 3: + if num_params < 3: raise ValueError( - f"{fn_name} should take in 3 arguments, " + f"{fn_name} should take in at least 3 arguments, " f"but it accepts {num_params} arguments instead." ) From 3263bd4179583901e262ee6515c227cc171d40c2 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 18:34:38 +0000 Subject: [PATCH 19/75] cleanup checkpoint --- .../integrations/huggingface/huggingface_trainer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index f578bc309ec3..65298cadf0d4 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -6,8 +6,6 @@ import inspect import gc from unittest.mock import patch -from ray.train.checkpoint import TuneCheckpointManager -from ray.train.constants import TUNE_CHECKPOINT_ID import torch import transformers.trainer @@ -28,6 +26,8 @@ from ray.ml.preprocessor import Preprocessor from ray.ml.checkpoint import Checkpoint from ray.util import PublicAPI, get_node_ip_address +from ray.train.checkpoint import TuneCheckpointManager +from ray.train.constants import TUNE_CHECKPOINT_ID from ray.tune.utils.file_transfer import delete_on_node, sync_dir_between_nodes from ray.ml.constants import TRAIN_DATASET_KEY, EVALUATION_DATASET_KEY, PREPROCESSOR_KEY @@ -476,6 +476,7 @@ def _save(self, *args, **kwargs): checkpoint = train.load_checkpoint() checkpoint_path = None + remove_checkpoint_path = False if checkpoint: source_ip = checkpoint[NODE_IP_KEY] source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY] @@ -484,7 +485,10 @@ def _save(self, *args, **kwargs): checkpoint_path = source_path else: # TODO(yard1): Confirm if tempdir is the right approach here. - checkpoint_path = tempfile.mkdtemp() + checkpoint_path = tempfile.mkdtemp( + suffix=Path(trainer.args.output_dir).name + ) + remove_checkpoint_path = True sync_dir_between_nodes( source_ip=source_ip, source_path=source_path, @@ -494,6 +498,8 @@ def _save(self, *args, **kwargs): max_size_bytes=None, ) trainer.train(resume_from_checkpoint=checkpoint_path) + if remove_checkpoint_path: + shutil.rmtree(checkpoint_path, ignore_errors=True) return train_loop_per_worker From 35cc359ef04643ee26ebb72af8baa0f1ff8e07a5 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 18:36:10 +0000 Subject: [PATCH 20/75] Sort imports --- .../huggingface/huggingface_trainer.py | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 65298cadf0d4..f5179a67f92e 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -1,35 +1,36 @@ -from pathlib import Path +import gc +import inspect +import os import shutil import tempfile -from typing import Any, Callable, Generator, Iterator, List, Optional, Dict, Type -import os -import inspect -import gc +from pathlib import Path +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Type from unittest.mock import patch import torch import transformers.trainer -from transformers.training_args import TrainingArguments +import ray.cloudpickle as cpickle +from torch.utils.data import DataLoader +from torch.utils.data import Dataset as TorchDataset +from torch.utils.data import IterableDataset from transformers.trainer_callback import TrainerCallback -from torch.utils.data import Dataset as TorchDataset, IterableDataset, DataLoader - +from transformers.training_args import TrainingArguments -from ray import train -from ray import tune -import ray.cloudpickle as cpickle +from ray import train, tune +from ray.util import PublicAPI, get_node_ip_address from ray.data.dataset import Dataset -from ray.train.torch import TorchConfig -from ray.train.session import get_session -from ray.ml.trainer import GenDataset -from ray.ml.train.integrations.torch import TorchTrainer -from ray.ml.config import ScalingConfig, RunConfig -from ray.ml.preprocessor import Preprocessor from ray.ml.checkpoint import Checkpoint -from ray.util import PublicAPI, get_node_ip_address +from ray.ml.config import RunConfig, ScalingConfig +from ray.ml.constants import EVALUATION_DATASET_KEY, PREPROCESSOR_KEY, TRAIN_DATASET_KEY +from ray.ml.preprocessor import Preprocessor +from ray.ml.train.integrations.torch import TorchTrainer +from ray.ml.trainer import GenDataset from ray.train.checkpoint import TuneCheckpointManager from ray.train.constants import TUNE_CHECKPOINT_ID +from ray.train.session import get_session +from ray.train.torch import TorchConfig from ray.tune.utils.file_transfer import delete_on_node, sync_dir_between_nodes -from ray.ml.constants import TRAIN_DATASET_KEY, EVALUATION_DATASET_KEY, PREPROCESSOR_KEY + # This trainer uses a special checkpoint syncing logic. # Because HF checkpoints are very large dirs (at least several GBs), From 7920c22a12c2b787306a942bf87cae7aa2ceb47f Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 18:45:39 +0000 Subject: [PATCH 21/75] Remove monkey patching for callbacks --- .../huggingface/huggingface_trainer.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index f5179a67f92e..2bdb3f5e799f 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -5,7 +5,6 @@ import tempfile from pathlib import Path from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Type -from unittest.mock import patch import torch import transformers.trainer @@ -389,16 +388,10 @@ def train_loop_per_worker(config): else: eval_torch_dataset = None - # ensure no HF logging callbacks are added - # aside from doubling functionality with our callbacks, - # the Wandb callbacks causes training to freeze - # TODO(yard1): Automatically set `no_cuda` - with patch( - "transformers.trainer.get_reporting_integration_callbacks", lambda x: [] - ): - trainer: transformers.trainer.Trainer = trainer_init_per_worker( - train_torch_dataset, eval_torch_dataset, **config - ) + # TODO(yard1): Automatically set `no_cuda` somehow + trainer: transformers.trainer.Trainer = trainer_init_per_worker( + train_torch_dataset, eval_torch_dataset, **config + ) if trainer.args.local_rank != train.local_rank(): raise RuntimeError( @@ -471,6 +464,18 @@ def _save(self, *args, **kwargs): trainer.args.__class__ = RayTrainingArguments trainer.args.no_cuda = not torch.cuda.is_available() trainer.args.save_on_each_node = True + + # ensure no HF logging callbacks are added + # aside from doubling functionality with our callbacks, + # the Wandb callbacks causes training to freeze + integration_callbacks = ( + transformers.trainer.get_reporting_integration_callbacks( + trainer.args.report_to + ) + ) + for callback in integration_callbacks: + trainer.pop_callback(callback) + trainer.add_callback(_TrainReportCallback) if trainer.args.device.type == "cuda": torch.cuda.set_device(trainer.args.device) From 552d0fe59716c257275a410085252eccec99a2b8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 18:53:15 +0000 Subject: [PATCH 22/75] Move to utils --- python/ray/ml/train/data_parallel_trainer.py | 1 + .../huggingface/huggingface_trainer.py | 108 +++--------------- python/ray/ml/utils/huggingface_utils.py | 94 +++++++++++++++ 3 files changed, 108 insertions(+), 95 deletions(-) create mode 100644 python/ray/ml/utils/huggingface_utils.py diff --git a/python/ray/ml/train/data_parallel_trainer.py b/python/ray/ml/train/data_parallel_trainer.py index c199fb276e76..688d2408dafc 100644 --- a/python/ray/ml/train/data_parallel_trainer.py +++ b/python/ray/ml/train/data_parallel_trainer.py @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) + # TODO(team-ml): Refactor checkpoint management along with Tune. class _DataParallelCheckpointManager(TuneCheckpointManager): def on_init(self, preprocessor: Preprocessor): diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 2bdb3f5e799f..b1b4260c8897 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -1,29 +1,31 @@ -import gc import inspect import os import shutil import tempfile from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Type +from typing import Any, Callable, Dict, Optional, Type import torch import transformers.trainer import ray.cloudpickle as cpickle -from torch.utils.data import DataLoader -from torch.utils.data import Dataset as TorchDataset -from torch.utils.data import IterableDataset -from transformers.trainer_callback import TrainerCallback +from torch.utils.data import DataLoader, Dataset as TorchDataset from transformers.training_args import TrainingArguments -from ray import train, tune +from ray import train +from ray import tune from ray.util import PublicAPI, get_node_ip_address -from ray.data.dataset import Dataset from ray.ml.checkpoint import Checkpoint from ray.ml.config import RunConfig, ScalingConfig from ray.ml.constants import EVALUATION_DATASET_KEY, PREPROCESSOR_KEY, TRAIN_DATASET_KEY from ray.ml.preprocessor import Preprocessor from ray.ml.train.integrations.torch import TorchTrainer from ray.ml.trainer import GenDataset +from ray.ml.utils.huggingface_utils import ( + CHECKPOINT_PATH_ON_NODE_KEY, + NODE_IP_KEY, + process_dataset_for_hf, + TrainReportCallback, +) from ray.train.checkpoint import TuneCheckpointManager from ray.train.constants import TUNE_CHECKPOINT_ID from ray.train.session import get_session @@ -41,9 +43,6 @@ # made generic. # TODO(ml-team): Make dir syncing checkpoint logic generic. -NODE_IP_KEY = "node_ip" -CHECKPOINT_PATH_ON_NODE_KEY = "checkpoint_path_on_node" - # TODO(team-ml): Refactor checkpoint management along with Tune. class _DataParallelSyncingCheckpointManager(TuneCheckpointManager): @@ -377,12 +376,12 @@ def train_loop_per_worker(config): # desired size inside transformers.Trainer. Possible optimization # in the future batch_size = 1 - train_torch_dataset = _process_dataset_for_hf( + train_torch_dataset = process_dataset_for_hf( train_dataset, feature_columns, batch_size=batch_size ) if eval_dataset: - eval_torch_dataset = _process_dataset_for_hf( + eval_torch_dataset = process_dataset_for_hf( eval_dataset, feature_columns, batch_size=batch_size ) else: @@ -476,7 +475,7 @@ def _save(self, *args, **kwargs): for callback in integration_callbacks: trainer.pop_callback(callback) - trainer.add_callback(_TrainReportCallback) + trainer.add_callback(TrainReportCallback) if trainer.args.device.type == "cuda": torch.cuda.set_device(trainer.args.device) @@ -508,84 +507,3 @@ def _save(self, *args, **kwargs): shutil.rmtree(checkpoint_path, ignore_errors=True) return train_loop_per_worker - - -class _HFIterableDatasetWithLen(IterableDataset): - """Special Torch IterableDataset with preset length.""" - - def __init__(self, generator: Generator, length: int): - self.generator = generator - self._len = length - - def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: - it = self.generator - for x in it: - # HF-specific format - yield {**x[0], "labels": x[1]} - - def __len__(self): - return self._len - - -class _TrainReportCallback(TrainerCallback): - """HF TrainerCallback for Ray Train metric reporting & checkpointing.""" - - def __init__(self) -> None: - # HF first logs metrics, and then checkpoints. With Ray AIR, we need the - # opposite. Therefore, if we detect that a checkpoint will be created, - # we delay the train.report call after the checkpoint is reported - # to Ray Train. - self.delayed_report = None - # Avoid double reporting at the end. - # TODO(yard1): Train statistics are only reported at the end. Combine - # the second to last report and the last report somehow. We want - # steps/epochs to match the training iteration. - self.last_step = None - super().__init__() - - def on_log(self, args, state, control, model=None, logs=None, **kwargs): - if state.global_step == self.last_step: - return - self.last_step = state.global_step - report = {**logs, "step": state.global_step, "epoch": state.epoch} - if control.should_save: - self.delayed_report = report - else: - train.report(**report) - - def on_save(self, args, state, control, **kwargs): - checkpoint_path = Path( - transformers.trainer.get_last_checkpoint(args.output_dir) - ).absolute() - if checkpoint_path: - train.save_checkpoint( - **{ - NODE_IP_KEY: get_node_ip_address(), - CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), - } - ) - if self.delayed_report: - train.report(**self.delayed_report) - self.delayed_report = None - gc.collect() - - -def _process_dataset_for_hf( - dataset: Dataset, feature_columns: Dict[str, List[str]], batch_size: int = 1 -) -> IterableDataset: - """Converts a Ray Dataset into a HF-compatible Torch Dataset.""" - torch_dataset = dataset.to_torch( - batch_size=batch_size, - feature_columns=feature_columns, - label_column="labels", - unsqueeze_label_tensor=False, - unsqueeze_feature_tensors=False, - ) - try: - count = dataset.count() - except ValueError: - # pipeline case - count = None - if count: - torch_dataset = _HFIterableDatasetWithLen(torch_dataset, count) - return torch_dataset diff --git a/python/ray/ml/utils/huggingface_utils.py b/python/ray/ml/utils/huggingface_utils.py new file mode 100644 index 000000000000..6df58e3a499f --- /dev/null +++ b/python/ray/ml/utils/huggingface_utils.py @@ -0,0 +1,94 @@ +from pathlib import Path +from typing import Dict, Generator, Iterator, List + +import torch +import transformers.trainer +from torch.utils.data import IterableDataset +from transformers.trainer_callback import TrainerCallback + +from ray import train +from ray.util import get_node_ip_address +from ray.data.dataset import Dataset + +CHECKPOINT_PATH_ON_NODE_KEY = "checkpoint_path_on_node" +NODE_IP_KEY = "node_ip" + + +class HFIterableDatasetWithLen(IterableDataset): + """Special Torch IterableDataset with preset length.""" + + def __init__(self, generator: Generator, length: int): + self.generator = generator + self._len = length + + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + it = self.generator + for x in it: + # HF-specific format + yield {**x[0], "labels": x[1]} + + def __len__(self): + return self._len + + +def process_dataset_for_hf( + dataset: Dataset, feature_columns: Dict[str, List[str]], batch_size: int = 1 +) -> IterableDataset: + """Converts a Ray Dataset into a HF-compatible Torch Dataset.""" + torch_dataset = dataset.to_torch( + batch_size=batch_size, + feature_columns=feature_columns, + label_column="labels", + unsqueeze_label_tensor=False, + unsqueeze_feature_tensors=False, + ) + try: + count = dataset.count() + except ValueError: + # pipeline case + count = None + if count: + torch_dataset = HFIterableDatasetWithLen(torch_dataset, count) + return torch_dataset + + +class TrainReportCallback(TrainerCallback): + """HF TrainerCallback for Ray Train metric reporting & checkpointing.""" + + def __init__(self) -> None: + # HF first logs metrics, and then checkpoints. With Ray AIR, we need the + # opposite. Therefore, if we detect that a checkpoint will be created, + # we delay the train.report call after the checkpoint is reported + # to Ray Train. + self.delayed_report = None + # Avoid double reporting at the end. + # TODO(yard1): Train statistics are only reported at the end. Combine + # the second to last report and the last report somehow. We want + # steps/epochs to match the training iteration. + self.last_step = None + super().__init__() + + def on_log(self, args, state, control, model=None, logs=None, **kwargs): + if state.global_step == self.last_step: + return + self.last_step = state.global_step + report = {**logs, "step": state.global_step, "epoch": state.epoch} + if control.should_save: + self.delayed_report = report + else: + train.report(**report) + + def on_save(self, args, state, control, **kwargs): + checkpoint_path = Path( + transformers.trainer.get_last_checkpoint(args.output_dir) + ).absolute() + if checkpoint_path: + train.save_checkpoint( + **{ + NODE_IP_KEY: get_node_ip_address(), + CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), + } + ) + if self.delayed_report: + train.report(**self.delayed_report) + self.delayed_report = None From ac2bd18db069a8ab99f80c37e7da54bf8dd9ed73 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 20:34:43 +0000 Subject: [PATCH 23/75] Lint fix, check transformers version --- doc/source/custom_directives.py | 2 ++ .../huggingface_basic_language_modelling_example.py | 4 ++++ .../integrations/huggingface/huggingface_trainer.py | 12 +++++++++++- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index aed1fb996054..f5c47d0ee665 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -92,6 +92,7 @@ def update_context(app, pagename, templatename, context, doctree): "blist", "ConfigSpace", "dask.distributed", + "datasets", "gym", "gym.spaces", "horovod", @@ -129,6 +130,7 @@ def update_context(app, pagename, templatename, context, doctree): "tensorflow", "tensorflow.contrib", "tensorflow.contrib.all_reduce", + "transformers", "tree", "tensorflow.contrib.all_reduce.python", "tensorflow.contrib.layers", diff --git a/python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py b/python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py index a33c2ddead4e..0c947ef62061 100644 --- a/python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py +++ b/python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py @@ -12,6 +12,8 @@ TrainingArguments, ) +import torch + import ray import ray.data from ray.ml.train.integrations.huggingface import HuggingFaceTrainer @@ -86,6 +88,8 @@ def train_function(train_dataset, eval_dataset=None, **config): weight_decay=0.01, disable_tqdm=True, save_strategy="epoch", + # Required to avoid an exception + no_cuda=not torch.cuda.is_available(), ) print("Initializing Trainer...") trainer = Trainer( diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index b1b4260c8897..ed65eb235bf2 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -2,10 +2,12 @@ import os import shutil import tempfile +from distutils.version import LooseVersion from pathlib import Path from typing import Any, Callable, Dict, Optional, Type import torch +import transformers import transformers.trainer import ray.cloudpickle as cpickle from torch.utils.data import DataLoader, Dataset as TorchDataset @@ -32,7 +34,6 @@ from ray.train.torch import TorchConfig from ray.tune.utils.file_transfer import delete_on_node, sync_dir_between_nodes - # This trainer uses a special checkpoint syncing logic. # Because HF checkpoints are very large dirs (at least several GBs), # we use directory checkpoints that are synced between nodes when @@ -250,6 +251,15 @@ def __init__( resume_from_checkpoint: Optional[Checkpoint] = None, ): + # Functionality required for HuggingFaceTrainer only added in this + # version + if LooseVersion(transformers.__version__) < LooseVersion("4.18.0"): + raise RuntimeError( + "HuggingFaceTrainer requires transformers>=4.18.0, but you " + f"have {transformers.__version__} which is incompatible. " + "Update on all nodes with `pip install -U 'transformers>=4.18.0'`." + ) + self._validate_trainer_init_per_worker( trainer_init_per_worker, "trainer_init_per_worker" ) From af271cf4f8e6632f4d802a1ffd0ae3bc123b00b7 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 20:36:22 +0000 Subject: [PATCH 24/75] Bump transformers version in requirements --- python/requirements/ml/requirements_train.txt | 2 +- python/requirements/ml/requirements_tune.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/requirements/ml/requirements_train.txt b/python/requirements/ml/requirements_train.txt index 792f79f66147..6497f8fb0a94 100644 --- a/python/requirements/ml/requirements_train.txt +++ b/python/requirements/ml/requirements_train.txt @@ -7,7 +7,7 @@ tensorboardX==2.4.1 # Dependencies for Hugging Face examples: # `python/ray/train/examples/transformers/transformers_example.py` -transformers==4.10.0 +transformers==4.18.0 accelerate==0.5.1 datasets==1.14.0 sentencepiece==0.1.96 diff --git a/python/requirements/ml/requirements_tune.txt b/python/requirements/ml/requirements_tune.txt index 892e9646543a..775ea7c04bec 100644 --- a/python/requirements/ml/requirements_tune.txt +++ b/python/requirements/ml/requirements_tune.txt @@ -34,7 +34,7 @@ scikit-learn==0.24.2 scikit-optimize==0.8.1 sigopt==7.5.0 timm==0.4.5 -transformers==4.10.0 +transformers==4.18.0 wandb==0.12.5 xgboost==1.3.3 zoopt==0.4.1 From 542e0073f0612e7419af2d856cf05289c3c983bd Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 20:52:19 +0000 Subject: [PATCH 25/75] Fix checkpoint loading --- .../ml/train/integrations/huggingface/huggingface_trainer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index ed65eb235bf2..04c5f53fcf7c 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -304,10 +304,7 @@ def __new__(cls, *args, **kwargs): checkpoint_path, ) = resume_from_checkpoint.get_internal_representation() if checkpoint_type != "local_path": - raise ValueError( - "Unexpected checkpoint type in `resume_from_checkpoint`. " - f"Expected 'local_path', got '{checkpoint_type}'" - ) + checkpoint_path = resume_from_checkpoint.to_directory() if checkpoint_path: # Load checkpoint from path. checkpoint_path = Path(checkpoint_path).expanduser().absolute() From 04a6d1d896c58c48fa28f261b025bc71df45d72d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 20:52:41 +0000 Subject: [PATCH 26/75] Add `HuggingFacePredictor` --- .../integrations/huggingface/__init__.py | 5 + .../huggingface/huggingface_predictor.py | 157 ++++++++++++++++++ .../integrations/torch/torch_predictor.py | 3 +- 3 files changed, 164 insertions(+), 1 deletion(-) create mode 100644 python/ray/ml/predictors/integrations/huggingface/__init__.py create mode 100644 python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py diff --git a/python/ray/ml/predictors/integrations/huggingface/__init__.py b/python/ray/ml/predictors/integrations/huggingface/__init__.py new file mode 100644 index 000000000000..617063387e1d --- /dev/null +++ b/python/ray/ml/predictors/integrations/huggingface/__init__.py @@ -0,0 +1,5 @@ +from ray.ml.predictors.integrations.huggingface.huggingface_predictor import ( + HuggingFacePredictor, +) + +__all__ = ["HuggingFacePredictor"] diff --git a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py new file mode 100644 index 000000000000..93543e475bd9 --- /dev/null +++ b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py @@ -0,0 +1,157 @@ +import os +from typing import Optional, Type, Union, List + +import numpy as np +import pandas as pd +import torch +from transformers.modeling_utils import PreTrainedModel +from transformers.trainer import WEIGHTS_NAME + +import ray.cloudpickle as cpickle +from ray.ml.predictor import Predictor, DataBatchType +from ray.ml.preprocessor import Preprocessor +from ray.ml.checkpoint import Checkpoint +from ray.ml.utils.torch_utils import load_torch_model, convert_pandas_to_torch_tensor +from ray.ml.constants import PREPROCESSOR_KEY + + +class HuggingFacePredictor(Predictor): + """A predictor for HuggingFace Transformers PyTorch models. + + Args: + model: The Transformers model to use for predictions. + preprocessor: A preprocessor used to transform data batches prior + to prediction. + """ + + def __init__( + self, + model: Union[PreTrainedModel, torch.nn.Module], + preprocessor: Optional[Preprocessor] = None, + ): + self.model = model + self.preprocessor = preprocessor + + @classmethod + def from_checkpoint( + cls, + checkpoint: Checkpoint, + model: Union[Type[PreTrainedModel], torch.nn.Module], + **pretrained_model_kwargs, + ) -> "HuggingFacePredictor": + """Instantiate the predictor from a Checkpoint. + + The checkpoint is expected to be a result of ``HuggingFaceTrainer``. + + Args: + checkpoint: The checkpoint to load the model and + preprocessor from. It is expected to be from the result of a + ``HuggingFaceTrainer`` run. + model: Either a ``transformers.PreTrainedModel`` class, or a + PyTorch model to load the weights to. This should be the + same model used for training. + **pretrained_model_kwargs: Any kwargs to pass to the ``model.from_pretrained()`` + call. Only used if ``model`` is a ``PreTrainerModel``. + """ + ( + checkpoint_type, + checkpoint_path, + ) = checkpoint.get_internal_representation() + if checkpoint_type != "local_path": + checkpoint_path = checkpoint.to_directory() + preprocessor_path = os.path.join(checkpoint_path, PREPROCESSOR_KEY) + if os.path.exists(preprocessor_path): + with open(preprocessor_path, "rb") as f: + preprocessor = cpickle.load(f) + else: + preprocessor = None + if issubclass(model, PreTrainedModel): + model = PreTrainedModel.from_pretrained( + checkpoint_path, **pretrained_model_kwargs + ) + else: + state_dict = torch.load( + os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu" + ) + model = load_torch_model(saved_model=state_dict, model_definition=model) + return HuggingFacePredictor(model=model, preprocessor=preprocessor) + + def predict( + self, + data: DataBatchType, + feature_columns: Optional[ + Union[List[str], List[List[str]], List[int], List[List[int]]] + ] = None, + dtype: Optional[torch.dtype] = None, + ) -> DataBatchType: + """Run inference on data batch. + + The data is converted into a torch Tensor before being inputted to + the model. + + Args: + data: A batch of input data. Either a pandas DataFrame or numpy + array. + feature_columns: The names or indices of the columns in the + data to use as features to predict on. If this arg is a + list of lists, then the data batch will be converted into a + multiple tensors which are then concatenated before feeding + into the model. This is useful for multi-input models. If + None, then use all columns in ``data``. + dtype: The torch dtype to use when creating the torch tensor. + If set to None, then automatically infer the dtype. + + Examples: + + .. code-block:: python + + import numpy as np + import torch + from ray.ml.predictors.torch import TorchPredictor + + model = torch.nn.Linear(1, 1) + predictor = TorchPredictor(model=model) + + data = np.array([[1, 2], [3, 4]]) + predictions = predictor.predict(data) + + # Only use first column as the feature + predictions = predictor.predict(data, feature_columns=[0]) + + .. code-block:: python + + import pandas as pd + import torch + from ray.ml.predictors.torch import TorchPredictor + + model = torch.nn.Linear(1, 1) + predictor = TorchPredictor(model=model) + + # Pandas dataframe. + data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + + predictions = predictor.predict(data) + + # Only use first column as the feature + predictions = predictor.predict(data, feature_columns=["A"]) + + + Returns: + DataBatchType: Prediction result. + """ + self.model.eval() + + if self.preprocessor: + data = self.preprocessor.transform_batch(data) + + if isinstance(data, np.ndarray): + # If numpy array, then convert to pandas dataframe. + data = pd.DataFrame(data) + + # TODO(amog): Add `_convert_numpy_to_torch_tensor to use based on input type. + # Reduce conversion cost if input is in Numpy + tensor = convert_pandas_to_torch_tensor( + data, columns=feature_columns, column_dtypes=dtype, unsqueeze=False + ) + prediction = self.model(tensor).cpu().detach().numpy() + return pd.DataFrame(prediction, columns=["predictions"]) diff --git a/python/ray/ml/predictors/integrations/torch/torch_predictor.py b/python/ray/ml/predictors/integrations/torch/torch_predictor.py index a9e8eedffa39..9a7e8436571b 100644 --- a/python/ray/ml/predictors/integrations/torch/torch_predictor.py +++ b/python/ray/ml/predictors/integrations/torch/torch_predictor.py @@ -24,7 +24,6 @@ def __init__( self, model: torch.nn.Module, preprocessor: Optional[Preprocessor] = None ): self.model = model - self.model.eval() self.preprocessor = preprocessor @classmethod @@ -119,6 +118,8 @@ def predict( Returns: DataBatchType: Prediction result. """ + self.model.eval() + if self.preprocessor: data = self.preprocessor.transform_batch(data) From d4f98cfa0369c7724bb5f17b6490f71a824dfc07 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Apr 2022 21:03:12 +0000 Subject: [PATCH 27/75] Fix predictor columns --- .../huggingface/huggingface_predictor.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py index 93543e475bd9..ab4d30f4b44a 100644 --- a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py +++ b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py @@ -50,8 +50,9 @@ def from_checkpoint( model: Either a ``transformers.PreTrainedModel`` class, or a PyTorch model to load the weights to. This should be the same model used for training. - **pretrained_model_kwargs: Any kwargs to pass to the ``model.from_pretrained()`` - call. Only used if ``model`` is a ``PreTrainerModel``. + **pretrained_model_kwargs: Any kwargs to pass to the + ``model.from_pretrained()`` call. Only used if + ``model`` is a ``PreTrainerModel`` class. """ ( checkpoint_type, @@ -79,9 +80,7 @@ def from_checkpoint( def predict( self, data: DataBatchType, - feature_columns: Optional[ - Union[List[str], List[List[str]], List[int], List[List[int]]] - ] = None, + feature_columns: Optional[List[str]] = None, dtype: Optional[torch.dtype] = None, ) -> DataBatchType: """Run inference on data batch. @@ -93,10 +92,7 @@ def predict( data: A batch of input data. Either a pandas DataFrame or numpy array. feature_columns: The names or indices of the columns in the - data to use as features to predict on. If this arg is a - list of lists, then the data batch will be converted into a - multiple tensors which are then concatenated before feeding - into the model. This is useful for multi-input models. If + data to use as features to predict on. If None, then use all columns in ``data``. dtype: The torch dtype to use when creating the torch tensor. If set to None, then automatically infer the dtype. @@ -148,6 +144,12 @@ def predict( # If numpy array, then convert to pandas dataframe. data = pd.DataFrame(data) + if not feature_columns: + feature_columns = data.columns + + # HF-supported format + feature_columns = {column: [column] for column in feature_columns} + # TODO(amog): Add `_convert_numpy_to_torch_tensor to use based on input type. # Reduce conversion cost if input is in Numpy tensor = convert_pandas_to_torch_tensor( From e9ba551cc1839d1a2e172683beaf6b98906b6187 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 14 Apr 2022 05:06:52 +0000 Subject: [PATCH 28/75] Address some comments --- .../huggingface/huggingface_trainer.py | 62 +------------------ python/ray/ml/utils/huggingface_utils.py | 24 +++++-- 2 files changed, 19 insertions(+), 67 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 04c5f53fcf7c..94fe782b1457 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -30,7 +30,6 @@ ) from ray.train.checkpoint import TuneCheckpointManager from ray.train.constants import TUNE_CHECKPOINT_ID -from ray.train.session import get_session from ray.train.torch import TorchConfig from ray.tune.utils.file_transfer import delete_on_node, sync_dir_between_nodes @@ -356,32 +355,10 @@ def train_loop_per_worker(config): train_dataset = train.get_dataset_shard(TRAIN_DATASET_KEY) eval_dataset = train.get_dataset_shard(EVALUATION_DATASET_KEY) train_columns = set(train_dataset.schema(fetch_if_missing=True).names) - if "labels" not in train_columns: - raise ValueError( - "'labels' column must be present in the training dataset!" - ) - train_columns.remove("labels") - if eval_dataset: - eval_columns = set(eval_dataset.schema(fetch_if_missing=True).names) - if "labels" not in eval_columns: - raise ValueError( - "'labels' column must be present in the evaluation dataset!" - ) - eval_columns.remove("labels") - - if not eval_columns.issuperset(train_columns): - raise ValueError( - "Evaluation dataset must have a superset of the columns in " - "the training dataset. " - f"Missing columns: {list(train_columns - eval_columns)}" - ) - # HF-supported format + # HF-specific format. See transformers.Trainer._prepare_inputs feature_columns = {column: [column] for column in train_columns} - # we use batch size 1 here, as it will be converted to - # desired size inside transformers.Trainer. Possible optimization - # in the future batch_size = 1 train_torch_dataset = process_dataset_for_hf( train_dataset, feature_columns, batch_size=batch_size @@ -399,14 +376,6 @@ def train_loop_per_worker(config): train_torch_dataset, eval_torch_dataset, **config ) - if trainer.args.local_rank != train.local_rank(): - raise RuntimeError( - "local_rank set in TrainingArguments doesn't match " - "Ray Train local_rank " - f"({trainer.args.local_rank} != {train.local_rank()}. " - "Ensure you are not setting local_rank manually." - ) - base_training_arguments_class: Type[ TrainingArguments ] = trainer.args.__class__ @@ -414,16 +383,12 @@ def train_loop_per_worker(config): class RayTrainingArguments(base_training_arguments_class): @property def device(self) -> "torch.device": - if get_session() is None: - return super().device return train.torch.get_device() base_trainer_class: Type[transformers.trainer.Trainer] = trainer.__class__ class RayTrainer(base_trainer_class): def get_train_dataloader(self): - if get_session() is None: - return super().get_train_dataloader() return DataLoader( self.train_dataset, batch_size=self.args.per_device_train_batch_size, @@ -432,31 +397,6 @@ def get_train_dataloader(self): pin_memory=self.args.dataloader_pin_memory, ) - def _wrap_model(self, model, training=True): - if get_session() is None: - return super()._wrap_model(model, training=training) - - if not training: - return model - kwargs = {} - # same logic as in transformers.Trainer - if self.args.ddp_find_unused_parameters is not None: - kwargs[ - "find_unused_parameters" - ] = self.args.ddp_find_unused_parameters - elif isinstance(model, transformers.trainer.PreTrainedModel): - # find_unused_parameters breaks checkpointing as per - # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 - kwargs[ - "find_unused_parameters" - ] = not model.is_gradient_checkpointing - else: - kwargs["find_unused_parameters"] = True - - if self.args.ddp_bucket_cap_mb is not None: - kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb - return train.torch.prepare_model(model, ddp_kwargs=kwargs) - def _save(self, *args, **kwargs): # Workaround for RayTrainingArguments not being # pickleable due to it being defined in a local diff --git a/python/ray/ml/utils/huggingface_utils.py b/python/ray/ml/utils/huggingface_utils.py index 6df58e3a499f..9964e0caeb9c 100644 --- a/python/ray/ml/utils/huggingface_utils.py +++ b/python/ray/ml/utils/huggingface_utils.py @@ -14,18 +14,25 @@ NODE_IP_KEY = "node_ip" -class HFIterableDatasetWithLen(IterableDataset): +class HFIterableDataset(IterableDataset): """Special Torch IterableDataset with preset length.""" - def __init__(self, generator: Generator, length: int): + def __init__(self, generator: Generator): self.generator = generator - self._len = length def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: it = self.generator for x in it: - # HF-specific format - yield {**x[0], "labels": x[1]} + # HF-specific format. See transformers.Trainer._prepare_inputs + yield {**x[0]} + + +class HFIterableDatasetWithLen(HFIterableDataset): + """Special Torch IterableDataset with preset length.""" + + def __init__(self, generator: Generator, length: int): + self.generator = generator + self._len = length def __len__(self): return self._len @@ -38,7 +45,7 @@ def process_dataset_for_hf( torch_dataset = dataset.to_torch( batch_size=batch_size, feature_columns=feature_columns, - label_column="labels", + label_column=None, unsqueeze_label_tensor=False, unsqueeze_feature_tensors=False, ) @@ -48,7 +55,12 @@ def process_dataset_for_hf( # pipeline case count = None if count: + # By adding length to the dataset we let HF calculate steps per epoch + # and other such values. Without length, it's not possible to use + # epochs as the evaluation strategy. torch_dataset = HFIterableDatasetWithLen(torch_dataset, count) + else: + torch_dataset = HFIterableDataset(torch_dataset) return torch_dataset From d55f84324a25719716c3585846f3093b4a12179e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 14 Apr 2022 05:20:17 +0000 Subject: [PATCH 29/75] add tune checkpoint id --- .../integrations/huggingface/huggingface_trainer.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 94fe782b1457..77af9ee7c8ea 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -49,12 +49,6 @@ class _DataParallelSyncingCheckpointManager(TuneCheckpointManager): """Same as _DataParallelCheckpointManager, but syncs the dir instead of serializing it.""" - def add_tune_checkpoint_id(self, path: str): - # Store the checkpoint_id in the file so that the Tune trial can be - # resumed after failure or cancellation. - with open(Path(path).joinpath(TUNE_CHECKPOINT_ID), "w") as f: - f.write(str(self._latest_checkpoint_id)) - def on_init(self, preprocessor: Preprocessor): self.preprocessor = preprocessor super(_DataParallelSyncingCheckpointManager, self).on_init() @@ -81,9 +75,12 @@ def write_checkpoint(self, checkpoint: Dict): max_size_bytes=None, ) delete_on_node(node_ip=source_ip, path=source_path) - with open(Path(checkpoint_dir).joinpath(PREPROCESSOR_KEY), "wb") as f: + checkpoint_dir = Path(checkpoint_dir) + with open(checkpoint_dir.joinpath(PREPROCESSOR_KEY), "wb") as f: cpickle.dump(self.preprocessor, f) - self.add_tune_checkpoint_id(checkpoint_dir) + # add tune checkpoint id + with open(checkpoint_dir.joinpath(TUNE_CHECKPOINT_ID), "w") as f: + f.write(str(self._latest_checkpoint_id)) @property def latest_checkpoint_dir(self) -> Optional[Path]: From e36bfcef0dbfe701cc94a794f4d5df6768f7ddda Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 14 Apr 2022 08:38:36 +0200 Subject: [PATCH 30/75] Update python/ray/ml/utils/huggingface_utils.py --- python/ray/ml/utils/huggingface_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/ml/utils/huggingface_utils.py b/python/ray/ml/utils/huggingface_utils.py index 9964e0caeb9c..244b4338a141 100644 --- a/python/ray/ml/utils/huggingface_utils.py +++ b/python/ray/ml/utils/huggingface_utils.py @@ -15,7 +15,7 @@ class HFIterableDataset(IterableDataset): - """Special Torch IterableDataset with preset length.""" + """Special Torch IterableDataset with HF format.""" def __init__(self, generator: Generator): self.generator = generator From 0d7f14d511ac0a4812e9cb29fc0f67f9501ca3f8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 14 Apr 2022 06:49:12 +0000 Subject: [PATCH 31/75] Use an external func --- .../huggingface/huggingface_trainer.py | 248 +++++++++--------- 1 file changed, 121 insertions(+), 127 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 77af9ee7c8ea..88ab03395b74 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -260,22 +260,15 @@ def __init__( trainer_init_per_worker, "trainer_init_per_worker" ) - if TRAIN_DATASET_KEY not in datasets: - raise KeyError( - f"'{TRAIN_DATASET_KEY}' key must be preset in `datasets`. " - f"Got {list(self.datasets.keys())}" - ) - if not all( - key in (TRAIN_DATASET_KEY, EVALUATION_DATASET_KEY) for key in datasets - ): - raise KeyError( - f"Only '{TRAIN_DATASET_KEY}' and '{EVALUATION_DATASET_KEY}' " - "keys can be preset in `datasets`. " - f"Got {list(self.datasets.keys())}" + trainer_init_config = trainer_init_config.copy() if trainer_init_config else {} + if "_trainer_init_per_worker" in trainer_init_config: + raise ValueError( + "'_trainer_init_per_worker' is a reserved key in `trainer_init_config`." ) + trainer_init_config["_trainer_init_per_worker"] = trainer_init_per_worker super().__init__( - train_loop_per_worker=self._create_train_func(trainer_init_per_worker), + train_loop_per_worker=_huggingface_train_loop_per_worker, train_loop_config=trainer_init_config, torch_config=torch_config, scaling_config=scaling_config, @@ -330,124 +323,125 @@ def _validate_trainer_init_per_worker( f"but it accepts {num_params} arguments instead." ) - def _validate_train_loop_per_worker( - self, train_loop_per_worker: Callable, fn_name: str - ) -> None: - # Do not validate train_loop_per_worker. We validate - # trainer_init_per_worker instead. - pass - - def _create_train_func( - self, - trainer_init_per_worker: Callable[ - [TorchDataset, Optional[TorchDataset], Any], transformers.trainer.Trainer - ], - ): - def train_loop_per_worker(config): - # Env vars necessary for HF to setup DDP - os.environ["RANK"] = str(train.world_rank()) - os.environ["WORLD_SIZE"] = str(train.world_size()) - os.environ["LOCAL_RANK"] = str(train.local_rank()) - - train_dataset = train.get_dataset_shard(TRAIN_DATASET_KEY) - eval_dataset = train.get_dataset_shard(EVALUATION_DATASET_KEY) - train_columns = set(train_dataset.schema(fetch_if_missing=True).names) - - # HF-specific format. See transformers.Trainer._prepare_inputs - feature_columns = {column: [column] for column in train_columns} - - batch_size = 1 - train_torch_dataset = process_dataset_for_hf( - train_dataset, feature_columns, batch_size=batch_size + def _validate_attributes(self): + if TRAIN_DATASET_KEY not in self.datasets: + raise KeyError( + f"'{TRAIN_DATASET_KEY}' key must be preset in `datasets`. " + f"Got {list(self.datasets.keys())}" ) + if not all( + key in (TRAIN_DATASET_KEY, EVALUATION_DATASET_KEY) for key in self.datasets + ): + raise KeyError( + f"Only '{TRAIN_DATASET_KEY}' and '{EVALUATION_DATASET_KEY}' " + "keys can be preset in `datasets`. " + f"Got {list(self.datasets.keys())}" + ) + return super()._validate_attributes() - if eval_dataset: - eval_torch_dataset = process_dataset_for_hf( - eval_dataset, feature_columns, batch_size=batch_size - ) - else: - eval_torch_dataset = None - # TODO(yard1): Automatically set `no_cuda` somehow - trainer: transformers.trainer.Trainer = trainer_init_per_worker( - train_torch_dataset, eval_torch_dataset, **config - ) +def _huggingface_train_loop_per_worker(config): + """Per-worker training loop for HuggingFace Transformers.""" + trainer_init_per_worker = config.pop("_trainer_init_per_worker") - base_training_arguments_class: Type[ - TrainingArguments - ] = trainer.args.__class__ - - class RayTrainingArguments(base_training_arguments_class): - @property - def device(self) -> "torch.device": - return train.torch.get_device() - - base_trainer_class: Type[transformers.trainer.Trainer] = trainer.__class__ - - class RayTrainer(base_trainer_class): - def get_train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.args.per_device_train_batch_size, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.dataloader_pin_memory, - ) + # Env vars necessary for HF to setup DDP + os.environ["RANK"] = str(train.world_rank()) + os.environ["WORLD_SIZE"] = str(train.world_size()) + os.environ["LOCAL_RANK"] = str(train.local_rank()) - def _save(self, *args, **kwargs): - # Workaround for RayTrainingArguments not being - # pickleable due to it being defined in a local - # scope - self.args.__class__ = base_training_arguments_class - ret = super()._save(*args, **kwargs) - self.args.__class__ = RayTrainingArguments - return ret - - trainer.__class__ = RayTrainer - trainer.args.__class__ = RayTrainingArguments - trainer.args.no_cuda = not torch.cuda.is_available() - trainer.args.save_on_each_node = True - - # ensure no HF logging callbacks are added - # aside from doubling functionality with our callbacks, - # the Wandb callbacks causes training to freeze - integration_callbacks = ( - transformers.trainer.get_reporting_integration_callbacks( - trainer.args.report_to - ) + train_dataset = train.get_dataset_shard(TRAIN_DATASET_KEY) + eval_dataset = train.get_dataset_shard(EVALUATION_DATASET_KEY) + train_columns = set(train_dataset.schema(fetch_if_missing=True).names) + + # HF-specific format. See transformers.Trainer._prepare_inputs + feature_columns = {column: [column] for column in train_columns} + + batch_size = 1 + train_torch_dataset = process_dataset_for_hf( + train_dataset, feature_columns, batch_size=batch_size + ) + + if eval_dataset: + eval_torch_dataset = process_dataset_for_hf( + eval_dataset, feature_columns, batch_size=batch_size + ) + else: + eval_torch_dataset = None + + # TODO(yard1): Automatically set `no_cuda` somehow + trainer: transformers.trainer.Trainer = trainer_init_per_worker( + train_torch_dataset, eval_torch_dataset, **config + ) + + base_training_arguments_class: Type[TrainingArguments] = trainer.args.__class__ + + class RayTrainingArguments(base_training_arguments_class): + @property + def device(self) -> "torch.device": + return train.torch.get_device() + + base_trainer_class: Type[transformers.trainer.Trainer] = trainer.__class__ + + class RayTrainer(base_trainer_class): + def get_train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, ) - for callback in integration_callbacks: - trainer.pop_callback(callback) - - trainer.add_callback(TrainReportCallback) - if trainer.args.device.type == "cuda": - torch.cuda.set_device(trainer.args.device) - - checkpoint = train.load_checkpoint() - checkpoint_path = None - remove_checkpoint_path = False - if checkpoint: - source_ip = checkpoint[NODE_IP_KEY] - source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY] - target_ip = get_node_ip_address() - if source_ip == target_ip: - checkpoint_path = source_path - else: - # TODO(yard1): Confirm if tempdir is the right approach here. - checkpoint_path = tempfile.mkdtemp( - suffix=Path(trainer.args.output_dir).name - ) - remove_checkpoint_path = True - sync_dir_between_nodes( - source_ip=source_ip, - source_path=source_path, - target_ip=target_ip, - target_path=checkpoint_path, - return_futures=False, - max_size_bytes=None, - ) - trainer.train(resume_from_checkpoint=checkpoint_path) - if remove_checkpoint_path: - shutil.rmtree(checkpoint_path, ignore_errors=True) - return train_loop_per_worker + def _save(self, *args, **kwargs): + # Workaround for RayTrainingArguments not being + # pickleable due to it being defined in a local + # scope + self.args.__class__ = base_training_arguments_class + ret = super()._save(*args, **kwargs) + self.args.__class__ = RayTrainingArguments + return ret + + trainer.__class__ = RayTrainer + trainer.args.__class__ = RayTrainingArguments + trainer.args.no_cuda = not torch.cuda.is_available() + trainer.args.save_on_each_node = True + + # ensure no HF logging callbacks are added + # aside from doubling functionality with our callbacks, + # the Wandb callbacks causes training to freeze + integration_callbacks = transformers.trainer.get_reporting_integration_callbacks( + trainer.args.report_to + ) + for callback in integration_callbacks: + trainer.pop_callback(callback) + + trainer.add_callback(TrainReportCallback) + if trainer.args.device.type == "cuda": + torch.cuda.set_device(trainer.args.device) + + checkpoint = train.load_checkpoint() + checkpoint_path = None + remove_checkpoint_path = False + if checkpoint: + source_ip = checkpoint[NODE_IP_KEY] + source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY] + target_ip = get_node_ip_address() + if source_ip == target_ip: + checkpoint_path = source_path + else: + # TODO(yard1): Confirm if tempdir is the right approach here. + checkpoint_path = tempfile.mkdtemp( + suffix=Path(trainer.args.output_dir).name + ) + remove_checkpoint_path = True + sync_dir_between_nodes( + source_ip=source_ip, + source_path=source_path, + target_ip=target_ip, + target_path=checkpoint_path, + return_futures=False, + max_size_bytes=None, + ) + trainer.train(resume_from_checkpoint=checkpoint_path) + if remove_checkpoint_path: + shutil.rmtree(checkpoint_path, ignore_errors=True) From 890f84ea2a11d05b2849839786c7c7a276142105 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 14 Apr 2022 08:51:23 +0200 Subject: [PATCH 32/75] Rename huggingface_basic_language_modelling_example.py to huggingface_basic_language_modeling_example.py --- ..._example.py => huggingface_basic_language_modeling_example.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename python/ray/ml/examples/huggingface/{huggingface_basic_language_modelling_example.py => huggingface_basic_language_modeling_example.py} (100%) diff --git a/python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py similarity index 100% rename from python/ray/ml/examples/huggingface/huggingface_basic_language_modelling_example.py rename to python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py From fd815f6ddd9ca8f5395a6133ef16dac6fb0ceb3e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 14 Apr 2022 23:58:20 +0000 Subject: [PATCH 33/75] Do not override __new__ --- .../huggingface/huggingface_trainer.py | 71 +++++++++---------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 88ab03395b74..46f4649d297d 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -31,6 +31,7 @@ from ray.train.checkpoint import TuneCheckpointManager from ray.train.constants import TUNE_CHECKPOINT_ID from ray.train.torch import TorchConfig +from ray.tune.trainable import Trainable from ray.tune.utils.file_transfer import delete_on_node, sync_dir_between_nodes # This trainer uses a special checkpoint syncing logic. @@ -44,6 +45,8 @@ # TODO(ml-team): Make dir syncing checkpoint logic generic. +# The checkpoint is turned into a dict with node ip & path +# in HuggingFaceTrainer.as_trainable # TODO(team-ml): Refactor checkpoint management along with Tune. class _DataParallelSyncingCheckpointManager(TuneCheckpointManager): """Same as _DataParallelCheckpointManager, but syncs the dir instead @@ -278,41 +281,6 @@ def __init__( resume_from_checkpoint=resume_from_checkpoint, ) - def __new__(cls, *args, **kwargs): - """Store the init args as attributes so this can be merged with Tune hparams.""" - # This if will be entered in the driver-side Trainer. - # The Trainer inside the trainable will have a dict - # checkpoint created here. - # This is required to ensure that the dir syncing logic - # is used instead of serializing several gigabytes of data - # when a Checkpoint is sent to a Ray Actor. - if "resume_from_checkpoint" in kwargs: - resume_from_checkpoint: Checkpoint = kwargs["resume_from_checkpoint"] - ( - checkpoint_type, - checkpoint_path, - ) = resume_from_checkpoint.get_internal_representation() - if checkpoint_type != "local_path": - checkpoint_path = resume_from_checkpoint.to_directory() - if checkpoint_path: - # Load checkpoint from path. - checkpoint_path = Path(checkpoint_path).expanduser().absolute() - if not checkpoint_path.exists(): - raise ValueError( - f"Checkpoint path {checkpoint_path} does not exist." - ) - with open(checkpoint_path.joinpath(TUNE_CHECKPOINT_ID), "r") as f: - tune_checkpoint_id = int(f.read()) - - kwargs["resume_from_checkpoint"] = Checkpoint.from_dict( - { - NODE_IP_KEY: get_node_ip_address(), - CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), - TUNE_CHECKPOINT_ID: tune_checkpoint_id, - } - ) - return super(HuggingFaceTrainer, cls).__new__(cls, *args, **kwargs) - def _validate_trainer_init_per_worker( self, trainer_init_per_worker: Callable, fn_name: str ) -> None: @@ -339,6 +307,38 @@ def _validate_attributes(self): ) return super()._validate_attributes() + def as_trainable(self) -> Type[Trainable]: + # Replace the directory checkpoint with a node ip & path dict checkpoint + # used to sync the directory. If we use a directry checkpoint directly, + # it will get deepcopied & serialized unnecessarily + original_param_dict = self._param_dict.copy() + resume_from_checkpoint: Optional[Checkpoint] = self._param_dict.get( + "resume_from_checkpoint", None + ) + if resume_from_checkpoint: + with resume_from_checkpoint.as_directory() as checkpoint_path: + # Load checkpoint from path. + checkpoint_path = Path(checkpoint_path).expanduser().absolute() + if not checkpoint_path.exists(): + raise ValueError( + f"Checkpoint path {checkpoint_path} does not exist." + ) + with open(checkpoint_path.joinpath(TUNE_CHECKPOINT_ID), "r") as f: + tune_checkpoint_id = int(f.read()) + + self._param_dict["resume_from_checkpoint"] = Checkpoint.from_dict( + { + NODE_IP_KEY: get_node_ip_address(), + CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), + TUNE_CHECKPOINT_ID: tune_checkpoint_id, + } + ) + try: + ret = super().as_trainable() + finally: + self._param_dict = original_param_dict + return ret + def _huggingface_train_loop_per_worker(config): """Per-worker training loop for HuggingFace Transformers.""" @@ -404,7 +404,6 @@ def _save(self, *args, **kwargs): trainer.__class__ = RayTrainer trainer.args.__class__ = RayTrainingArguments trainer.args.no_cuda = not torch.cuda.is_available() - trainer.args.save_on_each_node = True # ensure no HF logging callbacks are added # aside from doubling functionality with our callbacks, From d020fca18c9b5f9d58e688417cdb3139f39958e0 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 19 Apr 2022 13:04:45 +0000 Subject: [PATCH 34/75] HuggingFacePredictor inherits from TorchPredictor --- .../huggingface/huggingface_predictor.py | 119 ++++-------------- .../integrations/torch/torch_predictor.py | 18 ++- 2 files changed, 42 insertions(+), 95 deletions(-) diff --git a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py index ab4d30f4b44a..59df3d94ae01 100644 --- a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py +++ b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py @@ -8,14 +8,15 @@ from transformers.trainer import WEIGHTS_NAME import ray.cloudpickle as cpickle -from ray.ml.predictor import Predictor, DataBatchType +from ray.ml.predictor import DataBatchType +from ray.ml.predictors.integrations.torch import TorchPredictor from ray.ml.preprocessor import Preprocessor from ray.ml.checkpoint import Checkpoint from ray.ml.utils.torch_utils import load_torch_model, convert_pandas_to_torch_tensor from ray.ml.constants import PREPROCESSOR_KEY -class HuggingFacePredictor(Predictor): +class HuggingFacePredictor(TorchPredictor): """A predictor for HuggingFace Transformers PyTorch models. Args: @@ -54,96 +55,32 @@ def from_checkpoint( ``model.from_pretrained()`` call. Only used if ``model`` is a ``PreTrainerModel`` class. """ - ( - checkpoint_type, - checkpoint_path, - ) = checkpoint.get_internal_representation() - if checkpoint_type != "local_path": - checkpoint_path = checkpoint.to_directory() - preprocessor_path = os.path.join(checkpoint_path, PREPROCESSOR_KEY) - if os.path.exists(preprocessor_path): - with open(preprocessor_path, "rb") as f: - preprocessor = cpickle.load(f) - else: - preprocessor = None - if issubclass(model, PreTrainedModel): - model = PreTrainedModel.from_pretrained( - checkpoint_path, **pretrained_model_kwargs - ) - else: - state_dict = torch.load( - os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu" - ) - model = load_torch_model(saved_model=state_dict, model_definition=model) + with checkpoint.as_directory() as checkpoint_path: + preprocessor_path = os.path.join(checkpoint_path, PREPROCESSOR_KEY) + if os.path.exists(preprocessor_path): + with open(preprocessor_path, "rb") as f: + preprocessor = cpickle.load(f) + else: + preprocessor = None + if issubclass(model, PreTrainedModel): + model = PreTrainedModel.from_pretrained( + checkpoint_path, **pretrained_model_kwargs + ) + else: + state_dict = torch.load( + os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu" + ) + model = load_torch_model(saved_model=state_dict, model_definition=model) return HuggingFacePredictor(model=model, preprocessor=preprocessor) - def predict( + def _convert_to_tensor( self, - data: DataBatchType, - feature_columns: Optional[List[str]] = None, - dtype: Optional[torch.dtype] = None, - ) -> DataBatchType: - """Run inference on data batch. - - The data is converted into a torch Tensor before being inputted to - the model. - - Args: - data: A batch of input data. Either a pandas DataFrame or numpy - array. - feature_columns: The names or indices of the columns in the - data to use as features to predict on. If - None, then use all columns in ``data``. - dtype: The torch dtype to use when creating the torch tensor. - If set to None, then automatically infer the dtype. - - Examples: - - .. code-block:: python - - import numpy as np - import torch - from ray.ml.predictors.torch import TorchPredictor - - model = torch.nn.Linear(1, 1) - predictor = TorchPredictor(model=model) - - data = np.array([[1, 2], [3, 4]]) - predictions = predictor.predict(data) - - # Only use first column as the feature - predictions = predictor.predict(data, feature_columns=[0]) - - .. code-block:: python - - import pandas as pd - import torch - from ray.ml.predictors.torch import TorchPredictor - - model = torch.nn.Linear(1, 1) - predictor = TorchPredictor(model=model) - - # Pandas dataframe. - data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) - - predictions = predictor.predict(data) - - # Only use first column as the feature - predictions = predictor.predict(data, feature_columns=["A"]) - - - Returns: - DataBatchType: Prediction result. - """ - self.model.eval() - - if self.preprocessor: - data = self.preprocessor.transform_batch(data) - - if isinstance(data, np.ndarray): - # If numpy array, then convert to pandas dataframe. - data = pd.DataFrame(data) - + data: pd.DataFrame, + feature_columns: Optional[ + Union[List[str], List[List[str]], List[int], List[List[int]]] + ], + dtype: Optional[torch.dtype], + ): if not feature_columns: feature_columns = data.columns @@ -152,8 +89,6 @@ def predict( # TODO(amog): Add `_convert_numpy_to_torch_tensor to use based on input type. # Reduce conversion cost if input is in Numpy - tensor = convert_pandas_to_torch_tensor( + return convert_pandas_to_torch_tensor( data, columns=feature_columns, column_dtypes=dtype, unsqueeze=False ) - prediction = self.model(tensor).cpu().detach().numpy() - return pd.DataFrame(prediction, columns=["predictions"]) diff --git a/python/ray/ml/predictors/integrations/torch/torch_predictor.py b/python/ray/ml/predictors/integrations/torch/torch_predictor.py index 9a7e8436571b..7fb4cb9e9f76 100644 --- a/python/ray/ml/predictors/integrations/torch/torch_predictor.py +++ b/python/ray/ml/predictors/integrations/torch/torch_predictor.py @@ -55,6 +55,20 @@ def from_checkpoint( ) return TorchPredictor(model=model, preprocessor=preprocessor) + def _convert_to_tensor( + self, + data: pd.DataFrame, + feature_columns: Optional[ + Union[List[str], List[List[str]], List[int], List[List[int]]] + ], + dtype: Optional[torch.dtype], + ): + # TODO(amog): Add `_convert_numpy_to_torch_tensor to use based on input type. + # Reduce conversion cost if input is in Numpy + return convert_pandas_to_torch_tensor( + data, columns=feature_columns, column_dtypes=dtype + ) + def predict( self, data: DataBatchType, @@ -127,9 +141,7 @@ def predict( # If numpy array, then convert to pandas dataframe. data = pd.DataFrame(data) - # TODO(amog): Add `_convert_numpy_to_torch_tensor to use based on input type. - # Reduce conversion cost if input is in Numpy - tensor = convert_pandas_to_torch_tensor( + tensor = self._convert_to_tensor( data, columns=feature_columns, column_dtypes=dtype ) prediction = self.model(tensor).cpu().detach().numpy() From 9cdadb7db548cf0b5203008b7f44ec0f0cc54c40 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 19 Apr 2022 13:07:59 +0000 Subject: [PATCH 35/75] Move utils to hf folder --- .../ml/train/integrations/huggingface/huggingface_trainer.py | 2 +- .../integrations/huggingface}/huggingface_utils.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename python/ray/ml/{utils => train/integrations/huggingface}/huggingface_utils.py (100%) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 46f4649d297d..cd28673ec75e 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -22,7 +22,7 @@ from ray.ml.preprocessor import Preprocessor from ray.ml.train.integrations.torch import TorchTrainer from ray.ml.trainer import GenDataset -from ray.ml.utils.huggingface_utils import ( +from ray.ml.train.integrations.huggingface.huggingface_utils import ( CHECKPOINT_PATH_ON_NODE_KEY, NODE_IP_KEY, process_dataset_for_hf, diff --git a/python/ray/ml/utils/huggingface_utils.py b/python/ray/ml/train/integrations/huggingface/huggingface_utils.py similarity index 100% rename from python/ray/ml/utils/huggingface_utils.py rename to python/ray/ml/train/integrations/huggingface/huggingface_utils.py From 13ac55f1afdb3fd42fbb7d7cd6c48dd0f82c48c7 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 19 Apr 2022 13:11:54 +0000 Subject: [PATCH 36/75] Inheritance tweak --- .../huggingface/huggingface_predictor.py | 2 -- .../huggingface/huggingface_trainer.py | 16 +++------------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py index 59df3d94ae01..1e2849588f46 100644 --- a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py +++ b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py @@ -1,14 +1,12 @@ import os from typing import Optional, Type, Union, List -import numpy as np import pandas as pd import torch from transformers.modeling_utils import PreTrainedModel from transformers.trainer import WEIGHTS_NAME import ray.cloudpickle as cpickle -from ray.ml.predictor import DataBatchType from ray.ml.predictors.integrations.torch import TorchPredictor from ray.ml.preprocessor import Preprocessor from ray.ml.checkpoint import Checkpoint diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index cd28673ec75e..eb051f77c9aa 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -22,13 +22,13 @@ from ray.ml.preprocessor import Preprocessor from ray.ml.train.integrations.torch import TorchTrainer from ray.ml.trainer import GenDataset +from ray.ml.train.data_parallel_trainer import _DataParallelCheckpointManager from ray.ml.train.integrations.huggingface.huggingface_utils import ( CHECKPOINT_PATH_ON_NODE_KEY, NODE_IP_KEY, process_dataset_for_hf, TrainReportCallback, ) -from ray.train.checkpoint import TuneCheckpointManager from ray.train.constants import TUNE_CHECKPOINT_ID from ray.train.torch import TorchConfig from ray.tune.trainable import Trainable @@ -48,14 +48,10 @@ # The checkpoint is turned into a dict with node ip & path # in HuggingFaceTrainer.as_trainable # TODO(team-ml): Refactor checkpoint management along with Tune. -class _DataParallelSyncingCheckpointManager(TuneCheckpointManager): +class _DataParallelSyncingCheckpointManager(_DataParallelCheckpointManager): """Same as _DataParallelCheckpointManager, but syncs the dir instead of serializing it.""" - def on_init(self, preprocessor: Preprocessor): - self.preprocessor = preprocessor - super(_DataParallelSyncingCheckpointManager, self).on_init() - def write_checkpoint(self, checkpoint: Dict): # If inside a Tune Trainable, then checkpoint with Tune. with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: @@ -85,10 +81,6 @@ def write_checkpoint(self, checkpoint: Dict): with open(checkpoint_dir.joinpath(TUNE_CHECKPOINT_ID), "w") as f: f.write(str(self._latest_checkpoint_id)) - @property - def latest_checkpoint_dir(self) -> Optional[Path]: - raise NotImplementedError - @PublicAPI(stability="alpha") class HuggingFaceTrainer(TorchTrainer): @@ -231,9 +223,7 @@ def trainer_init_per_worker(train_dataset, eval_dataset, **config): resume_from_checkpoint: A checkpoint to resume training from. """ - _checkpoint_manager_cls: Type[ - TuneCheckpointManager - ] = _DataParallelSyncingCheckpointManager + _checkpoint_manager_cls = _DataParallelSyncingCheckpointManager def __init__( self, From 7f3403a3a715280d42ee380bf38649070ac5ab3f Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 19 Apr 2022 14:26:26 +0000 Subject: [PATCH 37/75] Doc fix --- doc/source/custom_directives.py | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index b1c4c32e0344..ef13431184d4 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -130,6 +130,7 @@ def update_context(app, pagename, templatename, context, doctree): "tensorflow.contrib", "tensorflow.contrib.all_reduce", "transformers", + "transformers.trainer", "tree", "tensorflow.contrib.all_reduce.python", "tensorflow.contrib.layers", From c3d80507df5891b046ddce05a58eca390b29e675 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 19 Apr 2022 15:32:26 +0000 Subject: [PATCH 38/75] Improve tensorize --- python/ray/ml/utils/torch_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/ray/ml/utils/torch_utils.py b/python/ray/ml/utils/torch_utils.py index 66eb64fb68c1..92b3bceb28e3 100644 --- a/python/ray/ml/utils/torch_utils.py +++ b/python/ray/ml/utils/torch_utils.py @@ -53,10 +53,15 @@ def convert_pandas_to_torch_tensor( def tensorize(vals, dtype): """This recursive function allows to convert pyarrow List dtypes to multi-dimensional tensors.""" - if vals.dtype == np.object: + try: + return torch.as_tensor(vals, dtype=dtype) + except TypeError: + # This exception will be raised if vals is of object dtype + # or otherwise cannot be made into a tensor directly. + # We assume it's a sequence in that case. + # This is more robust than checking for dtype. # TODO(yard1): clarify if this should be cat or stack - return torch.stack([tensorize(x, dtype) for x in vals]) - return torch.as_tensor(vals, dtype=dtype) + return torch.cat([tensorize(x, dtype) for x in vals]) def get_tensor_for_columns(columns, dtype): feature_tensors = [] From d4851e4ceed5fd8f25d75d779be30f032cbf7556 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 19 Apr 2022 19:02:02 +0000 Subject: [PATCH 39/75] Fix docs --- doc/source/custom_directives.py | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index ef13431184d4..95b6a5128021 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -131,6 +131,7 @@ def update_context(app, pagename, templatename, context, doctree): "tensorflow.contrib.all_reduce", "transformers", "transformers.trainer", + "transformers.training_args", "tree", "tensorflow.contrib.all_reduce.python", "tensorflow.contrib.layers", From 31850097158c065ed6e81ecb7add19be9a051a79 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 19 Apr 2022 19:02:37 +0000 Subject: [PATCH 40/75] Stack after all --- python/ray/ml/utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/ml/utils/torch_utils.py b/python/ray/ml/utils/torch_utils.py index 92b3bceb28e3..fb1d8a154a51 100644 --- a/python/ray/ml/utils/torch_utils.py +++ b/python/ray/ml/utils/torch_utils.py @@ -61,7 +61,7 @@ def tensorize(vals, dtype): # We assume it's a sequence in that case. # This is more robust than checking for dtype. # TODO(yard1): clarify if this should be cat or stack - return torch.cat([tensorize(x, dtype) for x in vals]) + return torch.stack([tensorize(x, dtype) for x in vals]) def get_tensor_for_columns(columns, dtype): feature_tensors = [] From 814b8894fca69b603cb9187bcb7bdca24d58b55d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 19 Apr 2022 19:21:23 +0000 Subject: [PATCH 41/75] Add tests, predictor work --- .../huggingface/huggingface_predictor.py | 127 +++++++++++++++--- .../integrations/torch/torch_predictor.py | 50 +++++-- .../ray/ml/tests/huggingface_data/train.json | 1 + .../ml/tests/huggingface_data/validation.json | 1 + .../ray/ml/tests/test_huggingface_trainer.py | 124 +++++++++++++++++ .../huggingface/huggingface_trainer.py | 23 +--- .../huggingface/huggingface_utils.py | 46 ++++++- python/ray/ml/utils/torch_utils.py | 10 +- 8 files changed, 331 insertions(+), 51 deletions(-) create mode 100644 python/ray/ml/tests/huggingface_data/train.json create mode 100644 python/ray/ml/tests/huggingface_data/validation.json create mode 100644 python/ray/ml/tests/test_huggingface_trainer.py diff --git a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py index 1e2849588f46..a989a2abd54e 100644 --- a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py +++ b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py @@ -1,16 +1,22 @@ import os -from typing import Optional, Type, Union, List +from typing import Dict, Optional, Type, Union, List import pandas as pd +import numpy as np +from ray.ml.train.integrations.huggingface.huggingface_utils import ( + HFIterableDatasetWithLen, +) import torch from transformers.modeling_utils import PreTrainedModel -from transformers.trainer import WEIGHTS_NAME +from transformers.trainer import WEIGHTS_NAME, TRAINING_ARGS_NAME, Trainer as HFTrainer +from transformers import TrainingArguments import ray.cloudpickle as cpickle +from ray.ml.predictor import DataBatchType from ray.ml.predictors.integrations.torch import TorchPredictor from ray.ml.preprocessor import Preprocessor from ray.ml.checkpoint import Checkpoint -from ray.ml.utils.torch_utils import load_torch_model, convert_pandas_to_torch_tensor +from ray.ml.utils.torch_utils import load_torch_model from ray.ml.constants import PREPROCESSOR_KEY @@ -27,9 +33,11 @@ def __init__( self, model: Union[PreTrainedModel, torch.nn.Module], preprocessor: Optional[Preprocessor] = None, + training_args: Optional[TrainingArguments] = None, ): self.model = model self.preprocessor = preprocessor + self.training_args = training_args @classmethod def from_checkpoint( @@ -46,9 +54,9 @@ def from_checkpoint( checkpoint: The checkpoint to load the model and preprocessor from. It is expected to be from the result of a ``HuggingFaceTrainer`` run. - model: Either a ``transformers.PreTrainedModel`` class, or a - PyTorch model to load the weights to. This should be the - same model used for training. + model: Either a ``transformers.PreTrainedModel`` class + (eg. ``AutoModelForCausalLM``), or a PyTorch model to load the + weights to. This should be the same model used for training. **pretrained_model_kwargs: Any kwargs to pass to the ``model.from_pretrained()`` call. Only used if ``model`` is a ``PreTrainerModel`` class. @@ -60,33 +68,112 @@ def from_checkpoint( preprocessor = cpickle.load(f) else: preprocessor = None - if issubclass(model, PreTrainedModel): - model = PreTrainedModel.from_pretrained( - checkpoint_path, **pretrained_model_kwargs - ) - else: + if isinstance(model, torch.nn.Module): state_dict = torch.load( os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu" ) model = load_torch_model(saved_model=state_dict, model_definition=model) - return HuggingFacePredictor(model=model, preprocessor=preprocessor) + else: + model = model.from_pretrained( + checkpoint_path, **pretrained_model_kwargs + ) + training_args_path = os.path.join(checkpoint_path, TRAINING_ARGS_NAME) + if os.path.exists(training_args_path): + with open(training_args_path, "rb") as f: + training_args = torch.load(f, map_location="cpu") + else: + training_args = None + return HuggingFacePredictor( + model=model, preprocessor=preprocessor, training_args=training_args + ) def _convert_to_tensor( self, data: pd.DataFrame, feature_columns: Optional[ Union[List[str], List[List[str]], List[int], List[List[int]]] - ], - dtype: Optional[torch.dtype], - ): + ] = None, + dtypes: Optional[torch.dtype] = None, + unsqueeze: bool = False, + ) -> Dict[str, torch.Tensor]: if not feature_columns: feature_columns = data.columns # HF-supported format - feature_columns = {column: [column] for column in feature_columns} + if not isinstance(feature_columns, dict): + feature_columns = {column: [column] for column in feature_columns} + + return super()._convert_to_tensor( + data, feature_columns=feature_columns, dtypes=dtypes, unsqueeze=unsqueeze + ) + + def _predict(self, tensor: Dict[str, torch.Tensor]) -> np.ndarray: + self.training_args.local_rank = -1 + trainer = HFTrainer(model=self.model, args=self.training_args) + dataset = HFIterableDatasetWithLen([tensor], 1) + return trainer.predict(dataset).predictions - # TODO(amog): Add `_convert_numpy_to_torch_tensor to use based on input type. - # Reduce conversion cost if input is in Numpy - return convert_pandas_to_torch_tensor( - data, columns=feature_columns, column_dtypes=dtype, unsqueeze=False + def predict( + self, + data: DataBatchType, + feature_columns: Optional[List[str]] = None, + dtype: Optional[Union[Dict[str, torch.dtype], torch.dtype]] = None, + ) -> DataBatchType: + """Run inference on data batch. + + The data is converted into a dict of torch Tensors before being inputted to + the model. + + Args: + data: A batch of input data. Either a pandas DataFrame or numpy + array. + feature_columns: The names or indices of the columns in the + data to use as features to predict on. If None, use all + columns. + dtype: The torch dtypes to use when creating the torch tensor. + Can be either a single dtype or a dict of ``column:dtype``. + If set to None, then automatically infer the dtype. + + Examples: + + .. code-block:: python + + import numpy as np + import torch + from ray.ml.predictors.torch import TorchPredictor + + model = torch.nn.Linear(1, 1) + predictor = TorchPredictor(model=model) + + data = np.array([[1, 2], [3, 4]]) + predictions = predictor.predict(data) + + # Only use first column as the feature + predictions = predictor.predict(data, feature_columns=[0]) + + .. code-block:: python + + import pandas as pd + import torch + from ray.ml.predictors.torch import TorchPredictor + + model = torch.nn.Linear(1, 1) + predictor = TorchPredictor(model=model) + + # Pandas dataframe. + data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + + predictions = predictor.predict(data) + + # Only use first column as the feature + predictions = predictor.predict(data, feature_columns=["A"]) + + + Returns: + DataBatchType: Prediction result. + """ + # We are just changing the signature and docstring. + print(data) + return super().predict( + data, feature_columns=feature_columns, dtype=dtype, unsqueeze=False ) diff --git a/python/ray/ml/predictors/integrations/torch/torch_predictor.py b/python/ray/ml/predictors/integrations/torch/torch_predictor.py index 7fb4cb9e9f76..b0b1197fb15f 100644 --- a/python/ray/ml/predictors/integrations/torch/torch_predictor.py +++ b/python/ray/ml/predictors/integrations/torch/torch_predictor.py @@ -55,19 +55,39 @@ def from_checkpoint( ) return TorchPredictor(model=model, preprocessor=preprocessor) + # parity with Datset.to_torch def _convert_to_tensor( self, data: pd.DataFrame, feature_columns: Optional[ Union[List[str], List[List[str]], List[int], List[List[int]]] - ], - dtype: Optional[torch.dtype], - ): + ] = None, + dtypes: Optional[torch.dtype] = None, + unsqueeze: bool = True, + ) -> torch.Tensor: # TODO(amog): Add `_convert_numpy_to_torch_tensor to use based on input type. # Reduce conversion cost if input is in Numpy - return convert_pandas_to_torch_tensor( - data, columns=feature_columns, column_dtypes=dtype - ) + if isinstance(feature_columns, dict): + features_tensor = { + key: convert_pandas_to_torch_tensor( + data, + feature_columns[key], + dtypes[key] if isinstance(dtypes, dict) else dtypes, + unsqueeze=unsqueeze, + ) + for key in feature_columns + } + else: + features_tensor = convert_pandas_to_torch_tensor( + data, + columns=feature_columns, + column_dtypes=dtypes, + unsqueeze=unsqueeze, + ) + return features_tensor + + def _predict(self, tensor: torch.Tensor) -> np.ndarray: + return self.model(tensor).cpu().detach().numpy() def predict( self, @@ -76,6 +96,7 @@ def predict( Union[List[str], List[List[str]], List[int], List[List[int]]] ] = None, dtype: Optional[torch.dtype] = None, + unsqueeze: bool = True, ) -> DataBatchType: """Run inference on data batch. @@ -87,12 +108,19 @@ def predict( array. feature_columns: The names or indices of the columns in the data to use as features to predict on. If this arg is a - list of lists, then the data batch will be converted into a + list of lists or a dict of string-list pairs, then the + data batch will be converted into a multiple tensors which are then concatenated before feeding into the model. This is useful for multi-input models. If None, then use all columns in ``data``. - dtype: The torch dtype to use when creating the torch tensor. - If set to None, then automatically infer the dtype. + dtype: The dtypes to use for the tensors. This should match the + format of ``feature_columns``, or be a single dtype, in which + case it will be applied to all tensors. + If None, then automatically infer the dtype. + unsqueeze_feature_tensors (bool): If set to True, the features tensors + will be unsqueezed (reshaped to (N, 1)) before being concatenated into + the final features tensor. Otherwise, they will be left as is, that is + (N, ). Defaults to True. Examples: @@ -142,7 +170,7 @@ def predict( data = pd.DataFrame(data) tensor = self._convert_to_tensor( - data, columns=feature_columns, column_dtypes=dtype + data, feature_columns=feature_columns, dtypes=dtype, unsqueeze=unsqueeze ) - prediction = self.model(tensor).cpu().detach().numpy() + prediction = self._predict(tensor) return pd.DataFrame(prediction, columns=["predictions"]) diff --git a/python/ray/ml/tests/huggingface_data/train.json b/python/ray/ml/tests/huggingface_data/train.json new file mode 100644 index 000000000000..a9c2e7302ae5 --- /dev/null +++ b/python/ray/ml/tests/huggingface_data/train.json @@ -0,0 +1 @@ +{"input_ids":{"0":[238,8576,9441,2987,238,252,4657,74,4762,826,8576,428,466,609,6881,412,204,9441,311,2746,466,10816,168,99,150,192,112,14328,3983,112,4446,94,18288,4446,193,3983,98,3983,22171,95,19,201,6374,209,8576,218,198,3455,1972,428,310,201,5099,3242,227,281,8576,9441,2987,2553,1759,201,301,196,13996,1496,277,2330,1464,674,1898,307,742,3541,225,7514,14,54,719,274,198,4777,15522,209,19895,221,1341,1633,221,1759,201,322,301,198,1368,674,221,198,8576,843,209,2468,1795,223,198,1049,9595,218,13996,225,1563,277,582,6493,281,457,14371,201,198,1422,3373,7452,227,198,455,674,225,4687],"1":[198,239,21976,239,201,196,21657,1680,3773,5591,198,4196,218,4679,427,661,198,3518,1288,220,1051,516,889,3947,1922,2500,225,390,2065,744,872,198,7592,3773,239,1975,251,208,89,22351,239,209,252,261,674,959,1921,221,1462,201,7600,547,196,1178,4753,218,198,630,3591,263,8576,9441,1180,209,1831,322,7568,198,3621,2240,218,198,843,201,322,471,9575,5291,16591,967,201,781,281,1815,198,674,604,10344,1252,274,843,664,3147,320,209,13290,8751,8124,2528,6023,74,235,225,7445,10040,17384,241,11487,8950,857,1835,340,1382,22582,201,1008,296,8576,9441,1180,2436,21134,5337,19463,5161,209,240,1178],"2":[927,218,3776,8650,198,3355,209,261,674,268,83,2511,3472,258,8288,307,1010,268,78,209,252,530,922,296,3096,4354,221,1759,201,225,258,2630,307,857,2746,225,2761,2234,209,929,1634,201,322,1217,13362,4798,201,1008,296,326,5585,4615,221,1389,218,295,544,209,530,258,471,6459,587,8630,225,326,1446,1464,10260,843,209,5023,227,1669,4354,218,8576,9441,1180,201,8576,9441,2987,258,410,18878,201,445,196,6418,8895,20885,296,198,674,268,83,5585,4615,258,1055,221,2474,209,7514,14,54,719,602,1063,227,198,7072,296,198,1921,218,8576,466,8111,447,5474,274,198,4777,493,209,252],"3":[238,238,15114,238,238,252,662,296,1382,7784,9345,9441,1374,201,8576,9441,2987,301,196,13996,1496,277,2330,674,756,1997,1581,1694,218,196,1680,3773,225,1581,540,221,6455,872,5000,1754,209,8410,390,2848,718,5926,1184,277,1093,13659,296,10102,1136,13018,201,296,2081,6041,8040,718,7674,6917,14877,6632,225,8040,718,434,15744,5263,3973,209,261,1538,15606,718,196,843,218,11382,6455,201,5546,13230,5382,281,12903,295,684,293,15086,518,6145,718,225,1486,13196,281,543,390,13230,5382,209,261,2262,227,1140,1422,3371,263,198,7464,11718,8326,263,326,3854,1538,268,83,4156,466,646,496,6938,301,3941,201],"4":[198,548,301,15410,555,227,198,1538,209,13067,6455,201,198,1538,2081,1231,221,196,1306,201,756,2726,684,293,7376,968,225,1136,5168,6395,209,24281,198,906,1422,6455,390,1136,277,2737,747,6455,13516,227,1579,5860,1718,209,929,198,674,268,83,7131,201,3079,3411,390,13230,5382,201,766,218,642,1763,196,4014,6669,702,1413,824,221,198,1231,218,198,674,209,1415,390,471,2610,1076,1388,2377,3130,227,198,674,268,83,506,906,3641,828,201,1360,543,1581,196,1365,2978,1496,209,252,261,674,268,83,2094,1232,201,198,1028,73,52,58,1232,201,301,3493,547,3660,340,7784,9345,9441,209],"5":[1303,6455,201,1997,6322,1140,3773,1919,196,1362,277,1004,8550,218,198,11332,7464,466,2109,196,1136,301,3941,201,198,1538,5757,198,1136,1118,198,11332,221,1368,277,1432,209,240,1136,684,669,733,2109,503,277,1407,201,445,2081,684,293,4929,5291,6398,319,198,13876,218,548,2081,268,6398,209,5056,1136,499,196,1606,225,3924,218,3672,3069,307,444,11840,315,15041,209,3264,227,2849,2081,684,293,3423,227,196,1314,4490,209,1303,6493,201,2081,1170,798,573,1358,3851,16048,227,642,201,781,281,444,3668,2423,311,282,48,310,6506,1669,371,812,11285,573,307,5000,3704,209,5056,1136,499,2737,239],"6":[8303,2453,83,239,201,8436,5025,227,1140,1136,209,1065,390,4695,587,239,12021,8303,2453,239,201,400,390,3295,365,8436,295,3052,434,222,8830,11270,7936,9067,380,307,198,1422,225,684,2292,1546,371,1017,21071,196,1136,201,225,239,3455,8303,2453,83,239,201,400,390,8498,1895,198,674,225,3365,3326,1958,675,227,196,1136,209,2000,2627,3455,8303,2453,83,201,1140,1136,499,196,5025,239,8119,236,614,239,201,196,14628,277,1356,11949,6230,295,684,293,788,227,13970,225,3261,1579,8436,209,16500,471,490,7447,1531,3004,295,3326,642,8535,1958,11069,263,198,11332,466,8625,684,1980,365,239,4495],"7":[3683,239,225,2341,1118,198,11332,1582,897,234,1236,342,11840,6957,16326,201,198,1136,789,9878,684,9627,587,468,239,8576,3408,239,225,1638,896,23892,201,732,2293,5336,684,4849,5291,5000,2726,296,468,2287,8263,209,252,13441,1355,390,4695,587,1160,6555,466,23804,201,808,722,84,228,19466,201,12785,201,7096,320,225,21755,18607,853,209,13441,19466,684,12454,6555,307,7048,444,3423,8263,209,11070,223,1249,1990,410,5993,11950,198,285,1201,4825,732,221,196,1382,1249,209,1585,2495,221,2094,201,3335,2423,390,3495,227,198,5860,201,400,390,6908,587,1160,1579,13376,6136,307,198,2068,5860,201,196,3135],"8":[1317,223,340,989,1374,268,3588,218,10078,5751,227,1579,3773,4325,209,252,238,238,6851,238,238,252,261,674,3544,1417,661,198,3518,1288,220,1051,209,4679,406,2205,13619,3775,18,201,471,876,281,239,261,21976,239,201,390,196,21657,1680,3773,3808,218,18490,201,4899,541,199,690,201,225,1680,555,16914,2821,1563,2850,390,1843,620,340,198,3517,225,817,200,4614,3242,227,307,3473,209,10730,665,307,198,4679,406,1680,227,889,198,659,8110,6455,295,198,22724,2205,225,18991,1170,410,804,201,543,390,15994,578,227,198,5214,201,20891,307,444,22232,79,201,18766,4209,1531,21157,201,3457,239],"9":[14996,789,2989,209,239,261,697,906,2081,390,1524,14,23,8625,2408,2513,201,326,3006,4417,16271,539,7184,218,15101,516,9461,227,24374,1920,440,8607,1524,14,17,2293,5336,201,196,2476,5447,67,3294,2287,5253,16021,516,17113,14241,872,198,8576,516,3443,468,1398,440,225,1524,14,17,19,297,793,65,20714,215,201,196,11573,536,197,88,204,2246,3063,516,301,4796,2616,5744,196,20713,218,198,8576,209,9332,296,444,5920,5860,1718,201,1021,697,390,10455,227,2172,872,196,12040,7592,3773,876,281,1975,251,379,22351,201,6355,218,3223,5447,67,3294,3478,209,252,662,198,21976,4614,804,410,2523],"10":[201,198,3578,5080,1193,675,218,198,4679,406,2205,21878,198,3301,218,21337,2903,14454,221,1688,227,5355,642,263,6455,295,602,7936,1445,4679,427,7843,3132,221,198,1060,209,1831,319,1696,526,2125,227,444,4517,201,781,281,196,1991,13719,587,7592,3833,201,548,4953,3522,2931,1718,218,198,3775,18,207,1428,16464,209,1539,781,2541,201,315,283,2809,201,4228,749,554,16361,295,331,3669,675,342,1753,225,770,2220,587,198,10100,218,1975,251,379,22351,201,6386,227,198,9541,218,5447,67,3294,5668,3412,307,444,5164,201,8954,971,209,775,198,1049,582,201,2377,1347,4679,406,2205,3683,2341,227],"11":[1843,626,198,21976,221,1688,227,6259,444,1059,7969,209,282,15902,307,857,13855,225,9782,201,225,4439,296,198,3752,218,196,836,3547,1347,444,10100,201,198,3775,18,207,11695,1206,2341,227,3292,3230,10852,732,319,198,1049,582,2172,227,1546,198,4679,406,1060,4585,209,795,4471,1040,198,21976,268,83,14643,4417,201,3790,5491,22737,201,516,398,504,4755,699,2010,5125,201,301,15218,227,198,2867,1047,218,12106,71,337,90,221,1688,227,1474,2525,377,200,199,761,198,4239,89,3478,225,20416,198,1563,836,3547,201,198,4679,406,1796,295,398,7184,8625,218,3430,599,209,252,2304,264,1309,227],"12":[1021,1988,201,225,7259,1309,227,198,1214,5662,221,23734,4679,427,2433,320,2375,198,726,218,198,1060,296,198,4444,201,198,21976,390,3946,196,7944,1676,281,196,5860,221,198,4679,406,2205,2472,702,5228,281,326,13483,21340,2033,209,795,301,1311,277,3364,201,1650,201,281,1149,19033,15786,268,83,3441,201,8954,971,225,1975,251,379,22351,2341,227,1980,365,326,4479,7784,2920,2157,8263,1347,198,4444,201,4755,3947,307,444,5698,17446,209,9663,198,1117,218,19033,15786,371,198,6241,227,8223,3230,221,198,1060,296,4679,427,201,322,301,8954,971,268,83,1244,677,1940,6368,221,5290,196,664,5447],"13":[67,3294,4196,209,662,326,4579,4679,406,2033,896,5021,198,4444,1318,1149,198,506,8947,268,12199,277,1739,602,9684,14048,444,664,11418,4206,201,8625,8420,227,2109,695,1445,342,5860,198,21976,201,9271,22737,227,1343,1920,225,610,699,342,1434,281,2196,277,221,277,2429,209,7243,15215,17157,227,7235,548,702,3230,201,198,3775,18,207,21955,8954,971,225,1503,24815,198,7784,2920,8263,209,5056,2541,862,6084,444,4142,6294,221,1688,227,4555,444,4658,326,361,209,252,238,238,5823,238,238,252,19611,630,274,8576,9441,2987,959,550,1921,2393,263,8576,9441,1180,221,989,1462,201,296,1971,1921,3053],"14":[4691,550,526,209,261,2436,218,8576,9441,1180,201,21134,5337,19463,5161,201,1835,227,295,1496,274,8576,9441,2987,209,5823,630,1260,3032,496,544,209,929,198,1634,218,8576,9441,1180,201,198,3154,1260,196,1976,319,857,198,1737,3092,274,198,674,225,1241,543,2512,227,804,1533,274,198,843,209,4064,457,8472,201,8576,9441,2987,258,1898,274,4777,15522,466,526,258,1309,227,198,927,12182,227,1308,415,198,16590,2054,274,8576,9441,1180,201,225,543,398,410,3029,578,296,198,239,12252,239,3213,295,602,13884,196,664,6494,274,198,4777,428,209,11929,1006,221,326,2692,201,322,258,1821,295],"15":[198,1921,927,1535,8576,9441,2987,227,293,198,843,268,455,4099,7056,466,732,8576,9441,1180,398,3034,196,1178,3126,218,5429,225,12738,661,1921,1309,227,198,5770,2341,201,198,1368,674,1706,642,196,6241,227,7127,2451,198,1386,2773,218,8576,9441,1180,1309,227,812,263,198,1049,5770,209,357,1513,227,742,3541,3154,340,198,1382,1374,201,1921,630,258,471,8650,307,7514,14,54,719,14,261,1446,13395,258,1617,15415,9455,558,9884,706,69,201,732,198,3355,258,1617,307,11856,1652,9455,12645,74,5435,201,11490,20317,12329,10190,201,384,424,11300,21187,10198,201,742,14619,9765,368,5161,225,10830,288]},"attention_mask":{"0":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"1":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"2":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"3":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"4":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"5":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"6":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"7":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"8":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"9":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"10":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"11":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"12":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"13":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"14":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"15":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]},"labels":{"0":[238,8576,9441,2987,238,252,4657,74,4762,826,8576,428,466,609,6881,412,204,9441,311,2746,466,10816,168,99,150,192,112,14328,3983,112,4446,94,18288,4446,193,3983,98,3983,22171,95,19,201,6374,209,8576,218,198,3455,1972,428,310,201,5099,3242,227,281,8576,9441,2987,2553,1759,201,301,196,13996,1496,277,2330,1464,674,1898,307,742,3541,225,7514,14,54,719,274,198,4777,15522,209,19895,221,1341,1633,221,1759,201,322,301,198,1368,674,221,198,8576,843,209,2468,1795,223,198,1049,9595,218,13996,225,1563,277,582,6493,281,457,14371,201,198,1422,3373,7452,227,198,455,674,225,4687],"1":[198,239,21976,239,201,196,21657,1680,3773,5591,198,4196,218,4679,427,661,198,3518,1288,220,1051,516,889,3947,1922,2500,225,390,2065,744,872,198,7592,3773,239,1975,251,208,89,22351,239,209,252,261,674,959,1921,221,1462,201,7600,547,196,1178,4753,218,198,630,3591,263,8576,9441,1180,209,1831,322,7568,198,3621,2240,218,198,843,201,322,471,9575,5291,16591,967,201,781,281,1815,198,674,604,10344,1252,274,843,664,3147,320,209,13290,8751,8124,2528,6023,74,235,225,7445,10040,17384,241,11487,8950,857,1835,340,1382,22582,201,1008,296,8576,9441,1180,2436,21134,5337,19463,5161,209,240,1178],"2":[927,218,3776,8650,198,3355,209,261,674,268,83,2511,3472,258,8288,307,1010,268,78,209,252,530,922,296,3096,4354,221,1759,201,225,258,2630,307,857,2746,225,2761,2234,209,929,1634,201,322,1217,13362,4798,201,1008,296,326,5585,4615,221,1389,218,295,544,209,530,258,471,6459,587,8630,225,326,1446,1464,10260,843,209,5023,227,1669,4354,218,8576,9441,1180,201,8576,9441,2987,258,410,18878,201,445,196,6418,8895,20885,296,198,674,268,83,5585,4615,258,1055,221,2474,209,7514,14,54,719,602,1063,227,198,7072,296,198,1921,218,8576,466,8111,447,5474,274,198,4777,493,209,252],"3":[238,238,15114,238,238,252,662,296,1382,7784,9345,9441,1374,201,8576,9441,2987,301,196,13996,1496,277,2330,674,756,1997,1581,1694,218,196,1680,3773,225,1581,540,221,6455,872,5000,1754,209,8410,390,2848,718,5926,1184,277,1093,13659,296,10102,1136,13018,201,296,2081,6041,8040,718,7674,6917,14877,6632,225,8040,718,434,15744,5263,3973,209,261,1538,15606,718,196,843,218,11382,6455,201,5546,13230,5382,281,12903,295,684,293,15086,518,6145,718,225,1486,13196,281,543,390,13230,5382,209,261,2262,227,1140,1422,3371,263,198,7464,11718,8326,263,326,3854,1538,268,83,4156,466,646,496,6938,301,3941,201],"4":[198,548,301,15410,555,227,198,1538,209,13067,6455,201,198,1538,2081,1231,221,196,1306,201,756,2726,684,293,7376,968,225,1136,5168,6395,209,24281,198,906,1422,6455,390,1136,277,2737,747,6455,13516,227,1579,5860,1718,209,929,198,674,268,83,7131,201,3079,3411,390,13230,5382,201,766,218,642,1763,196,4014,6669,702,1413,824,221,198,1231,218,198,674,209,1415,390,471,2610,1076,1388,2377,3130,227,198,674,268,83,506,906,3641,828,201,1360,543,1581,196,1365,2978,1496,209,252,261,674,268,83,2094,1232,201,198,1028,73,52,58,1232,201,301,3493,547,3660,340,7784,9345,9441,209],"5":[1303,6455,201,1997,6322,1140,3773,1919,196,1362,277,1004,8550,218,198,11332,7464,466,2109,196,1136,301,3941,201,198,1538,5757,198,1136,1118,198,11332,221,1368,277,1432,209,240,1136,684,669,733,2109,503,277,1407,201,445,2081,684,293,4929,5291,6398,319,198,13876,218,548,2081,268,6398,209,5056,1136,499,196,1606,225,3924,218,3672,3069,307,444,11840,315,15041,209,3264,227,2849,2081,684,293,3423,227,196,1314,4490,209,1303,6493,201,2081,1170,798,573,1358,3851,16048,227,642,201,781,281,444,3668,2423,311,282,48,310,6506,1669,371,812,11285,573,307,5000,3704,209,5056,1136,499,2737,239],"6":[8303,2453,83,239,201,8436,5025,227,1140,1136,209,1065,390,4695,587,239,12021,8303,2453,239,201,400,390,3295,365,8436,295,3052,434,222,8830,11270,7936,9067,380,307,198,1422,225,684,2292,1546,371,1017,21071,196,1136,201,225,239,3455,8303,2453,83,239,201,400,390,8498,1895,198,674,225,3365,3326,1958,675,227,196,1136,209,2000,2627,3455,8303,2453,83,201,1140,1136,499,196,5025,239,8119,236,614,239,201,196,14628,277,1356,11949,6230,295,684,293,788,227,13970,225,3261,1579,8436,209,16500,471,490,7447,1531,3004,295,3326,642,8535,1958,11069,263,198,11332,466,8625,684,1980,365,239,4495],"7":[3683,239,225,2341,1118,198,11332,1582,897,234,1236,342,11840,6957,16326,201,198,1136,789,9878,684,9627,587,468,239,8576,3408,239,225,1638,896,23892,201,732,2293,5336,684,4849,5291,5000,2726,296,468,2287,8263,209,252,13441,1355,390,4695,587,1160,6555,466,23804,201,808,722,84,228,19466,201,12785,201,7096,320,225,21755,18607,853,209,13441,19466,684,12454,6555,307,7048,444,3423,8263,209,11070,223,1249,1990,410,5993,11950,198,285,1201,4825,732,221,196,1382,1249,209,1585,2495,221,2094,201,3335,2423,390,3495,227,198,5860,201,400,390,6908,587,1160,1579,13376,6136,307,198,2068,5860,201,196,3135],"8":[1317,223,340,989,1374,268,3588,218,10078,5751,227,1579,3773,4325,209,252,238,238,6851,238,238,252,261,674,3544,1417,661,198,3518,1288,220,1051,209,4679,406,2205,13619,3775,18,201,471,876,281,239,261,21976,239,201,390,196,21657,1680,3773,3808,218,18490,201,4899,541,199,690,201,225,1680,555,16914,2821,1563,2850,390,1843,620,340,198,3517,225,817,200,4614,3242,227,307,3473,209,10730,665,307,198,4679,406,1680,227,889,198,659,8110,6455,295,198,22724,2205,225,18991,1170,410,804,201,543,390,15994,578,227,198,5214,201,20891,307,444,22232,79,201,18766,4209,1531,21157,201,3457,239],"9":[14996,789,2989,209,239,261,697,906,2081,390,1524,14,23,8625,2408,2513,201,326,3006,4417,16271,539,7184,218,15101,516,9461,227,24374,1920,440,8607,1524,14,17,2293,5336,201,196,2476,5447,67,3294,2287,5253,16021,516,17113,14241,872,198,8576,516,3443,468,1398,440,225,1524,14,17,19,297,793,65,20714,215,201,196,11573,536,197,88,204,2246,3063,516,301,4796,2616,5744,196,20713,218,198,8576,209,9332,296,444,5920,5860,1718,201,1021,697,390,10455,227,2172,872,196,12040,7592,3773,876,281,1975,251,379,22351,201,6355,218,3223,5447,67,3294,3478,209,252,662,198,21976,4614,804,410,2523],"10":[201,198,3578,5080,1193,675,218,198,4679,406,2205,21878,198,3301,218,21337,2903,14454,221,1688,227,5355,642,263,6455,295,602,7936,1445,4679,427,7843,3132,221,198,1060,209,1831,319,1696,526,2125,227,444,4517,201,781,281,196,1991,13719,587,7592,3833,201,548,4953,3522,2931,1718,218,198,3775,18,207,1428,16464,209,1539,781,2541,201,315,283,2809,201,4228,749,554,16361,295,331,3669,675,342,1753,225,770,2220,587,198,10100,218,1975,251,379,22351,201,6386,227,198,9541,218,5447,67,3294,5668,3412,307,444,5164,201,8954,971,209,775,198,1049,582,201,2377,1347,4679,406,2205,3683,2341,227],"11":[1843,626,198,21976,221,1688,227,6259,444,1059,7969,209,282,15902,307,857,13855,225,9782,201,225,4439,296,198,3752,218,196,836,3547,1347,444,10100,201,198,3775,18,207,11695,1206,2341,227,3292,3230,10852,732,319,198,1049,582,2172,227,1546,198,4679,406,1060,4585,209,795,4471,1040,198,21976,268,83,14643,4417,201,3790,5491,22737,201,516,398,504,4755,699,2010,5125,201,301,15218,227,198,2867,1047,218,12106,71,337,90,221,1688,227,1474,2525,377,200,199,761,198,4239,89,3478,225,20416,198,1563,836,3547,201,198,4679,406,1796,295,398,7184,8625,218,3430,599,209,252,2304,264,1309,227],"12":[1021,1988,201,225,7259,1309,227,198,1214,5662,221,23734,4679,427,2433,320,2375,198,726,218,198,1060,296,198,4444,201,198,21976,390,3946,196,7944,1676,281,196,5860,221,198,4679,406,2205,2472,702,5228,281,326,13483,21340,2033,209,795,301,1311,277,3364,201,1650,201,281,1149,19033,15786,268,83,3441,201,8954,971,225,1975,251,379,22351,2341,227,1980,365,326,4479,7784,2920,2157,8263,1347,198,4444,201,4755,3947,307,444,5698,17446,209,9663,198,1117,218,19033,15786,371,198,6241,227,8223,3230,221,198,1060,296,4679,427,201,322,301,8954,971,268,83,1244,677,1940,6368,221,5290,196,664,5447],"13":[67,3294,4196,209,662,326,4579,4679,406,2033,896,5021,198,4444,1318,1149,198,506,8947,268,12199,277,1739,602,9684,14048,444,664,11418,4206,201,8625,8420,227,2109,695,1445,342,5860,198,21976,201,9271,22737,227,1343,1920,225,610,699,342,1434,281,2196,277,221,277,2429,209,7243,15215,17157,227,7235,548,702,3230,201,198,3775,18,207,21955,8954,971,225,1503,24815,198,7784,2920,8263,209,5056,2541,862,6084,444,4142,6294,221,1688,227,4555,444,4658,326,361,209,252,238,238,5823,238,238,252,19611,630,274,8576,9441,2987,959,550,1921,2393,263,8576,9441,1180,221,989,1462,201,296,1971,1921,3053],"14":[4691,550,526,209,261,2436,218,8576,9441,1180,201,21134,5337,19463,5161,201,1835,227,295,1496,274,8576,9441,2987,209,5823,630,1260,3032,496,544,209,929,198,1634,218,8576,9441,1180,201,198,3154,1260,196,1976,319,857,198,1737,3092,274,198,674,225,1241,543,2512,227,804,1533,274,198,843,209,4064,457,8472,201,8576,9441,2987,258,1898,274,4777,15522,466,526,258,1309,227,198,927,12182,227,1308,415,198,16590,2054,274,8576,9441,1180,201,225,543,398,410,3029,578,296,198,239,12252,239,3213,295,602,13884,196,664,6494,274,198,4777,428,209,11929,1006,221,326,2692,201,322,258,1821,295],"15":[198,1921,927,1535,8576,9441,2987,227,293,198,843,268,455,4099,7056,466,732,8576,9441,1180,398,3034,196,1178,3126,218,5429,225,12738,661,1921,1309,227,198,5770,2341,201,198,1368,674,1706,642,196,6241,227,7127,2451,198,1386,2773,218,8576,9441,1180,1309,227,812,263,198,1049,5770,209,357,1513,227,742,3541,3154,340,198,1382,1374,201,1921,630,258,471,8650,307,7514,14,54,719,14,261,1446,13395,258,1617,15415,9455,558,9884,706,69,201,732,198,3355,258,1617,307,11856,1652,9455,12645,74,5435,201,11490,20317,12329,10190,201,384,424,11300,21187,10198,201,742,14619,9765,368,5161,225,10830,288]}} \ No newline at end of file diff --git a/python/ray/ml/tests/huggingface_data/validation.json b/python/ray/ml/tests/huggingface_data/validation.json new file mode 100644 index 000000000000..d252663aa6f8 --- /dev/null +++ b/python/ray/ml/tests/huggingface_data/validation.json @@ -0,0 +1 @@ +{"input_ids":{"0":[238,19210,12331,2208,283,238,252,19210,12331,2208,283,201,876,281,198,2364,17152,371,1512,17152,201,301,196,1293,218,472,655,204,17152,340,198,2225,2477,4575,201,7662,4116,225,2773,218,198,1885,4116,209,530,301,5952,3130,227,198,1013,17152,201,282,14,196,760,12898,209,530,977,2156,227,196,2591,218,2838,3171,311,1678,221,310,225,196,2223,218,605,21695,311,1086,4336,310,201,225,10043,196,20778,3289,218,23714,209,357,1151,201,198,7162,9879,390,4071,201,669,3179,239,17152,1994,239,263,12897,209,259,761,6395,221,198,3271,201,6453,7144,400,390,3493,307,198,5138,274,578,227,196],"1":[544,806,6066,2753,587,23664,851,229,17469,209,19210,12331,2208,283,301,196,3879,928,21008,204,2970,201,225,301,3991,6637,1919,17152,224,1637,201,3223,1118,198,1038,11303,209,252,238,238,10150,238,238,252,19210,12331,2208,283,301,196,1178,18183,571,220,201,296,196,1825,2591,578,227,2838,10892,311,1678,221,310,225,19186,578,227,502,459,605,21695,311,920,459,1086,4336,310,201,1360,198,7162,9879,6637,221,17152,224,1637,390,3076,1625,459,3075,3171,311,670,459,916,221,310,908,225,11698,796,511,629,459,289,511,289,4890,311,245,511,502,459,493,511,670,4336,310,209,4064,548,18183],"2":[571,738,201,7162,9879,490,196,2804,377,316,559,8605,400,543,2686,14864,221,1688,227,2156,201,221,196,2169,994,5080,68,1162,215,311,24820,1236,310,209,795,977,2100,951,1696,196,544,274,2246,7162,9879,201,445,24202,227,2109,1621,245,459,289,818,274,3097,5503,209,252,261,455,3289,218,224,350,73,343,11051,301,4579,296,196,1178,201,20234,4371,3688,3289,218,23714,209,261,3097,496,301,198,239,1643,17959,239,201,225,499,12466,24880,5571,788,274,1643,5003,6490,440,198,548,301,198,239,2439,280,239,201,400,499,9629,6639,8864,201,225,301,788,274,6205,371,194,10866,198,6490],"3":[209,4891,957,201,198,1344,472,655,301,198,1643,17959,201,225,198,1557,301,198,2439,280,209,252,261,377,316,559,8605,301,2370,4071,2660,201,296,9915,295,767,1130,298,201,225,4578,4052,209,261,1994,6334,3720,296,7162,9879,669,3440,550,12897,209,795,6395,1080,201,221,1151,201,198,1994,24422,4624,1215,6233,197,301,3460,227,196,4279,3794,201,445,198,3794,301,4752,578,307,198,8126,218,12897,201,11701,198,1994,24422,209,252,261,8445,7339,218,282,14,12331,2208,283,301,198,1013,17152,201,19210,196,760,12898,209,261,506,1293,390,1365,1553,201,225,684,293,6982,982,778,1053,201],"4":[1360,15608,2521,390,9718,227,2100,221,198,4520,1211,444,8564,804,410,14358,209,261,506,1293,684,293,8793,307,196,845,218,6498,466,252,261,494,246,5398,218,282,14,196,760,12898,10043,496,371,604,20338,263,198,24534,201,400,390,10581,221,282,14,12331,2208,283,209,252,261,20338,263,198,23714,218,282,14,196,760,12898,390,1994,371,1994,277,22181,201,732,1413,218,282,14,12331,2208,283,390,1969,371,1969,277,22181,209,252,261,24534,218,198,472,655,218,282,14,196,760,12898,301,12960,371,1994,201,732,295,218,282,14,12331,2208,283,301,942,5676,1969,371,1365,8202,1994,209],"5":[252,238,238,5971,7308,238,238,252,14035,282,14,12331,2208,283,3484,3956,13559,646,543,490,8498,227,196,1044,422,571,2591,218,3105,459,5626,22608,311,428,511,245,459,428,511,428,221,310,201,6844,5020,10719,319,196,4370,3199,2892,209,259,761,4183,6395,221,3271,708,196,5527,24820,744,2476,201,2821,4093,301,4947,6144,201,225,196,2804,277,24110,2956,209,261,2476,9651,198,7144,274,578,227,837,2200,201,8326,263,198,5010,201,6386,227,468,4405,343,11051,209,17172,7600,7144,390,1014,227,293,239,20148,1635,239,225,684,293,824,1895,198,544,209,252,261,7144,15638,319,1948,201,225],"6":[198,17469,9408,227,198,1371,3239,756,543,11642,296,198,7153,20898,201,577,2464,263,22762,8222,858,851,209,795,2570,11533,697,24820,745,225,15964,274,916,459,2380,1651,209,929,198,1368,24820,84,201,198,16695,3544,263,196,740,7727,227,198,4047,201,225,3313,745,196,219,2121,229,13643,209,261,16696,390,6678,1809,221,198,4520,201,225,390,7808,876,201,1360,543,390,876,227,293,5992,218,15590,5536,2515,6846,209,530,301,2729,295,669,245,768,6820,221,1621,464,562,802,14272,227,198,219,2121,229,5427,209,1438,543,3484,196,1044,422,571,2591,218,916,1999,311,796,511,5931,221,310],"7":[201,198,16696,3489,444,2515,6846,225,1046,444,4047,4658,209,252,238,238,14720,238,238,252,19210,12331,2208,283,301,824,1896,198,1128,277,2225,2477,4575,340,2518,7984,227,198,14780,225,22054,201,410,867,198,18294,4116,209,530,301,471,1474,221,659,218,198,7662,4116,201,669,6157,340,198,3161,1657,218,4968,563,201,225,1008,669,198,1128,277,1639,2076,218,198,1885,4116,209,261,2518,1942,5892,390,824,221,198,10768,214,74,1907,236,1162,70,74,412,213,225,18700,17840,65,201,3950,198,21427,9743,234,209,252,261,1293,684,293,4695,587,784,1851,269,1320,4344,5892,201,496,6098,1866],"8":[201,225,697,400,490,9068,1166,1309,227,1119,5501,1866,15321,201,4692,1309,227,8552,227,198,1492,5039,209,261,455,218,1021,301,198,1866,218,7162,9879,340,2518,7984,201,400,490,504,3242,227,281,198,239,16501,277,6271,17152,239,209,261,5892,221,198,7662,4116,390,4344,340,1413,221,198,2477,4575,209,261,1244,4344,1866,301,824,221,198,7533,466,12828,340,198,339,419,320,4651,352,69,360,4344,340,1413,5464,221,198,1182,4116,371,1545,7484,209,252,24135,83,490,504,746,227,14275,282,14,12331,2208,283,227,752,2952,201,4583,548,2364,1293,781,281,198,15494,21089,201,24872,224,330],"9":[260,283,209,5507,11329,225,5989,201,496,1079,17152,17469,360,1055,340,6066,2264,346,221,5222,204,197,201,445,198,1293,992,410,1638,1860,817,209,252,238,238,19976,238,238,252,11312,282,14,12331,2208,283,2253,263,198,10077,22975,319,21762,218,796,459,4178,3224,311,796,459,4468,18,1841,310,201,1360,410,7736,15408,702,1834,230,311,5722,1841,310,209,1065,5463,2804,8138,7075,201,781,281,10578,371,2804,11832,201,225,2253,221,12782,371,942,86,1656,201,16062,319,1948,227,7710,209,252,261,7546,218,282,14,12331,2208,283,3223,5886,218,548,219,2121,229,18499,209,1850,1766,6063,1466,201],"10":[230,1894,13799,83,201,3304,159,885,776,201,1322,11207,225,4291,18680,16328,21077,209,252,261,697,472,655,204,17152,1293,19210,12331,2208,283,201,282,14,196,760,12898,225,1853,575,635,83,5746,711,71,10491,390,12202,227,198,697,876,1293,218,198,4926,1091,930,327,247,2236,73,2401,4403,440,198,1293,263,282,14,12331,2208,283,499,410,504,1246,209,252,19210,12331,2208,283,301,22081,227,198,4516,267,1867,4219,7024,201,2035,307,198,14477,1026,5611,375,13489,283,9769,271,738,209,1647,322,301,5431,824,221,1013,7162,9879,201,198,4516,499,669,504,1809,221,19676,282,14,12331,2208,283],"11":[201,756,3119,9025,218,198,8284,307,282,14,196,760,12898,990,410,293,6620,573,209,252,238,238,6870,11391,238,238,252,19210,12331,2208,283,301,13924,239,3879,928,21008,204,239,281,196,2970,246,3760,225,258,4478,221,239,261,10835,11207,239,196,18409,926,1545,4676,647,209,530,977,13762,309,1365,823,15210,225,977,293,1574,5739,201,16791,7010,201,684,2457,371,4192,68,665,209,3965,198,23714,225,198,22051,218,282,14,12331,2208,283,4607,239,13123,239,1969,12394,201,225,659,218,198,15891,218,198,2374,15444,679,210,1215,390,15494,209,261,17588,390,198,21202,523,955,225,198,239,4936],"12":[22774,239,311,21201,310,209,261,11588,218,282,14,12331,2208,283,301,578,227,697,1696,4014,702,295,218,282,14,196,760,12898,201,225,198,2364,1293,301,1535,227,490,196,3077,12693,418,209,252,317,706,9879,390,3223,214,736,1919,17152,224,1637,201,1360,2875,22218,616,296,11820,20305,371,2439,84,234,11207,3019,8900,221,22179,223,642,573,201,227,1313,642,227,293,6637,221,196,8946,371,307,1414,209,357,1465,201,493,562,3075,22,194,218,282,14,12331,2208,283,360,6637,1896,1399,225,1182,3054,201,218,400,428,562,4683,18,194,311,6600,1124,310,258,6637,221,198,1038,11303,311],"13":[867,198,7484,7362,310,209,261,8240,5360,2892,274,282,14,12331,2208,283,301,196,1044,422,571,2591,218,6286,1999,311,428,511,493,221,310,209,252,13773,257,21384,5135,274,282,14,12331,2208,283,390,699,1921,201,225,1743,7928,390,1490,1365,1669,209,252,238,238,8937,16097,1514,238,238,252,19210,12331,2208,283,258,455,1782,196,12072,22267,1221,307,5064,317,909,19649,221,198,8813,4615,218,342,5599,65,318,4881,1240,201,1573,221,5986,24,209,2881,1221,258,24872,12331,2208,283,201,1211,317,909,19649,268,3301,218,198,4013,24872,319,295,582,1331,610,1178,18183,571,738,209,252,282,14],"14":[12331,2208,283,301,198,3108,1293,218,198,4013,19210,17518,201,20690,201,281,4729,307,300,4340,3945,218,198,2392,4774,263,16455,1839,318,249,22701,1411,209,7666,227,295,4568,201,11527,12061,1080,198,1293,398,504,3242,227,307,951,1579,2850,201,867,5265,257,283,1261,4487,14166,523,2215,201,22956,225,19210,18381,215,282,14,3248,1775,277,18903,201,14300,201,225,471,1080,14415,17518,268,83,5278,218,198,4013,398,504,17228,1040,1994,22708,307,2489,297,633,66,302,201,19912,1061,3119,13525,967,218,3108,1293,311,274,19210,282,14,3248,1775,277,18903,201,14300,310,21468,274,19210,17518,201,20690,209,252],"15":[261,3108,7139,218,19210,12331,2208,283,258,196,10529,7769,3941,307,23601,559,3059,299,85,215,221,5629,209,530,1501,340,5197,2719,4800,396,110,318,920,2719,3434,396,110,338,201,996,681,246,1245,201,9704,311,4067,5549,371,1342,1957,4587,218,9901,787,6012,310,201,445,857,322,225,198,1346,222,395,14403,490,1211,504,1777,209,252,261,1512,1221,274,282,14,12331,2208,283,7676,307,198,9758,225,15124,447,8473,301,239,2364,17152,239,201,445,198,1293,301,471,3991,876,281,198,239,1512,17152,239,209,252,238,3180,5577,4593,238,252,1934,8314,12825,3180,5577,4593,201,17920,201,17920,37]},"attention_mask":{"0":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"1":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"2":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"3":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"4":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"5":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"6":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"7":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"8":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"9":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"10":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"11":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"12":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"13":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"14":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"15":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]},"labels":{"0":[238,19210,12331,2208,283,238,252,19210,12331,2208,283,201,876,281,198,2364,17152,371,1512,17152,201,301,196,1293,218,472,655,204,17152,340,198,2225,2477,4575,201,7662,4116,225,2773,218,198,1885,4116,209,530,301,5952,3130,227,198,1013,17152,201,282,14,196,760,12898,209,530,977,2156,227,196,2591,218,2838,3171,311,1678,221,310,225,196,2223,218,605,21695,311,1086,4336,310,201,225,10043,196,20778,3289,218,23714,209,357,1151,201,198,7162,9879,390,4071,201,669,3179,239,17152,1994,239,263,12897,209,259,761,6395,221,198,3271,201,6453,7144,400,390,3493,307,198,5138,274,578,227,196],"1":[544,806,6066,2753,587,23664,851,229,17469,209,19210,12331,2208,283,301,196,3879,928,21008,204,2970,201,225,301,3991,6637,1919,17152,224,1637,201,3223,1118,198,1038,11303,209,252,238,238,10150,238,238,252,19210,12331,2208,283,301,196,1178,18183,571,220,201,296,196,1825,2591,578,227,2838,10892,311,1678,221,310,225,19186,578,227,502,459,605,21695,311,920,459,1086,4336,310,201,1360,198,7162,9879,6637,221,17152,224,1637,390,3076,1625,459,3075,3171,311,670,459,916,221,310,908,225,11698,796,511,629,459,289,511,289,4890,311,245,511,502,459,493,511,670,4336,310,209,4064,548,18183],"2":[571,738,201,7162,9879,490,196,2804,377,316,559,8605,400,543,2686,14864,221,1688,227,2156,201,221,196,2169,994,5080,68,1162,215,311,24820,1236,310,209,795,977,2100,951,1696,196,544,274,2246,7162,9879,201,445,24202,227,2109,1621,245,459,289,818,274,3097,5503,209,252,261,455,3289,218,224,350,73,343,11051,301,4579,296,196,1178,201,20234,4371,3688,3289,218,23714,209,261,3097,496,301,198,239,1643,17959,239,201,225,499,12466,24880,5571,788,274,1643,5003,6490,440,198,548,301,198,239,2439,280,239,201,400,499,9629,6639,8864,201,225,301,788,274,6205,371,194,10866,198,6490],"3":[209,4891,957,201,198,1344,472,655,301,198,1643,17959,201,225,198,1557,301,198,2439,280,209,252,261,377,316,559,8605,301,2370,4071,2660,201,296,9915,295,767,1130,298,201,225,4578,4052,209,261,1994,6334,3720,296,7162,9879,669,3440,550,12897,209,795,6395,1080,201,221,1151,201,198,1994,24422,4624,1215,6233,197,301,3460,227,196,4279,3794,201,445,198,3794,301,4752,578,307,198,8126,218,12897,201,11701,198,1994,24422,209,252,261,8445,7339,218,282,14,12331,2208,283,301,198,1013,17152,201,19210,196,760,12898,209,261,506,1293,390,1365,1553,201,225,684,293,6982,982,778,1053,201],"4":[1360,15608,2521,390,9718,227,2100,221,198,4520,1211,444,8564,804,410,14358,209,261,506,1293,684,293,8793,307,196,845,218,6498,466,252,261,494,246,5398,218,282,14,196,760,12898,10043,496,371,604,20338,263,198,24534,201,400,390,10581,221,282,14,12331,2208,283,209,252,261,20338,263,198,23714,218,282,14,196,760,12898,390,1994,371,1994,277,22181,201,732,1413,218,282,14,12331,2208,283,390,1969,371,1969,277,22181,209,252,261,24534,218,198,472,655,218,282,14,196,760,12898,301,12960,371,1994,201,732,295,218,282,14,12331,2208,283,301,942,5676,1969,371,1365,8202,1994,209],"5":[252,238,238,5971,7308,238,238,252,14035,282,14,12331,2208,283,3484,3956,13559,646,543,490,8498,227,196,1044,422,571,2591,218,3105,459,5626,22608,311,428,511,245,459,428,511,428,221,310,201,6844,5020,10719,319,196,4370,3199,2892,209,259,761,4183,6395,221,3271,708,196,5527,24820,744,2476,201,2821,4093,301,4947,6144,201,225,196,2804,277,24110,2956,209,261,2476,9651,198,7144,274,578,227,837,2200,201,8326,263,198,5010,201,6386,227,468,4405,343,11051,209,17172,7600,7144,390,1014,227,293,239,20148,1635,239,225,684,293,824,1895,198,544,209,252,261,7144,15638,319,1948,201,225],"6":[198,17469,9408,227,198,1371,3239,756,543,11642,296,198,7153,20898,201,577,2464,263,22762,8222,858,851,209,795,2570,11533,697,24820,745,225,15964,274,916,459,2380,1651,209,929,198,1368,24820,84,201,198,16695,3544,263,196,740,7727,227,198,4047,201,225,3313,745,196,219,2121,229,13643,209,261,16696,390,6678,1809,221,198,4520,201,225,390,7808,876,201,1360,543,390,876,227,293,5992,218,15590,5536,2515,6846,209,530,301,2729,295,669,245,768,6820,221,1621,464,562,802,14272,227,198,219,2121,229,5427,209,1438,543,3484,196,1044,422,571,2591,218,916,1999,311,796,511,5931,221,310],"7":[201,198,16696,3489,444,2515,6846,225,1046,444,4047,4658,209,252,238,238,14720,238,238,252,19210,12331,2208,283,301,824,1896,198,1128,277,2225,2477,4575,340,2518,7984,227,198,14780,225,22054,201,410,867,198,18294,4116,209,530,301,471,1474,221,659,218,198,7662,4116,201,669,6157,340,198,3161,1657,218,4968,563,201,225,1008,669,198,1128,277,1639,2076,218,198,1885,4116,209,261,2518,1942,5892,390,824,221,198,10768,214,74,1907,236,1162,70,74,412,213,225,18700,17840,65,201,3950,198,21427,9743,234,209,252,261,1293,684,293,4695,587,784,1851,269,1320,4344,5892,201,496,6098,1866],"8":[201,225,697,400,490,9068,1166,1309,227,1119,5501,1866,15321,201,4692,1309,227,8552,227,198,1492,5039,209,261,455,218,1021,301,198,1866,218,7162,9879,340,2518,7984,201,400,490,504,3242,227,281,198,239,16501,277,6271,17152,239,209,261,5892,221,198,7662,4116,390,4344,340,1413,221,198,2477,4575,209,261,1244,4344,1866,301,824,221,198,7533,466,12828,340,198,339,419,320,4651,352,69,360,4344,340,1413,5464,221,198,1182,4116,371,1545,7484,209,252,24135,83,490,504,746,227,14275,282,14,12331,2208,283,227,752,2952,201,4583,548,2364,1293,781,281,198,15494,21089,201,24872,224,330],"9":[260,283,209,5507,11329,225,5989,201,496,1079,17152,17469,360,1055,340,6066,2264,346,221,5222,204,197,201,445,198,1293,992,410,1638,1860,817,209,252,238,238,19976,238,238,252,11312,282,14,12331,2208,283,2253,263,198,10077,22975,319,21762,218,796,459,4178,3224,311,796,459,4468,18,1841,310,201,1360,410,7736,15408,702,1834,230,311,5722,1841,310,209,1065,5463,2804,8138,7075,201,781,281,10578,371,2804,11832,201,225,2253,221,12782,371,942,86,1656,201,16062,319,1948,227,7710,209,252,261,7546,218,282,14,12331,2208,283,3223,5886,218,548,219,2121,229,18499,209,1850,1766,6063,1466,201],"10":[230,1894,13799,83,201,3304,159,885,776,201,1322,11207,225,4291,18680,16328,21077,209,252,261,697,472,655,204,17152,1293,19210,12331,2208,283,201,282,14,196,760,12898,225,1853,575,635,83,5746,711,71,10491,390,12202,227,198,697,876,1293,218,198,4926,1091,930,327,247,2236,73,2401,4403,440,198,1293,263,282,14,12331,2208,283,499,410,504,1246,209,252,19210,12331,2208,283,301,22081,227,198,4516,267,1867,4219,7024,201,2035,307,198,14477,1026,5611,375,13489,283,9769,271,738,209,1647,322,301,5431,824,221,1013,7162,9879,201,198,4516,499,669,504,1809,221,19676,282,14,12331,2208,283],"11":[201,756,3119,9025,218,198,8284,307,282,14,196,760,12898,990,410,293,6620,573,209,252,238,238,6870,11391,238,238,252,19210,12331,2208,283,301,13924,239,3879,928,21008,204,239,281,196,2970,246,3760,225,258,4478,221,239,261,10835,11207,239,196,18409,926,1545,4676,647,209,530,977,13762,309,1365,823,15210,225,977,293,1574,5739,201,16791,7010,201,684,2457,371,4192,68,665,209,3965,198,23714,225,198,22051,218,282,14,12331,2208,283,4607,239,13123,239,1969,12394,201,225,659,218,198,15891,218,198,2374,15444,679,210,1215,390,15494,209,261,17588,390,198,21202,523,955,225,198,239,4936],"12":[22774,239,311,21201,310,209,261,11588,218,282,14,12331,2208,283,301,578,227,697,1696,4014,702,295,218,282,14,196,760,12898,201,225,198,2364,1293,301,1535,227,490,196,3077,12693,418,209,252,317,706,9879,390,3223,214,736,1919,17152,224,1637,201,1360,2875,22218,616,296,11820,20305,371,2439,84,234,11207,3019,8900,221,22179,223,642,573,201,227,1313,642,227,293,6637,221,196,8946,371,307,1414,209,357,1465,201,493,562,3075,22,194,218,282,14,12331,2208,283,360,6637,1896,1399,225,1182,3054,201,218,400,428,562,4683,18,194,311,6600,1124,310,258,6637,221,198,1038,11303,311],"13":[867,198,7484,7362,310,209,261,8240,5360,2892,274,282,14,12331,2208,283,301,196,1044,422,571,2591,218,6286,1999,311,428,511,493,221,310,209,252,13773,257,21384,5135,274,282,14,12331,2208,283,390,699,1921,201,225,1743,7928,390,1490,1365,1669,209,252,238,238,8937,16097,1514,238,238,252,19210,12331,2208,283,258,455,1782,196,12072,22267,1221,307,5064,317,909,19649,221,198,8813,4615,218,342,5599,65,318,4881,1240,201,1573,221,5986,24,209,2881,1221,258,24872,12331,2208,283,201,1211,317,909,19649,268,3301,218,198,4013,24872,319,295,582,1331,610,1178,18183,571,738,209,252,282,14],"14":[12331,2208,283,301,198,3108,1293,218,198,4013,19210,17518,201,20690,201,281,4729,307,300,4340,3945,218,198,2392,4774,263,16455,1839,318,249,22701,1411,209,7666,227,295,4568,201,11527,12061,1080,198,1293,398,504,3242,227,307,951,1579,2850,201,867,5265,257,283,1261,4487,14166,523,2215,201,22956,225,19210,18381,215,282,14,3248,1775,277,18903,201,14300,201,225,471,1080,14415,17518,268,83,5278,218,198,4013,398,504,17228,1040,1994,22708,307,2489,297,633,66,302,201,19912,1061,3119,13525,967,218,3108,1293,311,274,19210,282,14,3248,1775,277,18903,201,14300,310,21468,274,19210,17518,201,20690,209,252],"15":[261,3108,7139,218,19210,12331,2208,283,258,196,10529,7769,3941,307,23601,559,3059,299,85,215,221,5629,209,530,1501,340,5197,2719,4800,396,110,318,920,2719,3434,396,110,338,201,996,681,246,1245,201,9704,311,4067,5549,371,1342,1957,4587,218,9901,787,6012,310,201,445,857,322,225,198,1346,222,395,14403,490,1211,504,1777,209,252,261,1512,1221,274,282,14,12331,2208,283,7676,307,198,9758,225,15124,447,8473,301,239,2364,17152,239,201,445,198,1293,301,471,3991,876,281,198,239,1512,17152,239,209,252,238,3180,5577,4593,238,252,1934,8314,12825,3180,5577,4593,201,17920,201,17920,37]}} \ No newline at end of file diff --git a/python/ray/ml/tests/test_huggingface_trainer.py b/python/ray/ml/tests/test_huggingface_trainer.py new file mode 100644 index 000000000000..e49099fee038 --- /dev/null +++ b/python/ray/ml/tests/test_huggingface_trainer.py @@ -0,0 +1,124 @@ +import pandas as pd +import pytest +import torch +from datasets.arrow_dataset import Dataset +from transformers import AutoConfig, AutoModelForCausalLM, Trainer, TrainingArguments + +import ray.data +from ray.ml.train.integrations.huggingface import HuggingFaceTrainer +from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor +from ray.ml.train.integrations.huggingface.huggingface_utils import process_datasets + +# 16 first rows of tokenized wikitext-2-raw-v1 training & validation +train_df = pd.read_json("./huggingface_data/train.json") +validation_df = pd.read_json("./huggingface_data/validation.json") + +# We are only testing Casual Language Modelling here + +model_checkpoint = "sshleifer/tiny-gpt2" + + +@pytest.fixture +def ray_start_4_cpus(): + address_info = ray.init(num_cpus=4) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +def train_function(train_dataset, eval_dataset=None, **config): + model_config = AutoConfig.from_pretrained(model_checkpoint) + model = AutoModelForCausalLM.from_config(model_config) + training_args = TrainingArguments( + f"{model_checkpoint}-wikitext2", + evaluation_strategy="epoch", + num_train_epochs=config.get("epochs", 3), + learning_rate=2e-5, + weight_decay=0.01, + disable_tqdm=True, + no_cuda=True, + save_strategy=config.get("save_strategy", "no"), + ) + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + return trainer + + +@pytest.mark.parametrize("save_strategy", ["no", "epoch"]) +def test_e2e(ray_start_4_cpus, save_strategy): + ray_train = ray.data.from_pandas(train_df) + ray_validation = ray.data.from_pandas(validation_df) + scaling_config = {"num_workers": 2, "use_gpu": False} + trainer = HuggingFaceTrainer( + trainer_init_per_worker=train_function, + trainer_init_config={"epochs": 3, "save_strategy": save_strategy}, + scaling_config=scaling_config, + datasets={"train": ray_train, "evaluation": ray_validation}, + ) + result = trainer.fit() + + assert result.metrics["epoch"] == 3 + assert result.checkpoint + + trainer2 = HuggingFaceTrainer( + trainer_init_per_worker=train_function, + trainer_init_config={"epochs": 4}, # this will train for 1 epoch: 4 - 3 = 1 + scaling_config=scaling_config, + datasets={"train": ray_train, "evaluation": ray_validation}, + resume_from_checkpoint=result.checkpoint, + ) + result2 = trainer2.fit() + + assert result2.metrics["epoch"] == 4 + assert result2.checkpoint + + class HuggingFaceScorer: + def __init__(self): + self.pred = HuggingFacePredictor.from_checkpoint( + result2.checkpoint, AutoModelForCausalLM + ) + + def __call__(self, x): + return self.pred.predict(x) + + predictions = ray_validation.map_batches( + HuggingFaceScorer, batch_size=8, batch_format="pandas", compute="actors" + ) + assert predictions.count() == 16 + + +def test_same_data_format(ray_start_4_cpus): + train_hf_dataset = Dataset.from_pandas(train_df) + validation_hf_dataset = Dataset.from_pandas(validation_df) + hf_trainer = train_function(train_hf_dataset, validation_hf_dataset) + hf_trainer._get_train_sampler = lambda: None # No randomness + hf_train_dataloader = hf_trainer.get_train_dataloader() + + ray_train = ray.data.from_pandas(train_df) + ray_validation = ray.data.from_pandas(validation_df) + ray_train, ray_validation = process_datasets(ray_train, ray_validation) + ray_trainer = train_function(ray_train, ray_validation) + ray_train_dataloader = ray_trainer.get_train_dataloader() + + hf_train_dataloader_inputs = [ + hf_trainer._prepare_inputs(inputs) for inputs in hf_train_dataloader + ] + ray_train_dataloader_inputs = [ + ray_trainer._prepare_inputs(inputs) for inputs in ray_train_dataloader + ] + + def equal_or_exception(a: torch.Tensor, b: torch.Tensor): + if not torch.equal(a, b): + raise AssertionError( + f"Tensor A ({a.shape}) doesn't equal tensor B ({b.shape}):" + f"\n{a}\n{b}\n" + ) + + [ + [equal_or_exception(a[k], b[k]) for k in a] + for a, b in zip(hf_train_dataloader_inputs, ray_train_dataloader_inputs) + ] diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index eb051f77c9aa..2c85440be0e6 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -26,7 +26,7 @@ from ray.ml.train.integrations.huggingface.huggingface_utils import ( CHECKPOINT_PATH_ON_NODE_KEY, NODE_IP_KEY, - process_dataset_for_hf, + process_datasets, TrainReportCallback, ) from ray.train.constants import TUNE_CHECKPOINT_ID @@ -101,6 +101,10 @@ class HuggingFaceTrainer(TorchTrainer): shards, with each Actor training on a single shard. All the other datasets will not be split. + The datasets will NOT be shuffled by default. Call ``Dataset.random_shuffle()`` + on the "train" dataset you are passing in ``datasets`` if you wish for the + training data to be shuffled. + Please note that if you use a custom ``transformers.Trainer`` subclass, the ``get_train_dataloader`` method will be overriden to disable distributed sampling, as the dataset will already be sharded. @@ -341,23 +345,10 @@ def _huggingface_train_loop_per_worker(config): train_dataset = train.get_dataset_shard(TRAIN_DATASET_KEY) eval_dataset = train.get_dataset_shard(EVALUATION_DATASET_KEY) - train_columns = set(train_dataset.schema(fetch_if_missing=True).names) - - # HF-specific format. See transformers.Trainer._prepare_inputs - feature_columns = {column: [column] for column in train_columns} - - batch_size = 1 - train_torch_dataset = process_dataset_for_hf( - train_dataset, feature_columns, batch_size=batch_size + train_torch_dataset, eval_torch_dataset = process_datasets( + train_dataset, eval_dataset ) - if eval_dataset: - eval_torch_dataset = process_dataset_for_hf( - eval_dataset, feature_columns, batch_size=batch_size - ) - else: - eval_torch_dataset = None - # TODO(yard1): Automatically set `no_cuda` somehow trainer: transformers.trainer.Trainer = trainer_init_per_worker( train_torch_dataset, eval_torch_dataset, **config diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_utils.py b/python/ray/ml/train/integrations/huggingface/huggingface_utils.py index 244b4338a141..21b3ecaa4aa3 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_utils.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, Generator, Iterator, List +from typing import Dict, Generator, Iterator, List, Tuple import torch import transformers.trainer @@ -24,7 +24,12 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: it = self.generator for x in it: # HF-specific format. See transformers.Trainer._prepare_inputs - yield {**x[0]} + if isinstance(x, dict): + # Just features + yield x + else: + # Features and labels + yield x[0] class HFIterableDatasetWithLen(HFIterableDataset): @@ -64,6 +69,37 @@ def process_dataset_for_hf( return torch_dataset +def process_datasets( + train_dataset: Dataset, eval_dataset: Dataset +) -> Tuple[IterableDataset, IterableDataset]: + """Convert Ray train and validation to HF-friendly IterableDatasets.""" + train_columns = set(train_dataset.schema(fetch_if_missing=True).names) + + # HF-specific format. See transformers.Trainer._prepare_inputs + feature_columns = {column: [column] for column in train_columns} + + # This is set to 1 to ensure that the model input format + # is the same as with HF's Dataset. If we were to pass + # an n>1 batch obtained from to_torch to HF Trainer, + # the format will differ, and the example count calculation + # will be messed up (as it assumes that it will always get + # just one row per output of the IterableDataset). + # TODO (yard1): Investigate if we can work around this. + batch_size = 1 + train_torch_dataset = process_dataset_for_hf( + train_dataset, feature_columns, batch_size=batch_size + ) + + if eval_dataset: + eval_torch_dataset = process_dataset_for_hf( + eval_dataset, feature_columns, batch_size=batch_size + ) + else: + eval_torch_dataset = None + + return train_torch_dataset, eval_torch_dataset + + class TrainReportCallback(TrainerCallback): """HF TrainerCallback for Ray Train metric reporting & checkpointing.""" @@ -80,6 +116,12 @@ def __init__(self) -> None: self.last_step = None super().__init__() + def on_step_end(self, args, state, control, **kwargs): + if control.should_training_stop: + # always save at end + control.should_save = True + return control + def on_log(self, args, state, control, model=None, logs=None, **kwargs): if state.global_step == self.last_step: return diff --git a/python/ray/ml/utils/torch_utils.py b/python/ray/ml/utils/torch_utils.py index fb1d8a154a51..f38c53913406 100644 --- a/python/ray/ml/utils/torch_utils.py +++ b/python/ray/ml/utils/torch_utils.py @@ -60,8 +60,14 @@ def tensorize(vals, dtype): # or otherwise cannot be made into a tensor directly. # We assume it's a sequence in that case. # This is more robust than checking for dtype. - # TODO(yard1): clarify if this should be cat or stack - return torch.stack([tensorize(x, dtype) for x in vals]) + + # TODO(yard1): clarify if we should always stack or only + # if len(tensorized) > 1. The latter gives the same + # output as huggingface for batch_size==1 + tensorized = [tensorize(x, dtype) for x in vals] + if len(tensorized) > 1: + return torch.stack([tensorize(x, dtype) for x in vals]) + return tensorized[0] def get_tensor_for_columns(columns, dtype): feature_tensors = [] From 25d3bda9d1c527cc182f84b37550ac1726aa2c03 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 20 Apr 2022 18:45:59 +0000 Subject: [PATCH 42/75] Lint --- python/ray/ml/utils/torch_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/ray/ml/utils/torch_utils.py b/python/ray/ml/utils/torch_utils.py index f38c53913406..7da04c14bc6e 100644 --- a/python/ray/ml/utils/torch_utils.py +++ b/python/ray/ml/utils/torch_utils.py @@ -1,7 +1,6 @@ from typing import Optional, Union, List, Dict import pandas as pd -import numpy as np import torch From 8a743409df6f38344f387cbac1f29627d888b462 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 20 Apr 2022 21:40:28 +0000 Subject: [PATCH 43/75] Make tests work --- python/ray/ml/BUILD | 8 ++++++++ .../integrations/huggingface/huggingface_predictor.py | 11 +++++++++-- .../predictors/integrations/torch/torch_predictor.py | 8 ++++---- python/ray/ml/tests/test_huggingface_trainer.py | 5 ++++- python/ray/ml/utils/torch_utils.py | 11 ++--------- 5 files changed, 27 insertions(+), 16 deletions(-) diff --git a/python/ray/ml/BUILD b/python/ray/ml/BUILD index e14cd7eb645b..cd4c394010f7 100644 --- a/python/ray/ml/BUILD +++ b/python/ray/ml/BUILD @@ -172,6 +172,14 @@ py_test( deps = [":ml_lib"] ) +py_test( + name = "test_huggingface_trainer", + size = "medium", + srcs = ["tests/test_huggingface_trainer.py"], + tags = ["team:ml", "exclusive"], + deps = [":ml_lib"] +) + py_test( name = "test_lightgbm_predictor", size = "small", diff --git a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py index a989a2abd54e..ffb50e82fecf 100644 --- a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py +++ b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py @@ -107,11 +107,18 @@ def _convert_to_tensor( data, feature_columns=feature_columns, dtypes=dtypes, unsqueeze=unsqueeze ) - def _predict(self, tensor: Dict[str, torch.Tensor]) -> np.ndarray: + def _predict(self, tensor: Dict[str, torch.Tensor]) -> pd.DataFrame: self.training_args.local_rank = -1 trainer = HFTrainer(model=self.model, args=self.training_args) dataset = HFIterableDatasetWithLen([tensor], 1) - return trainer.predict(dataset).predictions + # squeeze out the extra dimension added by torch.stack + # inside the HF data collator + ret = trainer.predict(dataset).predictions.squeeze() + # TODO(yard1): Return just a numpy array once that's supported + # by Ray Datasets + df = pd.DataFrame([ret.tolist()]).T + df.columns = ["predictions"] + return df def predict( self, diff --git a/python/ray/ml/predictors/integrations/torch/torch_predictor.py b/python/ray/ml/predictors/integrations/torch/torch_predictor.py index b0b1197fb15f..609cd04fc542 100644 --- a/python/ray/ml/predictors/integrations/torch/torch_predictor.py +++ b/python/ray/ml/predictors/integrations/torch/torch_predictor.py @@ -86,8 +86,9 @@ def _convert_to_tensor( ) return features_tensor - def _predict(self, tensor: torch.Tensor) -> np.ndarray: - return self.model(tensor).cpu().detach().numpy() + def _predict(self, tensor: torch.Tensor) -> pd.DataFrame: + prediction = self.model(tensor).cpu().detach().numpy() + return pd.DataFrame(prediction, columns=["predictions"]) def predict( self, @@ -172,5 +173,4 @@ def predict( tensor = self._convert_to_tensor( data, feature_columns=feature_columns, dtypes=dtype, unsqueeze=unsqueeze ) - prediction = self._predict(tensor) - return pd.DataFrame(prediction, columns=["predictions"]) + return self._predict(tensor) diff --git a/python/ray/ml/tests/test_huggingface_trainer.py b/python/ray/ml/tests/test_huggingface_trainer.py index e49099fee038..a7f6f262d863 100644 --- a/python/ray/ml/tests/test_huggingface_trainer.py +++ b/python/ray/ml/tests/test_huggingface_trainer.py @@ -118,7 +118,10 @@ def equal_or_exception(a: torch.Tensor, b: torch.Tensor): f"\n{a}\n{b}\n" ) + # We squeeze to get rid of the extra dimension added by the HF + # torch_default_data_collator. The models seem to train and predict + # fine with that extra dimension. [ - [equal_or_exception(a[k], b[k]) for k in a] + [equal_or_exception(a[k], b[k].squeeze()) for k in a] for a, b in zip(hf_train_dataloader_inputs, ray_train_dataloader_inputs) ] diff --git a/python/ray/ml/utils/torch_utils.py b/python/ray/ml/utils/torch_utils.py index 7da04c14bc6e..18634a2cac9f 100644 --- a/python/ray/ml/utils/torch_utils.py +++ b/python/ray/ml/utils/torch_utils.py @@ -59,14 +59,7 @@ def tensorize(vals, dtype): # or otherwise cannot be made into a tensor directly. # We assume it's a sequence in that case. # This is more robust than checking for dtype. - - # TODO(yard1): clarify if we should always stack or only - # if len(tensorized) > 1. The latter gives the same - # output as huggingface for batch_size==1 - tensorized = [tensorize(x, dtype) for x in vals] - if len(tensorized) > 1: - return torch.stack([tensorize(x, dtype) for x in vals]) - return tensorized[0] + return torch.stack([tensorize(x, dtype) for x in vals]) def get_tensor_for_columns(columns, dtype): feature_tensors = [] @@ -80,7 +73,7 @@ def get_tensor_for_columns(columns, dtype): col_vals = batch[col].values t = tensorize(col_vals, dtype=dtype) if unsqueeze: - t = t.view(-1, 1) + t = t.unsqueeze(1) feature_tensors.append(t) if len(feature_tensors) > 1: From 97d99ace9a4d7c9c70445734509967f9bd261f33 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 20 Apr 2022 21:40:41 +0000 Subject: [PATCH 44/75] Add n>1 gpus warning --- .../integrations/huggingface/huggingface_trainer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 2c85440be0e6..76d84bc55ee2 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -5,6 +5,7 @@ from distutils.version import LooseVersion from pathlib import Path from typing import Any, Callable, Dict, Optional, Type +import warnings import torch import transformers @@ -286,6 +287,7 @@ def _validate_trainer_init_per_worker( ) def _validate_attributes(self): + # exceptions first if TRAIN_DATASET_KEY not in self.datasets: raise KeyError( f"'{TRAIN_DATASET_KEY}' key must be preset in `datasets`. " @@ -299,7 +301,15 @@ def _validate_attributes(self): "keys can be preset in `datasets`. " f"Got {list(self.datasets.keys())}" ) - return super()._validate_attributes() + super()._validate_attributes() + gpus_per_worker = self.scaling_config.get("num_gpus_per_worker", 0) + if gpus_per_worker > 1: + warnings.warn( + f"You have assigned {gpus_per_worker} GPUs per worker. " + "This is not supported by HuggingFace, which expects " + "one GPU per worker in DDP mode and will not be " + "able to make use of more GPUs." + ) def as_trainable(self) -> Type[Trainable]: # Replace the directory checkpoint with a node ip & path dict checkpoint From 2188e40375bf36691ae70d924288a7c379385b23 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 20 Apr 2022 21:43:35 +0000 Subject: [PATCH 45/75] Raise exception instead of warning --- .../integrations/huggingface/huggingface_trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 76d84bc55ee2..e25662c368ac 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -301,16 +301,17 @@ def _validate_attributes(self): "keys can be preset in `datasets`. " f"Got {list(self.datasets.keys())}" ) - super()._validate_attributes() gpus_per_worker = self.scaling_config.get("num_gpus_per_worker", 0) if gpus_per_worker > 1: - warnings.warn( + raise ValueError( f"You have assigned {gpus_per_worker} GPUs per worker. " "This is not supported by HuggingFace, which expects " - "one GPU per worker in DDP mode and will not be " - "able to make use of more GPUs." + "one GPU per worker in DDP mode and will fail " + "if more are assigned." ) + super()._validate_attributes() + def as_trainable(self) -> Type[Trainable]: # Replace the directory checkpoint with a node ip & path dict checkpoint # used to sync the directory. If we use a directry checkpoint directly, From 6028d4f699549d73426570a568e40d55b6db7e75 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 20 Apr 2022 21:45:05 +0000 Subject: [PATCH 46/75] Add predictor doc --- doc/source/ray-air/getting-started.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/source/ray-air/getting-started.rst b/doc/source/ray-air/getting-started.rst index 1b823ebb9d17..034171c5a8ac 100644 --- a/doc/source/ray-air/getting-started.rst +++ b/doc/source/ray-air/getting-started.rst @@ -116,6 +116,10 @@ Predictors :members: :show-inheritance: +.. automodule:: ray.ml.predictors.integrations.huggingface + :members: + :show-inheritance: + .. _air-serve-integration: Serving From 7699f7928c8d1ce00a28691baf67859392e08748 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 21 Apr 2022 15:35:58 +0000 Subject: [PATCH 47/75] Bump train requirements --- python/requirements/ml/requirements_train.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/requirements/ml/requirements_train.txt b/python/requirements/ml/requirements_train.txt index 6497f8fb0a94..f1a37ad85c36 100644 --- a/python/requirements/ml/requirements_train.txt +++ b/python/requirements/ml/requirements_train.txt @@ -5,9 +5,11 @@ mlflow==1.21.0 tensorboardX==2.4.1 -# Dependencies for Hugging Face examples: +# Dependencies for Hugging Face examples & tests: # `python/ray/train/examples/transformers/transformers_example.py` +# `python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py` +# `python/ray/ml/tests/test_huggingface_trainer.py` transformers==4.18.0 accelerate==0.5.1 -datasets==1.14.0 +datasets==2.0.0 sentencepiece==0.1.96 From e85295900dc06506732b9a467786a6c83e887884 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 21 Apr 2022 15:36:20 +0000 Subject: [PATCH 48/75] Lint --- .../predictors/integrations/huggingface/huggingface_predictor.py | 1 - .../ray/ml/train/integrations/huggingface/huggingface_trainer.py | 1 - 2 files changed, 2 deletions(-) diff --git a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py index ffb50e82fecf..3dad01464a34 100644 --- a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py +++ b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py @@ -2,7 +2,6 @@ from typing import Dict, Optional, Type, Union, List import pandas as pd -import numpy as np from ray.ml.train.integrations.huggingface.huggingface_utils import ( HFIterableDatasetWithLen, ) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index e25662c368ac..9f13159dd09e 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -5,7 +5,6 @@ from distutils.version import LooseVersion from pathlib import Path from typing import Any, Callable, Dict, Optional, Type -import warnings import torch import transformers From a66da31b5e21f2978aa5104097ed36d6ee647ced Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 21 Apr 2022 16:31:43 +0000 Subject: [PATCH 49/75] Add more mocks to docs --- doc/source/custom_directives.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index 95b6a5128021..80cfd311ab78 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -130,8 +130,12 @@ def update_context(app, pagename, templatename, context, doctree): "tensorflow.contrib", "tensorflow.contrib.all_reduce", "transformers", + "transformers.modeling_utils", "transformers.trainer", "transformers.training_args", + "transformers.trainer_callback", + "transformers.utils", + "transformers.utils.logging", "tree", "tensorflow.contrib.all_reduce.python", "tensorflow.contrib.layers", From f2627bbe9fcb43c67887238edb49a1fed0cb919f Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 21 Apr 2022 16:48:18 +0000 Subject: [PATCH 50/75] Put data into file --- python/ray/ml/tests/huggingface_data/train.json | 1 - python/ray/ml/tests/huggingface_data/validation.json | 1 - python/ray/ml/tests/test_huggingface_trainer.py | 12 ++++++++++-- 3 files changed, 10 insertions(+), 4 deletions(-) delete mode 100644 python/ray/ml/tests/huggingface_data/train.json delete mode 100644 python/ray/ml/tests/huggingface_data/validation.json diff --git a/python/ray/ml/tests/huggingface_data/train.json b/python/ray/ml/tests/huggingface_data/train.json deleted file mode 100644 index a9c2e7302ae5..000000000000 --- a/python/ray/ml/tests/huggingface_data/train.json +++ /dev/null @@ -1 +0,0 @@ -{"input_ids":{"0":[238,8576,9441,2987,238,252,4657,74,4762,826,8576,428,466,609,6881,412,204,9441,311,2746,466,10816,168,99,150,192,112,14328,3983,112,4446,94,18288,4446,193,3983,98,3983,22171,95,19,201,6374,209,8576,218,198,3455,1972,428,310,201,5099,3242,227,281,8576,9441,2987,2553,1759,201,301,196,13996,1496,277,2330,1464,674,1898,307,742,3541,225,7514,14,54,719,274,198,4777,15522,209,19895,221,1341,1633,221,1759,201,322,301,198,1368,674,221,198,8576,843,209,2468,1795,223,198,1049,9595,218,13996,225,1563,277,582,6493,281,457,14371,201,198,1422,3373,7452,227,198,455,674,225,4687],"1":[198,239,21976,239,201,196,21657,1680,3773,5591,198,4196,218,4679,427,661,198,3518,1288,220,1051,516,889,3947,1922,2500,225,390,2065,744,872,198,7592,3773,239,1975,251,208,89,22351,239,209,252,261,674,959,1921,221,1462,201,7600,547,196,1178,4753,218,198,630,3591,263,8576,9441,1180,209,1831,322,7568,198,3621,2240,218,198,843,201,322,471,9575,5291,16591,967,201,781,281,1815,198,674,604,10344,1252,274,843,664,3147,320,209,13290,8751,8124,2528,6023,74,235,225,7445,10040,17384,241,11487,8950,857,1835,340,1382,22582,201,1008,296,8576,9441,1180,2436,21134,5337,19463,5161,209,240,1178],"2":[927,218,3776,8650,198,3355,209,261,674,268,83,2511,3472,258,8288,307,1010,268,78,209,252,530,922,296,3096,4354,221,1759,201,225,258,2630,307,857,2746,225,2761,2234,209,929,1634,201,322,1217,13362,4798,201,1008,296,326,5585,4615,221,1389,218,295,544,209,530,258,471,6459,587,8630,225,326,1446,1464,10260,843,209,5023,227,1669,4354,218,8576,9441,1180,201,8576,9441,2987,258,410,18878,201,445,196,6418,8895,20885,296,198,674,268,83,5585,4615,258,1055,221,2474,209,7514,14,54,719,602,1063,227,198,7072,296,198,1921,218,8576,466,8111,447,5474,274,198,4777,493,209,252],"3":[238,238,15114,238,238,252,662,296,1382,7784,9345,9441,1374,201,8576,9441,2987,301,196,13996,1496,277,2330,674,756,1997,1581,1694,218,196,1680,3773,225,1581,540,221,6455,872,5000,1754,209,8410,390,2848,718,5926,1184,277,1093,13659,296,10102,1136,13018,201,296,2081,6041,8040,718,7674,6917,14877,6632,225,8040,718,434,15744,5263,3973,209,261,1538,15606,718,196,843,218,11382,6455,201,5546,13230,5382,281,12903,295,684,293,15086,518,6145,718,225,1486,13196,281,543,390,13230,5382,209,261,2262,227,1140,1422,3371,263,198,7464,11718,8326,263,326,3854,1538,268,83,4156,466,646,496,6938,301,3941,201],"4":[198,548,301,15410,555,227,198,1538,209,13067,6455,201,198,1538,2081,1231,221,196,1306,201,756,2726,684,293,7376,968,225,1136,5168,6395,209,24281,198,906,1422,6455,390,1136,277,2737,747,6455,13516,227,1579,5860,1718,209,929,198,674,268,83,7131,201,3079,3411,390,13230,5382,201,766,218,642,1763,196,4014,6669,702,1413,824,221,198,1231,218,198,674,209,1415,390,471,2610,1076,1388,2377,3130,227,198,674,268,83,506,906,3641,828,201,1360,543,1581,196,1365,2978,1496,209,252,261,674,268,83,2094,1232,201,198,1028,73,52,58,1232,201,301,3493,547,3660,340,7784,9345,9441,209],"5":[1303,6455,201,1997,6322,1140,3773,1919,196,1362,277,1004,8550,218,198,11332,7464,466,2109,196,1136,301,3941,201,198,1538,5757,198,1136,1118,198,11332,221,1368,277,1432,209,240,1136,684,669,733,2109,503,277,1407,201,445,2081,684,293,4929,5291,6398,319,198,13876,218,548,2081,268,6398,209,5056,1136,499,196,1606,225,3924,218,3672,3069,307,444,11840,315,15041,209,3264,227,2849,2081,684,293,3423,227,196,1314,4490,209,1303,6493,201,2081,1170,798,573,1358,3851,16048,227,642,201,781,281,444,3668,2423,311,282,48,310,6506,1669,371,812,11285,573,307,5000,3704,209,5056,1136,499,2737,239],"6":[8303,2453,83,239,201,8436,5025,227,1140,1136,209,1065,390,4695,587,239,12021,8303,2453,239,201,400,390,3295,365,8436,295,3052,434,222,8830,11270,7936,9067,380,307,198,1422,225,684,2292,1546,371,1017,21071,196,1136,201,225,239,3455,8303,2453,83,239,201,400,390,8498,1895,198,674,225,3365,3326,1958,675,227,196,1136,209,2000,2627,3455,8303,2453,83,201,1140,1136,499,196,5025,239,8119,236,614,239,201,196,14628,277,1356,11949,6230,295,684,293,788,227,13970,225,3261,1579,8436,209,16500,471,490,7447,1531,3004,295,3326,642,8535,1958,11069,263,198,11332,466,8625,684,1980,365,239,4495],"7":[3683,239,225,2341,1118,198,11332,1582,897,234,1236,342,11840,6957,16326,201,198,1136,789,9878,684,9627,587,468,239,8576,3408,239,225,1638,896,23892,201,732,2293,5336,684,4849,5291,5000,2726,296,468,2287,8263,209,252,13441,1355,390,4695,587,1160,6555,466,23804,201,808,722,84,228,19466,201,12785,201,7096,320,225,21755,18607,853,209,13441,19466,684,12454,6555,307,7048,444,3423,8263,209,11070,223,1249,1990,410,5993,11950,198,285,1201,4825,732,221,196,1382,1249,209,1585,2495,221,2094,201,3335,2423,390,3495,227,198,5860,201,400,390,6908,587,1160,1579,13376,6136,307,198,2068,5860,201,196,3135],"8":[1317,223,340,989,1374,268,3588,218,10078,5751,227,1579,3773,4325,209,252,238,238,6851,238,238,252,261,674,3544,1417,661,198,3518,1288,220,1051,209,4679,406,2205,13619,3775,18,201,471,876,281,239,261,21976,239,201,390,196,21657,1680,3773,3808,218,18490,201,4899,541,199,690,201,225,1680,555,16914,2821,1563,2850,390,1843,620,340,198,3517,225,817,200,4614,3242,227,307,3473,209,10730,665,307,198,4679,406,1680,227,889,198,659,8110,6455,295,198,22724,2205,225,18991,1170,410,804,201,543,390,15994,578,227,198,5214,201,20891,307,444,22232,79,201,18766,4209,1531,21157,201,3457,239],"9":[14996,789,2989,209,239,261,697,906,2081,390,1524,14,23,8625,2408,2513,201,326,3006,4417,16271,539,7184,218,15101,516,9461,227,24374,1920,440,8607,1524,14,17,2293,5336,201,196,2476,5447,67,3294,2287,5253,16021,516,17113,14241,872,198,8576,516,3443,468,1398,440,225,1524,14,17,19,297,793,65,20714,215,201,196,11573,536,197,88,204,2246,3063,516,301,4796,2616,5744,196,20713,218,198,8576,209,9332,296,444,5920,5860,1718,201,1021,697,390,10455,227,2172,872,196,12040,7592,3773,876,281,1975,251,379,22351,201,6355,218,3223,5447,67,3294,3478,209,252,662,198,21976,4614,804,410,2523],"10":[201,198,3578,5080,1193,675,218,198,4679,406,2205,21878,198,3301,218,21337,2903,14454,221,1688,227,5355,642,263,6455,295,602,7936,1445,4679,427,7843,3132,221,198,1060,209,1831,319,1696,526,2125,227,444,4517,201,781,281,196,1991,13719,587,7592,3833,201,548,4953,3522,2931,1718,218,198,3775,18,207,1428,16464,209,1539,781,2541,201,315,283,2809,201,4228,749,554,16361,295,331,3669,675,342,1753,225,770,2220,587,198,10100,218,1975,251,379,22351,201,6386,227,198,9541,218,5447,67,3294,5668,3412,307,444,5164,201,8954,971,209,775,198,1049,582,201,2377,1347,4679,406,2205,3683,2341,227],"11":[1843,626,198,21976,221,1688,227,6259,444,1059,7969,209,282,15902,307,857,13855,225,9782,201,225,4439,296,198,3752,218,196,836,3547,1347,444,10100,201,198,3775,18,207,11695,1206,2341,227,3292,3230,10852,732,319,198,1049,582,2172,227,1546,198,4679,406,1060,4585,209,795,4471,1040,198,21976,268,83,14643,4417,201,3790,5491,22737,201,516,398,504,4755,699,2010,5125,201,301,15218,227,198,2867,1047,218,12106,71,337,90,221,1688,227,1474,2525,377,200,199,761,198,4239,89,3478,225,20416,198,1563,836,3547,201,198,4679,406,1796,295,398,7184,8625,218,3430,599,209,252,2304,264,1309,227],"12":[1021,1988,201,225,7259,1309,227,198,1214,5662,221,23734,4679,427,2433,320,2375,198,726,218,198,1060,296,198,4444,201,198,21976,390,3946,196,7944,1676,281,196,5860,221,198,4679,406,2205,2472,702,5228,281,326,13483,21340,2033,209,795,301,1311,277,3364,201,1650,201,281,1149,19033,15786,268,83,3441,201,8954,971,225,1975,251,379,22351,2341,227,1980,365,326,4479,7784,2920,2157,8263,1347,198,4444,201,4755,3947,307,444,5698,17446,209,9663,198,1117,218,19033,15786,371,198,6241,227,8223,3230,221,198,1060,296,4679,427,201,322,301,8954,971,268,83,1244,677,1940,6368,221,5290,196,664,5447],"13":[67,3294,4196,209,662,326,4579,4679,406,2033,896,5021,198,4444,1318,1149,198,506,8947,268,12199,277,1739,602,9684,14048,444,664,11418,4206,201,8625,8420,227,2109,695,1445,342,5860,198,21976,201,9271,22737,227,1343,1920,225,610,699,342,1434,281,2196,277,221,277,2429,209,7243,15215,17157,227,7235,548,702,3230,201,198,3775,18,207,21955,8954,971,225,1503,24815,198,7784,2920,8263,209,5056,2541,862,6084,444,4142,6294,221,1688,227,4555,444,4658,326,361,209,252,238,238,5823,238,238,252,19611,630,274,8576,9441,2987,959,550,1921,2393,263,8576,9441,1180,221,989,1462,201,296,1971,1921,3053],"14":[4691,550,526,209,261,2436,218,8576,9441,1180,201,21134,5337,19463,5161,201,1835,227,295,1496,274,8576,9441,2987,209,5823,630,1260,3032,496,544,209,929,198,1634,218,8576,9441,1180,201,198,3154,1260,196,1976,319,857,198,1737,3092,274,198,674,225,1241,543,2512,227,804,1533,274,198,843,209,4064,457,8472,201,8576,9441,2987,258,1898,274,4777,15522,466,526,258,1309,227,198,927,12182,227,1308,415,198,16590,2054,274,8576,9441,1180,201,225,543,398,410,3029,578,296,198,239,12252,239,3213,295,602,13884,196,664,6494,274,198,4777,428,209,11929,1006,221,326,2692,201,322,258,1821,295],"15":[198,1921,927,1535,8576,9441,2987,227,293,198,843,268,455,4099,7056,466,732,8576,9441,1180,398,3034,196,1178,3126,218,5429,225,12738,661,1921,1309,227,198,5770,2341,201,198,1368,674,1706,642,196,6241,227,7127,2451,198,1386,2773,218,8576,9441,1180,1309,227,812,263,198,1049,5770,209,357,1513,227,742,3541,3154,340,198,1382,1374,201,1921,630,258,471,8650,307,7514,14,54,719,14,261,1446,13395,258,1617,15415,9455,558,9884,706,69,201,732,198,3355,258,1617,307,11856,1652,9455,12645,74,5435,201,11490,20317,12329,10190,201,384,424,11300,21187,10198,201,742,14619,9765,368,5161,225,10830,288]},"attention_mask":{"0":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"1":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"2":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"3":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"4":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"5":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"6":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"7":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"8":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"9":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"10":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"11":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"12":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"13":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"14":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"15":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]},"labels":{"0":[238,8576,9441,2987,238,252,4657,74,4762,826,8576,428,466,609,6881,412,204,9441,311,2746,466,10816,168,99,150,192,112,14328,3983,112,4446,94,18288,4446,193,3983,98,3983,22171,95,19,201,6374,209,8576,218,198,3455,1972,428,310,201,5099,3242,227,281,8576,9441,2987,2553,1759,201,301,196,13996,1496,277,2330,1464,674,1898,307,742,3541,225,7514,14,54,719,274,198,4777,15522,209,19895,221,1341,1633,221,1759,201,322,301,198,1368,674,221,198,8576,843,209,2468,1795,223,198,1049,9595,218,13996,225,1563,277,582,6493,281,457,14371,201,198,1422,3373,7452,227,198,455,674,225,4687],"1":[198,239,21976,239,201,196,21657,1680,3773,5591,198,4196,218,4679,427,661,198,3518,1288,220,1051,516,889,3947,1922,2500,225,390,2065,744,872,198,7592,3773,239,1975,251,208,89,22351,239,209,252,261,674,959,1921,221,1462,201,7600,547,196,1178,4753,218,198,630,3591,263,8576,9441,1180,209,1831,322,7568,198,3621,2240,218,198,843,201,322,471,9575,5291,16591,967,201,781,281,1815,198,674,604,10344,1252,274,843,664,3147,320,209,13290,8751,8124,2528,6023,74,235,225,7445,10040,17384,241,11487,8950,857,1835,340,1382,22582,201,1008,296,8576,9441,1180,2436,21134,5337,19463,5161,209,240,1178],"2":[927,218,3776,8650,198,3355,209,261,674,268,83,2511,3472,258,8288,307,1010,268,78,209,252,530,922,296,3096,4354,221,1759,201,225,258,2630,307,857,2746,225,2761,2234,209,929,1634,201,322,1217,13362,4798,201,1008,296,326,5585,4615,221,1389,218,295,544,209,530,258,471,6459,587,8630,225,326,1446,1464,10260,843,209,5023,227,1669,4354,218,8576,9441,1180,201,8576,9441,2987,258,410,18878,201,445,196,6418,8895,20885,296,198,674,268,83,5585,4615,258,1055,221,2474,209,7514,14,54,719,602,1063,227,198,7072,296,198,1921,218,8576,466,8111,447,5474,274,198,4777,493,209,252],"3":[238,238,15114,238,238,252,662,296,1382,7784,9345,9441,1374,201,8576,9441,2987,301,196,13996,1496,277,2330,674,756,1997,1581,1694,218,196,1680,3773,225,1581,540,221,6455,872,5000,1754,209,8410,390,2848,718,5926,1184,277,1093,13659,296,10102,1136,13018,201,296,2081,6041,8040,718,7674,6917,14877,6632,225,8040,718,434,15744,5263,3973,209,261,1538,15606,718,196,843,218,11382,6455,201,5546,13230,5382,281,12903,295,684,293,15086,518,6145,718,225,1486,13196,281,543,390,13230,5382,209,261,2262,227,1140,1422,3371,263,198,7464,11718,8326,263,326,3854,1538,268,83,4156,466,646,496,6938,301,3941,201],"4":[198,548,301,15410,555,227,198,1538,209,13067,6455,201,198,1538,2081,1231,221,196,1306,201,756,2726,684,293,7376,968,225,1136,5168,6395,209,24281,198,906,1422,6455,390,1136,277,2737,747,6455,13516,227,1579,5860,1718,209,929,198,674,268,83,7131,201,3079,3411,390,13230,5382,201,766,218,642,1763,196,4014,6669,702,1413,824,221,198,1231,218,198,674,209,1415,390,471,2610,1076,1388,2377,3130,227,198,674,268,83,506,906,3641,828,201,1360,543,1581,196,1365,2978,1496,209,252,261,674,268,83,2094,1232,201,198,1028,73,52,58,1232,201,301,3493,547,3660,340,7784,9345,9441,209],"5":[1303,6455,201,1997,6322,1140,3773,1919,196,1362,277,1004,8550,218,198,11332,7464,466,2109,196,1136,301,3941,201,198,1538,5757,198,1136,1118,198,11332,221,1368,277,1432,209,240,1136,684,669,733,2109,503,277,1407,201,445,2081,684,293,4929,5291,6398,319,198,13876,218,548,2081,268,6398,209,5056,1136,499,196,1606,225,3924,218,3672,3069,307,444,11840,315,15041,209,3264,227,2849,2081,684,293,3423,227,196,1314,4490,209,1303,6493,201,2081,1170,798,573,1358,3851,16048,227,642,201,781,281,444,3668,2423,311,282,48,310,6506,1669,371,812,11285,573,307,5000,3704,209,5056,1136,499,2737,239],"6":[8303,2453,83,239,201,8436,5025,227,1140,1136,209,1065,390,4695,587,239,12021,8303,2453,239,201,400,390,3295,365,8436,295,3052,434,222,8830,11270,7936,9067,380,307,198,1422,225,684,2292,1546,371,1017,21071,196,1136,201,225,239,3455,8303,2453,83,239,201,400,390,8498,1895,198,674,225,3365,3326,1958,675,227,196,1136,209,2000,2627,3455,8303,2453,83,201,1140,1136,499,196,5025,239,8119,236,614,239,201,196,14628,277,1356,11949,6230,295,684,293,788,227,13970,225,3261,1579,8436,209,16500,471,490,7447,1531,3004,295,3326,642,8535,1958,11069,263,198,11332,466,8625,684,1980,365,239,4495],"7":[3683,239,225,2341,1118,198,11332,1582,897,234,1236,342,11840,6957,16326,201,198,1136,789,9878,684,9627,587,468,239,8576,3408,239,225,1638,896,23892,201,732,2293,5336,684,4849,5291,5000,2726,296,468,2287,8263,209,252,13441,1355,390,4695,587,1160,6555,466,23804,201,808,722,84,228,19466,201,12785,201,7096,320,225,21755,18607,853,209,13441,19466,684,12454,6555,307,7048,444,3423,8263,209,11070,223,1249,1990,410,5993,11950,198,285,1201,4825,732,221,196,1382,1249,209,1585,2495,221,2094,201,3335,2423,390,3495,227,198,5860,201,400,390,6908,587,1160,1579,13376,6136,307,198,2068,5860,201,196,3135],"8":[1317,223,340,989,1374,268,3588,218,10078,5751,227,1579,3773,4325,209,252,238,238,6851,238,238,252,261,674,3544,1417,661,198,3518,1288,220,1051,209,4679,406,2205,13619,3775,18,201,471,876,281,239,261,21976,239,201,390,196,21657,1680,3773,3808,218,18490,201,4899,541,199,690,201,225,1680,555,16914,2821,1563,2850,390,1843,620,340,198,3517,225,817,200,4614,3242,227,307,3473,209,10730,665,307,198,4679,406,1680,227,889,198,659,8110,6455,295,198,22724,2205,225,18991,1170,410,804,201,543,390,15994,578,227,198,5214,201,20891,307,444,22232,79,201,18766,4209,1531,21157,201,3457,239],"9":[14996,789,2989,209,239,261,697,906,2081,390,1524,14,23,8625,2408,2513,201,326,3006,4417,16271,539,7184,218,15101,516,9461,227,24374,1920,440,8607,1524,14,17,2293,5336,201,196,2476,5447,67,3294,2287,5253,16021,516,17113,14241,872,198,8576,516,3443,468,1398,440,225,1524,14,17,19,297,793,65,20714,215,201,196,11573,536,197,88,204,2246,3063,516,301,4796,2616,5744,196,20713,218,198,8576,209,9332,296,444,5920,5860,1718,201,1021,697,390,10455,227,2172,872,196,12040,7592,3773,876,281,1975,251,379,22351,201,6355,218,3223,5447,67,3294,3478,209,252,662,198,21976,4614,804,410,2523],"10":[201,198,3578,5080,1193,675,218,198,4679,406,2205,21878,198,3301,218,21337,2903,14454,221,1688,227,5355,642,263,6455,295,602,7936,1445,4679,427,7843,3132,221,198,1060,209,1831,319,1696,526,2125,227,444,4517,201,781,281,196,1991,13719,587,7592,3833,201,548,4953,3522,2931,1718,218,198,3775,18,207,1428,16464,209,1539,781,2541,201,315,283,2809,201,4228,749,554,16361,295,331,3669,675,342,1753,225,770,2220,587,198,10100,218,1975,251,379,22351,201,6386,227,198,9541,218,5447,67,3294,5668,3412,307,444,5164,201,8954,971,209,775,198,1049,582,201,2377,1347,4679,406,2205,3683,2341,227],"11":[1843,626,198,21976,221,1688,227,6259,444,1059,7969,209,282,15902,307,857,13855,225,9782,201,225,4439,296,198,3752,218,196,836,3547,1347,444,10100,201,198,3775,18,207,11695,1206,2341,227,3292,3230,10852,732,319,198,1049,582,2172,227,1546,198,4679,406,1060,4585,209,795,4471,1040,198,21976,268,83,14643,4417,201,3790,5491,22737,201,516,398,504,4755,699,2010,5125,201,301,15218,227,198,2867,1047,218,12106,71,337,90,221,1688,227,1474,2525,377,200,199,761,198,4239,89,3478,225,20416,198,1563,836,3547,201,198,4679,406,1796,295,398,7184,8625,218,3430,599,209,252,2304,264,1309,227],"12":[1021,1988,201,225,7259,1309,227,198,1214,5662,221,23734,4679,427,2433,320,2375,198,726,218,198,1060,296,198,4444,201,198,21976,390,3946,196,7944,1676,281,196,5860,221,198,4679,406,2205,2472,702,5228,281,326,13483,21340,2033,209,795,301,1311,277,3364,201,1650,201,281,1149,19033,15786,268,83,3441,201,8954,971,225,1975,251,379,22351,2341,227,1980,365,326,4479,7784,2920,2157,8263,1347,198,4444,201,4755,3947,307,444,5698,17446,209,9663,198,1117,218,19033,15786,371,198,6241,227,8223,3230,221,198,1060,296,4679,427,201,322,301,8954,971,268,83,1244,677,1940,6368,221,5290,196,664,5447],"13":[67,3294,4196,209,662,326,4579,4679,406,2033,896,5021,198,4444,1318,1149,198,506,8947,268,12199,277,1739,602,9684,14048,444,664,11418,4206,201,8625,8420,227,2109,695,1445,342,5860,198,21976,201,9271,22737,227,1343,1920,225,610,699,342,1434,281,2196,277,221,277,2429,209,7243,15215,17157,227,7235,548,702,3230,201,198,3775,18,207,21955,8954,971,225,1503,24815,198,7784,2920,8263,209,5056,2541,862,6084,444,4142,6294,221,1688,227,4555,444,4658,326,361,209,252,238,238,5823,238,238,252,19611,630,274,8576,9441,2987,959,550,1921,2393,263,8576,9441,1180,221,989,1462,201,296,1971,1921,3053],"14":[4691,550,526,209,261,2436,218,8576,9441,1180,201,21134,5337,19463,5161,201,1835,227,295,1496,274,8576,9441,2987,209,5823,630,1260,3032,496,544,209,929,198,1634,218,8576,9441,1180,201,198,3154,1260,196,1976,319,857,198,1737,3092,274,198,674,225,1241,543,2512,227,804,1533,274,198,843,209,4064,457,8472,201,8576,9441,2987,258,1898,274,4777,15522,466,526,258,1309,227,198,927,12182,227,1308,415,198,16590,2054,274,8576,9441,1180,201,225,543,398,410,3029,578,296,198,239,12252,239,3213,295,602,13884,196,664,6494,274,198,4777,428,209,11929,1006,221,326,2692,201,322,258,1821,295],"15":[198,1921,927,1535,8576,9441,2987,227,293,198,843,268,455,4099,7056,466,732,8576,9441,1180,398,3034,196,1178,3126,218,5429,225,12738,661,1921,1309,227,198,5770,2341,201,198,1368,674,1706,642,196,6241,227,7127,2451,198,1386,2773,218,8576,9441,1180,1309,227,812,263,198,1049,5770,209,357,1513,227,742,3541,3154,340,198,1382,1374,201,1921,630,258,471,8650,307,7514,14,54,719,14,261,1446,13395,258,1617,15415,9455,558,9884,706,69,201,732,198,3355,258,1617,307,11856,1652,9455,12645,74,5435,201,11490,20317,12329,10190,201,384,424,11300,21187,10198,201,742,14619,9765,368,5161,225,10830,288]}} \ No newline at end of file diff --git a/python/ray/ml/tests/huggingface_data/validation.json b/python/ray/ml/tests/huggingface_data/validation.json deleted file mode 100644 index d252663aa6f8..000000000000 --- a/python/ray/ml/tests/huggingface_data/validation.json +++ /dev/null @@ -1 +0,0 @@ -{"input_ids":{"0":[238,19210,12331,2208,283,238,252,19210,12331,2208,283,201,876,281,198,2364,17152,371,1512,17152,201,301,196,1293,218,472,655,204,17152,340,198,2225,2477,4575,201,7662,4116,225,2773,218,198,1885,4116,209,530,301,5952,3130,227,198,1013,17152,201,282,14,196,760,12898,209,530,977,2156,227,196,2591,218,2838,3171,311,1678,221,310,225,196,2223,218,605,21695,311,1086,4336,310,201,225,10043,196,20778,3289,218,23714,209,357,1151,201,198,7162,9879,390,4071,201,669,3179,239,17152,1994,239,263,12897,209,259,761,6395,221,198,3271,201,6453,7144,400,390,3493,307,198,5138,274,578,227,196],"1":[544,806,6066,2753,587,23664,851,229,17469,209,19210,12331,2208,283,301,196,3879,928,21008,204,2970,201,225,301,3991,6637,1919,17152,224,1637,201,3223,1118,198,1038,11303,209,252,238,238,10150,238,238,252,19210,12331,2208,283,301,196,1178,18183,571,220,201,296,196,1825,2591,578,227,2838,10892,311,1678,221,310,225,19186,578,227,502,459,605,21695,311,920,459,1086,4336,310,201,1360,198,7162,9879,6637,221,17152,224,1637,390,3076,1625,459,3075,3171,311,670,459,916,221,310,908,225,11698,796,511,629,459,289,511,289,4890,311,245,511,502,459,493,511,670,4336,310,209,4064,548,18183],"2":[571,738,201,7162,9879,490,196,2804,377,316,559,8605,400,543,2686,14864,221,1688,227,2156,201,221,196,2169,994,5080,68,1162,215,311,24820,1236,310,209,795,977,2100,951,1696,196,544,274,2246,7162,9879,201,445,24202,227,2109,1621,245,459,289,818,274,3097,5503,209,252,261,455,3289,218,224,350,73,343,11051,301,4579,296,196,1178,201,20234,4371,3688,3289,218,23714,209,261,3097,496,301,198,239,1643,17959,239,201,225,499,12466,24880,5571,788,274,1643,5003,6490,440,198,548,301,198,239,2439,280,239,201,400,499,9629,6639,8864,201,225,301,788,274,6205,371,194,10866,198,6490],"3":[209,4891,957,201,198,1344,472,655,301,198,1643,17959,201,225,198,1557,301,198,2439,280,209,252,261,377,316,559,8605,301,2370,4071,2660,201,296,9915,295,767,1130,298,201,225,4578,4052,209,261,1994,6334,3720,296,7162,9879,669,3440,550,12897,209,795,6395,1080,201,221,1151,201,198,1994,24422,4624,1215,6233,197,301,3460,227,196,4279,3794,201,445,198,3794,301,4752,578,307,198,8126,218,12897,201,11701,198,1994,24422,209,252,261,8445,7339,218,282,14,12331,2208,283,301,198,1013,17152,201,19210,196,760,12898,209,261,506,1293,390,1365,1553,201,225,684,293,6982,982,778,1053,201],"4":[1360,15608,2521,390,9718,227,2100,221,198,4520,1211,444,8564,804,410,14358,209,261,506,1293,684,293,8793,307,196,845,218,6498,466,252,261,494,246,5398,218,282,14,196,760,12898,10043,496,371,604,20338,263,198,24534,201,400,390,10581,221,282,14,12331,2208,283,209,252,261,20338,263,198,23714,218,282,14,196,760,12898,390,1994,371,1994,277,22181,201,732,1413,218,282,14,12331,2208,283,390,1969,371,1969,277,22181,209,252,261,24534,218,198,472,655,218,282,14,196,760,12898,301,12960,371,1994,201,732,295,218,282,14,12331,2208,283,301,942,5676,1969,371,1365,8202,1994,209],"5":[252,238,238,5971,7308,238,238,252,14035,282,14,12331,2208,283,3484,3956,13559,646,543,490,8498,227,196,1044,422,571,2591,218,3105,459,5626,22608,311,428,511,245,459,428,511,428,221,310,201,6844,5020,10719,319,196,4370,3199,2892,209,259,761,4183,6395,221,3271,708,196,5527,24820,744,2476,201,2821,4093,301,4947,6144,201,225,196,2804,277,24110,2956,209,261,2476,9651,198,7144,274,578,227,837,2200,201,8326,263,198,5010,201,6386,227,468,4405,343,11051,209,17172,7600,7144,390,1014,227,293,239,20148,1635,239,225,684,293,824,1895,198,544,209,252,261,7144,15638,319,1948,201,225],"6":[198,17469,9408,227,198,1371,3239,756,543,11642,296,198,7153,20898,201,577,2464,263,22762,8222,858,851,209,795,2570,11533,697,24820,745,225,15964,274,916,459,2380,1651,209,929,198,1368,24820,84,201,198,16695,3544,263,196,740,7727,227,198,4047,201,225,3313,745,196,219,2121,229,13643,209,261,16696,390,6678,1809,221,198,4520,201,225,390,7808,876,201,1360,543,390,876,227,293,5992,218,15590,5536,2515,6846,209,530,301,2729,295,669,245,768,6820,221,1621,464,562,802,14272,227,198,219,2121,229,5427,209,1438,543,3484,196,1044,422,571,2591,218,916,1999,311,796,511,5931,221,310],"7":[201,198,16696,3489,444,2515,6846,225,1046,444,4047,4658,209,252,238,238,14720,238,238,252,19210,12331,2208,283,301,824,1896,198,1128,277,2225,2477,4575,340,2518,7984,227,198,14780,225,22054,201,410,867,198,18294,4116,209,530,301,471,1474,221,659,218,198,7662,4116,201,669,6157,340,198,3161,1657,218,4968,563,201,225,1008,669,198,1128,277,1639,2076,218,198,1885,4116,209,261,2518,1942,5892,390,824,221,198,10768,214,74,1907,236,1162,70,74,412,213,225,18700,17840,65,201,3950,198,21427,9743,234,209,252,261,1293,684,293,4695,587,784,1851,269,1320,4344,5892,201,496,6098,1866],"8":[201,225,697,400,490,9068,1166,1309,227,1119,5501,1866,15321,201,4692,1309,227,8552,227,198,1492,5039,209,261,455,218,1021,301,198,1866,218,7162,9879,340,2518,7984,201,400,490,504,3242,227,281,198,239,16501,277,6271,17152,239,209,261,5892,221,198,7662,4116,390,4344,340,1413,221,198,2477,4575,209,261,1244,4344,1866,301,824,221,198,7533,466,12828,340,198,339,419,320,4651,352,69,360,4344,340,1413,5464,221,198,1182,4116,371,1545,7484,209,252,24135,83,490,504,746,227,14275,282,14,12331,2208,283,227,752,2952,201,4583,548,2364,1293,781,281,198,15494,21089,201,24872,224,330],"9":[260,283,209,5507,11329,225,5989,201,496,1079,17152,17469,360,1055,340,6066,2264,346,221,5222,204,197,201,445,198,1293,992,410,1638,1860,817,209,252,238,238,19976,238,238,252,11312,282,14,12331,2208,283,2253,263,198,10077,22975,319,21762,218,796,459,4178,3224,311,796,459,4468,18,1841,310,201,1360,410,7736,15408,702,1834,230,311,5722,1841,310,209,1065,5463,2804,8138,7075,201,781,281,10578,371,2804,11832,201,225,2253,221,12782,371,942,86,1656,201,16062,319,1948,227,7710,209,252,261,7546,218,282,14,12331,2208,283,3223,5886,218,548,219,2121,229,18499,209,1850,1766,6063,1466,201],"10":[230,1894,13799,83,201,3304,159,885,776,201,1322,11207,225,4291,18680,16328,21077,209,252,261,697,472,655,204,17152,1293,19210,12331,2208,283,201,282,14,196,760,12898,225,1853,575,635,83,5746,711,71,10491,390,12202,227,198,697,876,1293,218,198,4926,1091,930,327,247,2236,73,2401,4403,440,198,1293,263,282,14,12331,2208,283,499,410,504,1246,209,252,19210,12331,2208,283,301,22081,227,198,4516,267,1867,4219,7024,201,2035,307,198,14477,1026,5611,375,13489,283,9769,271,738,209,1647,322,301,5431,824,221,1013,7162,9879,201,198,4516,499,669,504,1809,221,19676,282,14,12331,2208,283],"11":[201,756,3119,9025,218,198,8284,307,282,14,196,760,12898,990,410,293,6620,573,209,252,238,238,6870,11391,238,238,252,19210,12331,2208,283,301,13924,239,3879,928,21008,204,239,281,196,2970,246,3760,225,258,4478,221,239,261,10835,11207,239,196,18409,926,1545,4676,647,209,530,977,13762,309,1365,823,15210,225,977,293,1574,5739,201,16791,7010,201,684,2457,371,4192,68,665,209,3965,198,23714,225,198,22051,218,282,14,12331,2208,283,4607,239,13123,239,1969,12394,201,225,659,218,198,15891,218,198,2374,15444,679,210,1215,390,15494,209,261,17588,390,198,21202,523,955,225,198,239,4936],"12":[22774,239,311,21201,310,209,261,11588,218,282,14,12331,2208,283,301,578,227,697,1696,4014,702,295,218,282,14,196,760,12898,201,225,198,2364,1293,301,1535,227,490,196,3077,12693,418,209,252,317,706,9879,390,3223,214,736,1919,17152,224,1637,201,1360,2875,22218,616,296,11820,20305,371,2439,84,234,11207,3019,8900,221,22179,223,642,573,201,227,1313,642,227,293,6637,221,196,8946,371,307,1414,209,357,1465,201,493,562,3075,22,194,218,282,14,12331,2208,283,360,6637,1896,1399,225,1182,3054,201,218,400,428,562,4683,18,194,311,6600,1124,310,258,6637,221,198,1038,11303,311],"13":[867,198,7484,7362,310,209,261,8240,5360,2892,274,282,14,12331,2208,283,301,196,1044,422,571,2591,218,6286,1999,311,428,511,493,221,310,209,252,13773,257,21384,5135,274,282,14,12331,2208,283,390,699,1921,201,225,1743,7928,390,1490,1365,1669,209,252,238,238,8937,16097,1514,238,238,252,19210,12331,2208,283,258,455,1782,196,12072,22267,1221,307,5064,317,909,19649,221,198,8813,4615,218,342,5599,65,318,4881,1240,201,1573,221,5986,24,209,2881,1221,258,24872,12331,2208,283,201,1211,317,909,19649,268,3301,218,198,4013,24872,319,295,582,1331,610,1178,18183,571,738,209,252,282,14],"14":[12331,2208,283,301,198,3108,1293,218,198,4013,19210,17518,201,20690,201,281,4729,307,300,4340,3945,218,198,2392,4774,263,16455,1839,318,249,22701,1411,209,7666,227,295,4568,201,11527,12061,1080,198,1293,398,504,3242,227,307,951,1579,2850,201,867,5265,257,283,1261,4487,14166,523,2215,201,22956,225,19210,18381,215,282,14,3248,1775,277,18903,201,14300,201,225,471,1080,14415,17518,268,83,5278,218,198,4013,398,504,17228,1040,1994,22708,307,2489,297,633,66,302,201,19912,1061,3119,13525,967,218,3108,1293,311,274,19210,282,14,3248,1775,277,18903,201,14300,310,21468,274,19210,17518,201,20690,209,252],"15":[261,3108,7139,218,19210,12331,2208,283,258,196,10529,7769,3941,307,23601,559,3059,299,85,215,221,5629,209,530,1501,340,5197,2719,4800,396,110,318,920,2719,3434,396,110,338,201,996,681,246,1245,201,9704,311,4067,5549,371,1342,1957,4587,218,9901,787,6012,310,201,445,857,322,225,198,1346,222,395,14403,490,1211,504,1777,209,252,261,1512,1221,274,282,14,12331,2208,283,7676,307,198,9758,225,15124,447,8473,301,239,2364,17152,239,201,445,198,1293,301,471,3991,876,281,198,239,1512,17152,239,209,252,238,3180,5577,4593,238,252,1934,8314,12825,3180,5577,4593,201,17920,201,17920,37]},"attention_mask":{"0":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"1":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"2":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"3":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"4":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"5":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"6":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"7":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"8":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"9":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"10":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"11":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"12":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"13":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"14":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"15":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]},"labels":{"0":[238,19210,12331,2208,283,238,252,19210,12331,2208,283,201,876,281,198,2364,17152,371,1512,17152,201,301,196,1293,218,472,655,204,17152,340,198,2225,2477,4575,201,7662,4116,225,2773,218,198,1885,4116,209,530,301,5952,3130,227,198,1013,17152,201,282,14,196,760,12898,209,530,977,2156,227,196,2591,218,2838,3171,311,1678,221,310,225,196,2223,218,605,21695,311,1086,4336,310,201,225,10043,196,20778,3289,218,23714,209,357,1151,201,198,7162,9879,390,4071,201,669,3179,239,17152,1994,239,263,12897,209,259,761,6395,221,198,3271,201,6453,7144,400,390,3493,307,198,5138,274,578,227,196],"1":[544,806,6066,2753,587,23664,851,229,17469,209,19210,12331,2208,283,301,196,3879,928,21008,204,2970,201,225,301,3991,6637,1919,17152,224,1637,201,3223,1118,198,1038,11303,209,252,238,238,10150,238,238,252,19210,12331,2208,283,301,196,1178,18183,571,220,201,296,196,1825,2591,578,227,2838,10892,311,1678,221,310,225,19186,578,227,502,459,605,21695,311,920,459,1086,4336,310,201,1360,198,7162,9879,6637,221,17152,224,1637,390,3076,1625,459,3075,3171,311,670,459,916,221,310,908,225,11698,796,511,629,459,289,511,289,4890,311,245,511,502,459,493,511,670,4336,310,209,4064,548,18183],"2":[571,738,201,7162,9879,490,196,2804,377,316,559,8605,400,543,2686,14864,221,1688,227,2156,201,221,196,2169,994,5080,68,1162,215,311,24820,1236,310,209,795,977,2100,951,1696,196,544,274,2246,7162,9879,201,445,24202,227,2109,1621,245,459,289,818,274,3097,5503,209,252,261,455,3289,218,224,350,73,343,11051,301,4579,296,196,1178,201,20234,4371,3688,3289,218,23714,209,261,3097,496,301,198,239,1643,17959,239,201,225,499,12466,24880,5571,788,274,1643,5003,6490,440,198,548,301,198,239,2439,280,239,201,400,499,9629,6639,8864,201,225,301,788,274,6205,371,194,10866,198,6490],"3":[209,4891,957,201,198,1344,472,655,301,198,1643,17959,201,225,198,1557,301,198,2439,280,209,252,261,377,316,559,8605,301,2370,4071,2660,201,296,9915,295,767,1130,298,201,225,4578,4052,209,261,1994,6334,3720,296,7162,9879,669,3440,550,12897,209,795,6395,1080,201,221,1151,201,198,1994,24422,4624,1215,6233,197,301,3460,227,196,4279,3794,201,445,198,3794,301,4752,578,307,198,8126,218,12897,201,11701,198,1994,24422,209,252,261,8445,7339,218,282,14,12331,2208,283,301,198,1013,17152,201,19210,196,760,12898,209,261,506,1293,390,1365,1553,201,225,684,293,6982,982,778,1053,201],"4":[1360,15608,2521,390,9718,227,2100,221,198,4520,1211,444,8564,804,410,14358,209,261,506,1293,684,293,8793,307,196,845,218,6498,466,252,261,494,246,5398,218,282,14,196,760,12898,10043,496,371,604,20338,263,198,24534,201,400,390,10581,221,282,14,12331,2208,283,209,252,261,20338,263,198,23714,218,282,14,196,760,12898,390,1994,371,1994,277,22181,201,732,1413,218,282,14,12331,2208,283,390,1969,371,1969,277,22181,209,252,261,24534,218,198,472,655,218,282,14,196,760,12898,301,12960,371,1994,201,732,295,218,282,14,12331,2208,283,301,942,5676,1969,371,1365,8202,1994,209],"5":[252,238,238,5971,7308,238,238,252,14035,282,14,12331,2208,283,3484,3956,13559,646,543,490,8498,227,196,1044,422,571,2591,218,3105,459,5626,22608,311,428,511,245,459,428,511,428,221,310,201,6844,5020,10719,319,196,4370,3199,2892,209,259,761,4183,6395,221,3271,708,196,5527,24820,744,2476,201,2821,4093,301,4947,6144,201,225,196,2804,277,24110,2956,209,261,2476,9651,198,7144,274,578,227,837,2200,201,8326,263,198,5010,201,6386,227,468,4405,343,11051,209,17172,7600,7144,390,1014,227,293,239,20148,1635,239,225,684,293,824,1895,198,544,209,252,261,7144,15638,319,1948,201,225],"6":[198,17469,9408,227,198,1371,3239,756,543,11642,296,198,7153,20898,201,577,2464,263,22762,8222,858,851,209,795,2570,11533,697,24820,745,225,15964,274,916,459,2380,1651,209,929,198,1368,24820,84,201,198,16695,3544,263,196,740,7727,227,198,4047,201,225,3313,745,196,219,2121,229,13643,209,261,16696,390,6678,1809,221,198,4520,201,225,390,7808,876,201,1360,543,390,876,227,293,5992,218,15590,5536,2515,6846,209,530,301,2729,295,669,245,768,6820,221,1621,464,562,802,14272,227,198,219,2121,229,5427,209,1438,543,3484,196,1044,422,571,2591,218,916,1999,311,796,511,5931,221,310],"7":[201,198,16696,3489,444,2515,6846,225,1046,444,4047,4658,209,252,238,238,14720,238,238,252,19210,12331,2208,283,301,824,1896,198,1128,277,2225,2477,4575,340,2518,7984,227,198,14780,225,22054,201,410,867,198,18294,4116,209,530,301,471,1474,221,659,218,198,7662,4116,201,669,6157,340,198,3161,1657,218,4968,563,201,225,1008,669,198,1128,277,1639,2076,218,198,1885,4116,209,261,2518,1942,5892,390,824,221,198,10768,214,74,1907,236,1162,70,74,412,213,225,18700,17840,65,201,3950,198,21427,9743,234,209,252,261,1293,684,293,4695,587,784,1851,269,1320,4344,5892,201,496,6098,1866],"8":[201,225,697,400,490,9068,1166,1309,227,1119,5501,1866,15321,201,4692,1309,227,8552,227,198,1492,5039,209,261,455,218,1021,301,198,1866,218,7162,9879,340,2518,7984,201,400,490,504,3242,227,281,198,239,16501,277,6271,17152,239,209,261,5892,221,198,7662,4116,390,4344,340,1413,221,198,2477,4575,209,261,1244,4344,1866,301,824,221,198,7533,466,12828,340,198,339,419,320,4651,352,69,360,4344,340,1413,5464,221,198,1182,4116,371,1545,7484,209,252,24135,83,490,504,746,227,14275,282,14,12331,2208,283,227,752,2952,201,4583,548,2364,1293,781,281,198,15494,21089,201,24872,224,330],"9":[260,283,209,5507,11329,225,5989,201,496,1079,17152,17469,360,1055,340,6066,2264,346,221,5222,204,197,201,445,198,1293,992,410,1638,1860,817,209,252,238,238,19976,238,238,252,11312,282,14,12331,2208,283,2253,263,198,10077,22975,319,21762,218,796,459,4178,3224,311,796,459,4468,18,1841,310,201,1360,410,7736,15408,702,1834,230,311,5722,1841,310,209,1065,5463,2804,8138,7075,201,781,281,10578,371,2804,11832,201,225,2253,221,12782,371,942,86,1656,201,16062,319,1948,227,7710,209,252,261,7546,218,282,14,12331,2208,283,3223,5886,218,548,219,2121,229,18499,209,1850,1766,6063,1466,201],"10":[230,1894,13799,83,201,3304,159,885,776,201,1322,11207,225,4291,18680,16328,21077,209,252,261,697,472,655,204,17152,1293,19210,12331,2208,283,201,282,14,196,760,12898,225,1853,575,635,83,5746,711,71,10491,390,12202,227,198,697,876,1293,218,198,4926,1091,930,327,247,2236,73,2401,4403,440,198,1293,263,282,14,12331,2208,283,499,410,504,1246,209,252,19210,12331,2208,283,301,22081,227,198,4516,267,1867,4219,7024,201,2035,307,198,14477,1026,5611,375,13489,283,9769,271,738,209,1647,322,301,5431,824,221,1013,7162,9879,201,198,4516,499,669,504,1809,221,19676,282,14,12331,2208,283],"11":[201,756,3119,9025,218,198,8284,307,282,14,196,760,12898,990,410,293,6620,573,209,252,238,238,6870,11391,238,238,252,19210,12331,2208,283,301,13924,239,3879,928,21008,204,239,281,196,2970,246,3760,225,258,4478,221,239,261,10835,11207,239,196,18409,926,1545,4676,647,209,530,977,13762,309,1365,823,15210,225,977,293,1574,5739,201,16791,7010,201,684,2457,371,4192,68,665,209,3965,198,23714,225,198,22051,218,282,14,12331,2208,283,4607,239,13123,239,1969,12394,201,225,659,218,198,15891,218,198,2374,15444,679,210,1215,390,15494,209,261,17588,390,198,21202,523,955,225,198,239,4936],"12":[22774,239,311,21201,310,209,261,11588,218,282,14,12331,2208,283,301,578,227,697,1696,4014,702,295,218,282,14,196,760,12898,201,225,198,2364,1293,301,1535,227,490,196,3077,12693,418,209,252,317,706,9879,390,3223,214,736,1919,17152,224,1637,201,1360,2875,22218,616,296,11820,20305,371,2439,84,234,11207,3019,8900,221,22179,223,642,573,201,227,1313,642,227,293,6637,221,196,8946,371,307,1414,209,357,1465,201,493,562,3075,22,194,218,282,14,12331,2208,283,360,6637,1896,1399,225,1182,3054,201,218,400,428,562,4683,18,194,311,6600,1124,310,258,6637,221,198,1038,11303,311],"13":[867,198,7484,7362,310,209,261,8240,5360,2892,274,282,14,12331,2208,283,301,196,1044,422,571,2591,218,6286,1999,311,428,511,493,221,310,209,252,13773,257,21384,5135,274,282,14,12331,2208,283,390,699,1921,201,225,1743,7928,390,1490,1365,1669,209,252,238,238,8937,16097,1514,238,238,252,19210,12331,2208,283,258,455,1782,196,12072,22267,1221,307,5064,317,909,19649,221,198,8813,4615,218,342,5599,65,318,4881,1240,201,1573,221,5986,24,209,2881,1221,258,24872,12331,2208,283,201,1211,317,909,19649,268,3301,218,198,4013,24872,319,295,582,1331,610,1178,18183,571,738,209,252,282,14],"14":[12331,2208,283,301,198,3108,1293,218,198,4013,19210,17518,201,20690,201,281,4729,307,300,4340,3945,218,198,2392,4774,263,16455,1839,318,249,22701,1411,209,7666,227,295,4568,201,11527,12061,1080,198,1293,398,504,3242,227,307,951,1579,2850,201,867,5265,257,283,1261,4487,14166,523,2215,201,22956,225,19210,18381,215,282,14,3248,1775,277,18903,201,14300,201,225,471,1080,14415,17518,268,83,5278,218,198,4013,398,504,17228,1040,1994,22708,307,2489,297,633,66,302,201,19912,1061,3119,13525,967,218,3108,1293,311,274,19210,282,14,3248,1775,277,18903,201,14300,310,21468,274,19210,17518,201,20690,209,252],"15":[261,3108,7139,218,19210,12331,2208,283,258,196,10529,7769,3941,307,23601,559,3059,299,85,215,221,5629,209,530,1501,340,5197,2719,4800,396,110,318,920,2719,3434,396,110,338,201,996,681,246,1245,201,9704,311,4067,5549,371,1342,1957,4587,218,9901,787,6012,310,201,445,857,322,225,198,1346,222,395,14403,490,1211,504,1777,209,252,261,1512,1221,274,282,14,12331,2208,283,7676,307,198,9758,225,15124,447,8473,301,239,2364,17152,239,201,445,198,1293,301,471,3991,876,281,198,239,1512,17152,239,209,252,238,3180,5577,4593,238,252,1934,8314,12825,3180,5577,4593,201,17920,201,17920,37]}} \ No newline at end of file diff --git a/python/ray/ml/tests/test_huggingface_trainer.py b/python/ray/ml/tests/test_huggingface_trainer.py index a7f6f262d863..4660e4a3ae54 100644 --- a/python/ray/ml/tests/test_huggingface_trainer.py +++ b/python/ray/ml/tests/test_huggingface_trainer.py @@ -9,9 +9,17 @@ from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor from ray.ml.train.integrations.huggingface.huggingface_utils import process_datasets +train_data = """ +{"input_ids":{"0":[238,8576,9441,2987,238,252,4657,74,4762,826,8576,428,466,609,6881,412,204,9441,311,2746,466,10816,168,99,150,192,112,14328,3983,112,4446,94,18288,4446,193,3983,98,3983,22171,95,19,201,6374,209,8576,218,198,3455,1972,428,310,201,5099,3242,227,281,8576,9441,2987,2553,1759,201,301,196,13996,1496,277,2330,1464,674,1898,307,742,3541,225,7514,14,54,719,274,198,4777,15522,209,19895,221,1341,1633,221,1759,201,322,301,198,1368,674,221,198,8576,843,209,2468,1795,223,198,1049,9595,218,13996,225,1563,277,582,6493,281,457,14371,201,198,1422,3373,7452,227,198,455,674,225,4687],"1":[198,239,21976,239,201,196,21657,1680,3773,5591,198,4196,218,4679,427,661,198,3518,1288,220,1051,516,889,3947,1922,2500,225,390,2065,744,872,198,7592,3773,239,1975,251,208,89,22351,239,209,252,261,674,959,1921,221,1462,201,7600,547,196,1178,4753,218,198,630,3591,263,8576,9441,1180,209,1831,322,7568,198,3621,2240,218,198,843,201,322,471,9575,5291,16591,967,201,781,281,1815,198,674,604,10344,1252,274,843,664,3147,320,209,13290,8751,8124,2528,6023,74,235,225,7445,10040,17384,241,11487,8950,857,1835,340,1382,22582,201,1008,296,8576,9441,1180,2436,21134,5337,19463,5161,209,240,1178],"2":[927,218,3776,8650,198,3355,209,261,674,268,83,2511,3472,258,8288,307,1010,268,78,209,252,530,922,296,3096,4354,221,1759,201,225,258,2630,307,857,2746,225,2761,2234,209,929,1634,201,322,1217,13362,4798,201,1008,296,326,5585,4615,221,1389,218,295,544,209,530,258,471,6459,587,8630,225,326,1446,1464,10260,843,209,5023,227,1669,4354,218,8576,9441,1180,201,8576,9441,2987,258,410,18878,201,445,196,6418,8895,20885,296,198,674,268,83,5585,4615,258,1055,221,2474,209,7514,14,54,719,602,1063,227,198,7072,296,198,1921,218,8576,466,8111,447,5474,274,198,4777,493,209,252],"3":[238,238,15114,238,238,252,662,296,1382,7784,9345,9441,1374,201,8576,9441,2987,301,196,13996,1496,277,2330,674,756,1997,1581,1694,218,196,1680,3773,225,1581,540,221,6455,872,5000,1754,209,8410,390,2848,718,5926,1184,277,1093,13659,296,10102,1136,13018,201,296,2081,6041,8040,718,7674,6917,14877,6632,225,8040,718,434,15744,5263,3973,209,261,1538,15606,718,196,843,218,11382,6455,201,5546,13230,5382,281,12903,295,684,293,15086,518,6145,718,225,1486,13196,281,543,390,13230,5382,209,261,2262,227,1140,1422,3371,263,198,7464,11718,8326,263,326,3854,1538,268,83,4156,466,646,496,6938,301,3941,201],"4":[198,548,301,15410,555,227,198,1538,209,13067,6455,201,198,1538,2081,1231,221,196,1306,201,756,2726,684,293,7376,968,225,1136,5168,6395,209,24281,198,906,1422,6455,390,1136,277,2737,747,6455,13516,227,1579,5860,1718,209,929,198,674,268,83,7131,201,3079,3411,390,13230,5382,201,766,218,642,1763,196,4014,6669,702,1413,824,221,198,1231,218,198,674,209,1415,390,471,2610,1076,1388,2377,3130,227,198,674,268,83,506,906,3641,828,201,1360,543,1581,196,1365,2978,1496,209,252,261,674,268,83,2094,1232,201,198,1028,73,52,58,1232,201,301,3493,547,3660,340,7784,9345,9441,209],"5":[1303,6455,201,1997,6322,1140,3773,1919,196,1362,277,1004,8550,218,198,11332,7464,466,2109,196,1136,301,3941,201,198,1538,5757,198,1136,1118,198,11332,221,1368,277,1432,209,240,1136,684,669,733,2109,503,277,1407,201,445,2081,684,293,4929,5291,6398,319,198,13876,218,548,2081,268,6398,209,5056,1136,499,196,1606,225,3924,218,3672,3069,307,444,11840,315,15041,209,3264,227,2849,2081,684,293,3423,227,196,1314,4490,209,1303,6493,201,2081,1170,798,573,1358,3851,16048,227,642,201,781,281,444,3668,2423,311,282,48,310,6506,1669,371,812,11285,573,307,5000,3704,209,5056,1136,499,2737,239],"6":[8303,2453,83,239,201,8436,5025,227,1140,1136,209,1065,390,4695,587,239,12021,8303,2453,239,201,400,390,3295,365,8436,295,3052,434,222,8830,11270,7936,9067,380,307,198,1422,225,684,2292,1546,371,1017,21071,196,1136,201,225,239,3455,8303,2453,83,239,201,400,390,8498,1895,198,674,225,3365,3326,1958,675,227,196,1136,209,2000,2627,3455,8303,2453,83,201,1140,1136,499,196,5025,239,8119,236,614,239,201,196,14628,277,1356,11949,6230,295,684,293,788,227,13970,225,3261,1579,8436,209,16500,471,490,7447,1531,3004,295,3326,642,8535,1958,11069,263,198,11332,466,8625,684,1980,365,239,4495],"7":[3683,239,225,2341,1118,198,11332,1582,897,234,1236,342,11840,6957,16326,201,198,1136,789,9878,684,9627,587,468,239,8576,3408,239,225,1638,896,23892,201,732,2293,5336,684,4849,5291,5000,2726,296,468,2287,8263,209,252,13441,1355,390,4695,587,1160,6555,466,23804,201,808,722,84,228,19466,201,12785,201,7096,320,225,21755,18607,853,209,13441,19466,684,12454,6555,307,7048,444,3423,8263,209,11070,223,1249,1990,410,5993,11950,198,285,1201,4825,732,221,196,1382,1249,209,1585,2495,221,2094,201,3335,2423,390,3495,227,198,5860,201,400,390,6908,587,1160,1579,13376,6136,307,198,2068,5860,201,196,3135],"8":[1317,223,340,989,1374,268,3588,218,10078,5751,227,1579,3773,4325,209,252,238,238,6851,238,238,252,261,674,3544,1417,661,198,3518,1288,220,1051,209,4679,406,2205,13619,3775,18,201,471,876,281,239,261,21976,239,201,390,196,21657,1680,3773,3808,218,18490,201,4899,541,199,690,201,225,1680,555,16914,2821,1563,2850,390,1843,620,340,198,3517,225,817,200,4614,3242,227,307,3473,209,10730,665,307,198,4679,406,1680,227,889,198,659,8110,6455,295,198,22724,2205,225,18991,1170,410,804,201,543,390,15994,578,227,198,5214,201,20891,307,444,22232,79,201,18766,4209,1531,21157,201,3457,239],"9":[14996,789,2989,209,239,261,697,906,2081,390,1524,14,23,8625,2408,2513,201,326,3006,4417,16271,539,7184,218,15101,516,9461,227,24374,1920,440,8607,1524,14,17,2293,5336,201,196,2476,5447,67,3294,2287,5253,16021,516,17113,14241,872,198,8576,516,3443,468,1398,440,225,1524,14,17,19,297,793,65,20714,215,201,196,11573,536,197,88,204,2246,3063,516,301,4796,2616,5744,196,20713,218,198,8576,209,9332,296,444,5920,5860,1718,201,1021,697,390,10455,227,2172,872,196,12040,7592,3773,876,281,1975,251,379,22351,201,6355,218,3223,5447,67,3294,3478,209,252,662,198,21976,4614,804,410,2523],"10":[201,198,3578,5080,1193,675,218,198,4679,406,2205,21878,198,3301,218,21337,2903,14454,221,1688,227,5355,642,263,6455,295,602,7936,1445,4679,427,7843,3132,221,198,1060,209,1831,319,1696,526,2125,227,444,4517,201,781,281,196,1991,13719,587,7592,3833,201,548,4953,3522,2931,1718,218,198,3775,18,207,1428,16464,209,1539,781,2541,201,315,283,2809,201,4228,749,554,16361,295,331,3669,675,342,1753,225,770,2220,587,198,10100,218,1975,251,379,22351,201,6386,227,198,9541,218,5447,67,3294,5668,3412,307,444,5164,201,8954,971,209,775,198,1049,582,201,2377,1347,4679,406,2205,3683,2341,227],"11":[1843,626,198,21976,221,1688,227,6259,444,1059,7969,209,282,15902,307,857,13855,225,9782,201,225,4439,296,198,3752,218,196,836,3547,1347,444,10100,201,198,3775,18,207,11695,1206,2341,227,3292,3230,10852,732,319,198,1049,582,2172,227,1546,198,4679,406,1060,4585,209,795,4471,1040,198,21976,268,83,14643,4417,201,3790,5491,22737,201,516,398,504,4755,699,2010,5125,201,301,15218,227,198,2867,1047,218,12106,71,337,90,221,1688,227,1474,2525,377,200,199,761,198,4239,89,3478,225,20416,198,1563,836,3547,201,198,4679,406,1796,295,398,7184,8625,218,3430,599,209,252,2304,264,1309,227],"12":[1021,1988,201,225,7259,1309,227,198,1214,5662,221,23734,4679,427,2433,320,2375,198,726,218,198,1060,296,198,4444,201,198,21976,390,3946,196,7944,1676,281,196,5860,221,198,4679,406,2205,2472,702,5228,281,326,13483,21340,2033,209,795,301,1311,277,3364,201,1650,201,281,1149,19033,15786,268,83,3441,201,8954,971,225,1975,251,379,22351,2341,227,1980,365,326,4479,7784,2920,2157,8263,1347,198,4444,201,4755,3947,307,444,5698,17446,209,9663,198,1117,218,19033,15786,371,198,6241,227,8223,3230,221,198,1060,296,4679,427,201,322,301,8954,971,268,83,1244,677,1940,6368,221,5290,196,664,5447],"13":[67,3294,4196,209,662,326,4579,4679,406,2033,896,5021,198,4444,1318,1149,198,506,8947,268,12199,277,1739,602,9684,14048,444,664,11418,4206,201,8625,8420,227,2109,695,1445,342,5860,198,21976,201,9271,22737,227,1343,1920,225,610,699,342,1434,281,2196,277,221,277,2429,209,7243,15215,17157,227,7235,548,702,3230,201,198,3775,18,207,21955,8954,971,225,1503,24815,198,7784,2920,8263,209,5056,2541,862,6084,444,4142,6294,221,1688,227,4555,444,4658,326,361,209,252,238,238,5823,238,238,252,19611,630,274,8576,9441,2987,959,550,1921,2393,263,8576,9441,1180,221,989,1462,201,296,1971,1921,3053],"14":[4691,550,526,209,261,2436,218,8576,9441,1180,201,21134,5337,19463,5161,201,1835,227,295,1496,274,8576,9441,2987,209,5823,630,1260,3032,496,544,209,929,198,1634,218,8576,9441,1180,201,198,3154,1260,196,1976,319,857,198,1737,3092,274,198,674,225,1241,543,2512,227,804,1533,274,198,843,209,4064,457,8472,201,8576,9441,2987,258,1898,274,4777,15522,466,526,258,1309,227,198,927,12182,227,1308,415,198,16590,2054,274,8576,9441,1180,201,225,543,398,410,3029,578,296,198,239,12252,239,3213,295,602,13884,196,664,6494,274,198,4777,428,209,11929,1006,221,326,2692,201,322,258,1821,295],"15":[198,1921,927,1535,8576,9441,2987,227,293,198,843,268,455,4099,7056,466,732,8576,9441,1180,398,3034,196,1178,3126,218,5429,225,12738,661,1921,1309,227,198,5770,2341,201,198,1368,674,1706,642,196,6241,227,7127,2451,198,1386,2773,218,8576,9441,1180,1309,227,812,263,198,1049,5770,209,357,1513,227,742,3541,3154,340,198,1382,1374,201,1921,630,258,471,8650,307,7514,14,54,719,14,261,1446,13395,258,1617,15415,9455,558,9884,706,69,201,732,198,3355,258,1617,307,11856,1652,9455,12645,74,5435,201,11490,20317,12329,10190,201,384,424,11300,21187,10198,201,742,14619,9765,368,5161,225,10830,288]},"attention_mask":{"0":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"1":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"2":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"3":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"4":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"5":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"6":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"7":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"8":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"9":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"10":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"11":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"12":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"13":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"14":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"15":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]},"labels":{"0":[238,8576,9441,2987,238,252,4657,74,4762,826,8576,428,466,609,6881,412,204,9441,311,2746,466,10816,168,99,150,192,112,14328,3983,112,4446,94,18288,4446,193,3983,98,3983,22171,95,19,201,6374,209,8576,218,198,3455,1972,428,310,201,5099,3242,227,281,8576,9441,2987,2553,1759,201,301,196,13996,1496,277,2330,1464,674,1898,307,742,3541,225,7514,14,54,719,274,198,4777,15522,209,19895,221,1341,1633,221,1759,201,322,301,198,1368,674,221,198,8576,843,209,2468,1795,223,198,1049,9595,218,13996,225,1563,277,582,6493,281,457,14371,201,198,1422,3373,7452,227,198,455,674,225,4687],"1":[198,239,21976,239,201,196,21657,1680,3773,5591,198,4196,218,4679,427,661,198,3518,1288,220,1051,516,889,3947,1922,2500,225,390,2065,744,872,198,7592,3773,239,1975,251,208,89,22351,239,209,252,261,674,959,1921,221,1462,201,7600,547,196,1178,4753,218,198,630,3591,263,8576,9441,1180,209,1831,322,7568,198,3621,2240,218,198,843,201,322,471,9575,5291,16591,967,201,781,281,1815,198,674,604,10344,1252,274,843,664,3147,320,209,13290,8751,8124,2528,6023,74,235,225,7445,10040,17384,241,11487,8950,857,1835,340,1382,22582,201,1008,296,8576,9441,1180,2436,21134,5337,19463,5161,209,240,1178],"2":[927,218,3776,8650,198,3355,209,261,674,268,83,2511,3472,258,8288,307,1010,268,78,209,252,530,922,296,3096,4354,221,1759,201,225,258,2630,307,857,2746,225,2761,2234,209,929,1634,201,322,1217,13362,4798,201,1008,296,326,5585,4615,221,1389,218,295,544,209,530,258,471,6459,587,8630,225,326,1446,1464,10260,843,209,5023,227,1669,4354,218,8576,9441,1180,201,8576,9441,2987,258,410,18878,201,445,196,6418,8895,20885,296,198,674,268,83,5585,4615,258,1055,221,2474,209,7514,14,54,719,602,1063,227,198,7072,296,198,1921,218,8576,466,8111,447,5474,274,198,4777,493,209,252],"3":[238,238,15114,238,238,252,662,296,1382,7784,9345,9441,1374,201,8576,9441,2987,301,196,13996,1496,277,2330,674,756,1997,1581,1694,218,196,1680,3773,225,1581,540,221,6455,872,5000,1754,209,8410,390,2848,718,5926,1184,277,1093,13659,296,10102,1136,13018,201,296,2081,6041,8040,718,7674,6917,14877,6632,225,8040,718,434,15744,5263,3973,209,261,1538,15606,718,196,843,218,11382,6455,201,5546,13230,5382,281,12903,295,684,293,15086,518,6145,718,225,1486,13196,281,543,390,13230,5382,209,261,2262,227,1140,1422,3371,263,198,7464,11718,8326,263,326,3854,1538,268,83,4156,466,646,496,6938,301,3941,201],"4":[198,548,301,15410,555,227,198,1538,209,13067,6455,201,198,1538,2081,1231,221,196,1306,201,756,2726,684,293,7376,968,225,1136,5168,6395,209,24281,198,906,1422,6455,390,1136,277,2737,747,6455,13516,227,1579,5860,1718,209,929,198,674,268,83,7131,201,3079,3411,390,13230,5382,201,766,218,642,1763,196,4014,6669,702,1413,824,221,198,1231,218,198,674,209,1415,390,471,2610,1076,1388,2377,3130,227,198,674,268,83,506,906,3641,828,201,1360,543,1581,196,1365,2978,1496,209,252,261,674,268,83,2094,1232,201,198,1028,73,52,58,1232,201,301,3493,547,3660,340,7784,9345,9441,209],"5":[1303,6455,201,1997,6322,1140,3773,1919,196,1362,277,1004,8550,218,198,11332,7464,466,2109,196,1136,301,3941,201,198,1538,5757,198,1136,1118,198,11332,221,1368,277,1432,209,240,1136,684,669,733,2109,503,277,1407,201,445,2081,684,293,4929,5291,6398,319,198,13876,218,548,2081,268,6398,209,5056,1136,499,196,1606,225,3924,218,3672,3069,307,444,11840,315,15041,209,3264,227,2849,2081,684,293,3423,227,196,1314,4490,209,1303,6493,201,2081,1170,798,573,1358,3851,16048,227,642,201,781,281,444,3668,2423,311,282,48,310,6506,1669,371,812,11285,573,307,5000,3704,209,5056,1136,499,2737,239],"6":[8303,2453,83,239,201,8436,5025,227,1140,1136,209,1065,390,4695,587,239,12021,8303,2453,239,201,400,390,3295,365,8436,295,3052,434,222,8830,11270,7936,9067,380,307,198,1422,225,684,2292,1546,371,1017,21071,196,1136,201,225,239,3455,8303,2453,83,239,201,400,390,8498,1895,198,674,225,3365,3326,1958,675,227,196,1136,209,2000,2627,3455,8303,2453,83,201,1140,1136,499,196,5025,239,8119,236,614,239,201,196,14628,277,1356,11949,6230,295,684,293,788,227,13970,225,3261,1579,8436,209,16500,471,490,7447,1531,3004,295,3326,642,8535,1958,11069,263,198,11332,466,8625,684,1980,365,239,4495],"7":[3683,239,225,2341,1118,198,11332,1582,897,234,1236,342,11840,6957,16326,201,198,1136,789,9878,684,9627,587,468,239,8576,3408,239,225,1638,896,23892,201,732,2293,5336,684,4849,5291,5000,2726,296,468,2287,8263,209,252,13441,1355,390,4695,587,1160,6555,466,23804,201,808,722,84,228,19466,201,12785,201,7096,320,225,21755,18607,853,209,13441,19466,684,12454,6555,307,7048,444,3423,8263,209,11070,223,1249,1990,410,5993,11950,198,285,1201,4825,732,221,196,1382,1249,209,1585,2495,221,2094,201,3335,2423,390,3495,227,198,5860,201,400,390,6908,587,1160,1579,13376,6136,307,198,2068,5860,201,196,3135],"8":[1317,223,340,989,1374,268,3588,218,10078,5751,227,1579,3773,4325,209,252,238,238,6851,238,238,252,261,674,3544,1417,661,198,3518,1288,220,1051,209,4679,406,2205,13619,3775,18,201,471,876,281,239,261,21976,239,201,390,196,21657,1680,3773,3808,218,18490,201,4899,541,199,690,201,225,1680,555,16914,2821,1563,2850,390,1843,620,340,198,3517,225,817,200,4614,3242,227,307,3473,209,10730,665,307,198,4679,406,1680,227,889,198,659,8110,6455,295,198,22724,2205,225,18991,1170,410,804,201,543,390,15994,578,227,198,5214,201,20891,307,444,22232,79,201,18766,4209,1531,21157,201,3457,239],"9":[14996,789,2989,209,239,261,697,906,2081,390,1524,14,23,8625,2408,2513,201,326,3006,4417,16271,539,7184,218,15101,516,9461,227,24374,1920,440,8607,1524,14,17,2293,5336,201,196,2476,5447,67,3294,2287,5253,16021,516,17113,14241,872,198,8576,516,3443,468,1398,440,225,1524,14,17,19,297,793,65,20714,215,201,196,11573,536,197,88,204,2246,3063,516,301,4796,2616,5744,196,20713,218,198,8576,209,9332,296,444,5920,5860,1718,201,1021,697,390,10455,227,2172,872,196,12040,7592,3773,876,281,1975,251,379,22351,201,6355,218,3223,5447,67,3294,3478,209,252,662,198,21976,4614,804,410,2523],"10":[201,198,3578,5080,1193,675,218,198,4679,406,2205,21878,198,3301,218,21337,2903,14454,221,1688,227,5355,642,263,6455,295,602,7936,1445,4679,427,7843,3132,221,198,1060,209,1831,319,1696,526,2125,227,444,4517,201,781,281,196,1991,13719,587,7592,3833,201,548,4953,3522,2931,1718,218,198,3775,18,207,1428,16464,209,1539,781,2541,201,315,283,2809,201,4228,749,554,16361,295,331,3669,675,342,1753,225,770,2220,587,198,10100,218,1975,251,379,22351,201,6386,227,198,9541,218,5447,67,3294,5668,3412,307,444,5164,201,8954,971,209,775,198,1049,582,201,2377,1347,4679,406,2205,3683,2341,227],"11":[1843,626,198,21976,221,1688,227,6259,444,1059,7969,209,282,15902,307,857,13855,225,9782,201,225,4439,296,198,3752,218,196,836,3547,1347,444,10100,201,198,3775,18,207,11695,1206,2341,227,3292,3230,10852,732,319,198,1049,582,2172,227,1546,198,4679,406,1060,4585,209,795,4471,1040,198,21976,268,83,14643,4417,201,3790,5491,22737,201,516,398,504,4755,699,2010,5125,201,301,15218,227,198,2867,1047,218,12106,71,337,90,221,1688,227,1474,2525,377,200,199,761,198,4239,89,3478,225,20416,198,1563,836,3547,201,198,4679,406,1796,295,398,7184,8625,218,3430,599,209,252,2304,264,1309,227],"12":[1021,1988,201,225,7259,1309,227,198,1214,5662,221,23734,4679,427,2433,320,2375,198,726,218,198,1060,296,198,4444,201,198,21976,390,3946,196,7944,1676,281,196,5860,221,198,4679,406,2205,2472,702,5228,281,326,13483,21340,2033,209,795,301,1311,277,3364,201,1650,201,281,1149,19033,15786,268,83,3441,201,8954,971,225,1975,251,379,22351,2341,227,1980,365,326,4479,7784,2920,2157,8263,1347,198,4444,201,4755,3947,307,444,5698,17446,209,9663,198,1117,218,19033,15786,371,198,6241,227,8223,3230,221,198,1060,296,4679,427,201,322,301,8954,971,268,83,1244,677,1940,6368,221,5290,196,664,5447],"13":[67,3294,4196,209,662,326,4579,4679,406,2033,896,5021,198,4444,1318,1149,198,506,8947,268,12199,277,1739,602,9684,14048,444,664,11418,4206,201,8625,8420,227,2109,695,1445,342,5860,198,21976,201,9271,22737,227,1343,1920,225,610,699,342,1434,281,2196,277,221,277,2429,209,7243,15215,17157,227,7235,548,702,3230,201,198,3775,18,207,21955,8954,971,225,1503,24815,198,7784,2920,8263,209,5056,2541,862,6084,444,4142,6294,221,1688,227,4555,444,4658,326,361,209,252,238,238,5823,238,238,252,19611,630,274,8576,9441,2987,959,550,1921,2393,263,8576,9441,1180,221,989,1462,201,296,1971,1921,3053],"14":[4691,550,526,209,261,2436,218,8576,9441,1180,201,21134,5337,19463,5161,201,1835,227,295,1496,274,8576,9441,2987,209,5823,630,1260,3032,496,544,209,929,198,1634,218,8576,9441,1180,201,198,3154,1260,196,1976,319,857,198,1737,3092,274,198,674,225,1241,543,2512,227,804,1533,274,198,843,209,4064,457,8472,201,8576,9441,2987,258,1898,274,4777,15522,466,526,258,1309,227,198,927,12182,227,1308,415,198,16590,2054,274,8576,9441,1180,201,225,543,398,410,3029,578,296,198,239,12252,239,3213,295,602,13884,196,664,6494,274,198,4777,428,209,11929,1006,221,326,2692,201,322,258,1821,295],"15":[198,1921,927,1535,8576,9441,2987,227,293,198,843,268,455,4099,7056,466,732,8576,9441,1180,398,3034,196,1178,3126,218,5429,225,12738,661,1921,1309,227,198,5770,2341,201,198,1368,674,1706,642,196,6241,227,7127,2451,198,1386,2773,218,8576,9441,1180,1309,227,812,263,198,1049,5770,209,357,1513,227,742,3541,3154,340,198,1382,1374,201,1921,630,258,471,8650,307,7514,14,54,719,14,261,1446,13395,258,1617,15415,9455,558,9884,706,69,201,732,198,3355,258,1617,307,11856,1652,9455,12645,74,5435,201,11490,20317,12329,10190,201,384,424,11300,21187,10198,201,742,14619,9765,368,5161,225,10830,288]}} +""" + +validation_data = """ +{"input_ids":{"0":[238,19210,12331,2208,283,238,252,19210,12331,2208,283,201,876,281,198,2364,17152,371,1512,17152,201,301,196,1293,218,472,655,204,17152,340,198,2225,2477,4575,201,7662,4116,225,2773,218,198,1885,4116,209,530,301,5952,3130,227,198,1013,17152,201,282,14,196,760,12898,209,530,977,2156,227,196,2591,218,2838,3171,311,1678,221,310,225,196,2223,218,605,21695,311,1086,4336,310,201,225,10043,196,20778,3289,218,23714,209,357,1151,201,198,7162,9879,390,4071,201,669,3179,239,17152,1994,239,263,12897,209,259,761,6395,221,198,3271,201,6453,7144,400,390,3493,307,198,5138,274,578,227,196],"1":[544,806,6066,2753,587,23664,851,229,17469,209,19210,12331,2208,283,301,196,3879,928,21008,204,2970,201,225,301,3991,6637,1919,17152,224,1637,201,3223,1118,198,1038,11303,209,252,238,238,10150,238,238,252,19210,12331,2208,283,301,196,1178,18183,571,220,201,296,196,1825,2591,578,227,2838,10892,311,1678,221,310,225,19186,578,227,502,459,605,21695,311,920,459,1086,4336,310,201,1360,198,7162,9879,6637,221,17152,224,1637,390,3076,1625,459,3075,3171,311,670,459,916,221,310,908,225,11698,796,511,629,459,289,511,289,4890,311,245,511,502,459,493,511,670,4336,310,209,4064,548,18183],"2":[571,738,201,7162,9879,490,196,2804,377,316,559,8605,400,543,2686,14864,221,1688,227,2156,201,221,196,2169,994,5080,68,1162,215,311,24820,1236,310,209,795,977,2100,951,1696,196,544,274,2246,7162,9879,201,445,24202,227,2109,1621,245,459,289,818,274,3097,5503,209,252,261,455,3289,218,224,350,73,343,11051,301,4579,296,196,1178,201,20234,4371,3688,3289,218,23714,209,261,3097,496,301,198,239,1643,17959,239,201,225,499,12466,24880,5571,788,274,1643,5003,6490,440,198,548,301,198,239,2439,280,239,201,400,499,9629,6639,8864,201,225,301,788,274,6205,371,194,10866,198,6490],"3":[209,4891,957,201,198,1344,472,655,301,198,1643,17959,201,225,198,1557,301,198,2439,280,209,252,261,377,316,559,8605,301,2370,4071,2660,201,296,9915,295,767,1130,298,201,225,4578,4052,209,261,1994,6334,3720,296,7162,9879,669,3440,550,12897,209,795,6395,1080,201,221,1151,201,198,1994,24422,4624,1215,6233,197,301,3460,227,196,4279,3794,201,445,198,3794,301,4752,578,307,198,8126,218,12897,201,11701,198,1994,24422,209,252,261,8445,7339,218,282,14,12331,2208,283,301,198,1013,17152,201,19210,196,760,12898,209,261,506,1293,390,1365,1553,201,225,684,293,6982,982,778,1053,201],"4":[1360,15608,2521,390,9718,227,2100,221,198,4520,1211,444,8564,804,410,14358,209,261,506,1293,684,293,8793,307,196,845,218,6498,466,252,261,494,246,5398,218,282,14,196,760,12898,10043,496,371,604,20338,263,198,24534,201,400,390,10581,221,282,14,12331,2208,283,209,252,261,20338,263,198,23714,218,282,14,196,760,12898,390,1994,371,1994,277,22181,201,732,1413,218,282,14,12331,2208,283,390,1969,371,1969,277,22181,209,252,261,24534,218,198,472,655,218,282,14,196,760,12898,301,12960,371,1994,201,732,295,218,282,14,12331,2208,283,301,942,5676,1969,371,1365,8202,1994,209],"5":[252,238,238,5971,7308,238,238,252,14035,282,14,12331,2208,283,3484,3956,13559,646,543,490,8498,227,196,1044,422,571,2591,218,3105,459,5626,22608,311,428,511,245,459,428,511,428,221,310,201,6844,5020,10719,319,196,4370,3199,2892,209,259,761,4183,6395,221,3271,708,196,5527,24820,744,2476,201,2821,4093,301,4947,6144,201,225,196,2804,277,24110,2956,209,261,2476,9651,198,7144,274,578,227,837,2200,201,8326,263,198,5010,201,6386,227,468,4405,343,11051,209,17172,7600,7144,390,1014,227,293,239,20148,1635,239,225,684,293,824,1895,198,544,209,252,261,7144,15638,319,1948,201,225],"6":[198,17469,9408,227,198,1371,3239,756,543,11642,296,198,7153,20898,201,577,2464,263,22762,8222,858,851,209,795,2570,11533,697,24820,745,225,15964,274,916,459,2380,1651,209,929,198,1368,24820,84,201,198,16695,3544,263,196,740,7727,227,198,4047,201,225,3313,745,196,219,2121,229,13643,209,261,16696,390,6678,1809,221,198,4520,201,225,390,7808,876,201,1360,543,390,876,227,293,5992,218,15590,5536,2515,6846,209,530,301,2729,295,669,245,768,6820,221,1621,464,562,802,14272,227,198,219,2121,229,5427,209,1438,543,3484,196,1044,422,571,2591,218,916,1999,311,796,511,5931,221,310],"7":[201,198,16696,3489,444,2515,6846,225,1046,444,4047,4658,209,252,238,238,14720,238,238,252,19210,12331,2208,283,301,824,1896,198,1128,277,2225,2477,4575,340,2518,7984,227,198,14780,225,22054,201,410,867,198,18294,4116,209,530,301,471,1474,221,659,218,198,7662,4116,201,669,6157,340,198,3161,1657,218,4968,563,201,225,1008,669,198,1128,277,1639,2076,218,198,1885,4116,209,261,2518,1942,5892,390,824,221,198,10768,214,74,1907,236,1162,70,74,412,213,225,18700,17840,65,201,3950,198,21427,9743,234,209,252,261,1293,684,293,4695,587,784,1851,269,1320,4344,5892,201,496,6098,1866],"8":[201,225,697,400,490,9068,1166,1309,227,1119,5501,1866,15321,201,4692,1309,227,8552,227,198,1492,5039,209,261,455,218,1021,301,198,1866,218,7162,9879,340,2518,7984,201,400,490,504,3242,227,281,198,239,16501,277,6271,17152,239,209,261,5892,221,198,7662,4116,390,4344,340,1413,221,198,2477,4575,209,261,1244,4344,1866,301,824,221,198,7533,466,12828,340,198,339,419,320,4651,352,69,360,4344,340,1413,5464,221,198,1182,4116,371,1545,7484,209,252,24135,83,490,504,746,227,14275,282,14,12331,2208,283,227,752,2952,201,4583,548,2364,1293,781,281,198,15494,21089,201,24872,224,330],"9":[260,283,209,5507,11329,225,5989,201,496,1079,17152,17469,360,1055,340,6066,2264,346,221,5222,204,197,201,445,198,1293,992,410,1638,1860,817,209,252,238,238,19976,238,238,252,11312,282,14,12331,2208,283,2253,263,198,10077,22975,319,21762,218,796,459,4178,3224,311,796,459,4468,18,1841,310,201,1360,410,7736,15408,702,1834,230,311,5722,1841,310,209,1065,5463,2804,8138,7075,201,781,281,10578,371,2804,11832,201,225,2253,221,12782,371,942,86,1656,201,16062,319,1948,227,7710,209,252,261,7546,218,282,14,12331,2208,283,3223,5886,218,548,219,2121,229,18499,209,1850,1766,6063,1466,201],"10":[230,1894,13799,83,201,3304,159,885,776,201,1322,11207,225,4291,18680,16328,21077,209,252,261,697,472,655,204,17152,1293,19210,12331,2208,283,201,282,14,196,760,12898,225,1853,575,635,83,5746,711,71,10491,390,12202,227,198,697,876,1293,218,198,4926,1091,930,327,247,2236,73,2401,4403,440,198,1293,263,282,14,12331,2208,283,499,410,504,1246,209,252,19210,12331,2208,283,301,22081,227,198,4516,267,1867,4219,7024,201,2035,307,198,14477,1026,5611,375,13489,283,9769,271,738,209,1647,322,301,5431,824,221,1013,7162,9879,201,198,4516,499,669,504,1809,221,19676,282,14,12331,2208,283],"11":[201,756,3119,9025,218,198,8284,307,282,14,196,760,12898,990,410,293,6620,573,209,252,238,238,6870,11391,238,238,252,19210,12331,2208,283,301,13924,239,3879,928,21008,204,239,281,196,2970,246,3760,225,258,4478,221,239,261,10835,11207,239,196,18409,926,1545,4676,647,209,530,977,13762,309,1365,823,15210,225,977,293,1574,5739,201,16791,7010,201,684,2457,371,4192,68,665,209,3965,198,23714,225,198,22051,218,282,14,12331,2208,283,4607,239,13123,239,1969,12394,201,225,659,218,198,15891,218,198,2374,15444,679,210,1215,390,15494,209,261,17588,390,198,21202,523,955,225,198,239,4936],"12":[22774,239,311,21201,310,209,261,11588,218,282,14,12331,2208,283,301,578,227,697,1696,4014,702,295,218,282,14,196,760,12898,201,225,198,2364,1293,301,1535,227,490,196,3077,12693,418,209,252,317,706,9879,390,3223,214,736,1919,17152,224,1637,201,1360,2875,22218,616,296,11820,20305,371,2439,84,234,11207,3019,8900,221,22179,223,642,573,201,227,1313,642,227,293,6637,221,196,8946,371,307,1414,209,357,1465,201,493,562,3075,22,194,218,282,14,12331,2208,283,360,6637,1896,1399,225,1182,3054,201,218,400,428,562,4683,18,194,311,6600,1124,310,258,6637,221,198,1038,11303,311],"13":[867,198,7484,7362,310,209,261,8240,5360,2892,274,282,14,12331,2208,283,301,196,1044,422,571,2591,218,6286,1999,311,428,511,493,221,310,209,252,13773,257,21384,5135,274,282,14,12331,2208,283,390,699,1921,201,225,1743,7928,390,1490,1365,1669,209,252,238,238,8937,16097,1514,238,238,252,19210,12331,2208,283,258,455,1782,196,12072,22267,1221,307,5064,317,909,19649,221,198,8813,4615,218,342,5599,65,318,4881,1240,201,1573,221,5986,24,209,2881,1221,258,24872,12331,2208,283,201,1211,317,909,19649,268,3301,218,198,4013,24872,319,295,582,1331,610,1178,18183,571,738,209,252,282,14],"14":[12331,2208,283,301,198,3108,1293,218,198,4013,19210,17518,201,20690,201,281,4729,307,300,4340,3945,218,198,2392,4774,263,16455,1839,318,249,22701,1411,209,7666,227,295,4568,201,11527,12061,1080,198,1293,398,504,3242,227,307,951,1579,2850,201,867,5265,257,283,1261,4487,14166,523,2215,201,22956,225,19210,18381,215,282,14,3248,1775,277,18903,201,14300,201,225,471,1080,14415,17518,268,83,5278,218,198,4013,398,504,17228,1040,1994,22708,307,2489,297,633,66,302,201,19912,1061,3119,13525,967,218,3108,1293,311,274,19210,282,14,3248,1775,277,18903,201,14300,310,21468,274,19210,17518,201,20690,209,252],"15":[261,3108,7139,218,19210,12331,2208,283,258,196,10529,7769,3941,307,23601,559,3059,299,85,215,221,5629,209,530,1501,340,5197,2719,4800,396,110,318,920,2719,3434,396,110,338,201,996,681,246,1245,201,9704,311,4067,5549,371,1342,1957,4587,218,9901,787,6012,310,201,445,857,322,225,198,1346,222,395,14403,490,1211,504,1777,209,252,261,1512,1221,274,282,14,12331,2208,283,7676,307,198,9758,225,15124,447,8473,301,239,2364,17152,239,201,445,198,1293,301,471,3991,876,281,198,239,1512,17152,239,209,252,238,3180,5577,4593,238,252,1934,8314,12825,3180,5577,4593,201,17920,201,17920,37]},"attention_mask":{"0":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"1":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"2":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"3":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"4":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"5":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"6":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"7":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"8":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"9":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"10":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"11":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"12":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"13":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"14":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"15":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]},"labels":{"0":[238,19210,12331,2208,283,238,252,19210,12331,2208,283,201,876,281,198,2364,17152,371,1512,17152,201,301,196,1293,218,472,655,204,17152,340,198,2225,2477,4575,201,7662,4116,225,2773,218,198,1885,4116,209,530,301,5952,3130,227,198,1013,17152,201,282,14,196,760,12898,209,530,977,2156,227,196,2591,218,2838,3171,311,1678,221,310,225,196,2223,218,605,21695,311,1086,4336,310,201,225,10043,196,20778,3289,218,23714,209,357,1151,201,198,7162,9879,390,4071,201,669,3179,239,17152,1994,239,263,12897,209,259,761,6395,221,198,3271,201,6453,7144,400,390,3493,307,198,5138,274,578,227,196],"1":[544,806,6066,2753,587,23664,851,229,17469,209,19210,12331,2208,283,301,196,3879,928,21008,204,2970,201,225,301,3991,6637,1919,17152,224,1637,201,3223,1118,198,1038,11303,209,252,238,238,10150,238,238,252,19210,12331,2208,283,301,196,1178,18183,571,220,201,296,196,1825,2591,578,227,2838,10892,311,1678,221,310,225,19186,578,227,502,459,605,21695,311,920,459,1086,4336,310,201,1360,198,7162,9879,6637,221,17152,224,1637,390,3076,1625,459,3075,3171,311,670,459,916,221,310,908,225,11698,796,511,629,459,289,511,289,4890,311,245,511,502,459,493,511,670,4336,310,209,4064,548,18183],"2":[571,738,201,7162,9879,490,196,2804,377,316,559,8605,400,543,2686,14864,221,1688,227,2156,201,221,196,2169,994,5080,68,1162,215,311,24820,1236,310,209,795,977,2100,951,1696,196,544,274,2246,7162,9879,201,445,24202,227,2109,1621,245,459,289,818,274,3097,5503,209,252,261,455,3289,218,224,350,73,343,11051,301,4579,296,196,1178,201,20234,4371,3688,3289,218,23714,209,261,3097,496,301,198,239,1643,17959,239,201,225,499,12466,24880,5571,788,274,1643,5003,6490,440,198,548,301,198,239,2439,280,239,201,400,499,9629,6639,8864,201,225,301,788,274,6205,371,194,10866,198,6490],"3":[209,4891,957,201,198,1344,472,655,301,198,1643,17959,201,225,198,1557,301,198,2439,280,209,252,261,377,316,559,8605,301,2370,4071,2660,201,296,9915,295,767,1130,298,201,225,4578,4052,209,261,1994,6334,3720,296,7162,9879,669,3440,550,12897,209,795,6395,1080,201,221,1151,201,198,1994,24422,4624,1215,6233,197,301,3460,227,196,4279,3794,201,445,198,3794,301,4752,578,307,198,8126,218,12897,201,11701,198,1994,24422,209,252,261,8445,7339,218,282,14,12331,2208,283,301,198,1013,17152,201,19210,196,760,12898,209,261,506,1293,390,1365,1553,201,225,684,293,6982,982,778,1053,201],"4":[1360,15608,2521,390,9718,227,2100,221,198,4520,1211,444,8564,804,410,14358,209,261,506,1293,684,293,8793,307,196,845,218,6498,466,252,261,494,246,5398,218,282,14,196,760,12898,10043,496,371,604,20338,263,198,24534,201,400,390,10581,221,282,14,12331,2208,283,209,252,261,20338,263,198,23714,218,282,14,196,760,12898,390,1994,371,1994,277,22181,201,732,1413,218,282,14,12331,2208,283,390,1969,371,1969,277,22181,209,252,261,24534,218,198,472,655,218,282,14,196,760,12898,301,12960,371,1994,201,732,295,218,282,14,12331,2208,283,301,942,5676,1969,371,1365,8202,1994,209],"5":[252,238,238,5971,7308,238,238,252,14035,282,14,12331,2208,283,3484,3956,13559,646,543,490,8498,227,196,1044,422,571,2591,218,3105,459,5626,22608,311,428,511,245,459,428,511,428,221,310,201,6844,5020,10719,319,196,4370,3199,2892,209,259,761,4183,6395,221,3271,708,196,5527,24820,744,2476,201,2821,4093,301,4947,6144,201,225,196,2804,277,24110,2956,209,261,2476,9651,198,7144,274,578,227,837,2200,201,8326,263,198,5010,201,6386,227,468,4405,343,11051,209,17172,7600,7144,390,1014,227,293,239,20148,1635,239,225,684,293,824,1895,198,544,209,252,261,7144,15638,319,1948,201,225],"6":[198,17469,9408,227,198,1371,3239,756,543,11642,296,198,7153,20898,201,577,2464,263,22762,8222,858,851,209,795,2570,11533,697,24820,745,225,15964,274,916,459,2380,1651,209,929,198,1368,24820,84,201,198,16695,3544,263,196,740,7727,227,198,4047,201,225,3313,745,196,219,2121,229,13643,209,261,16696,390,6678,1809,221,198,4520,201,225,390,7808,876,201,1360,543,390,876,227,293,5992,218,15590,5536,2515,6846,209,530,301,2729,295,669,245,768,6820,221,1621,464,562,802,14272,227,198,219,2121,229,5427,209,1438,543,3484,196,1044,422,571,2591,218,916,1999,311,796,511,5931,221,310],"7":[201,198,16696,3489,444,2515,6846,225,1046,444,4047,4658,209,252,238,238,14720,238,238,252,19210,12331,2208,283,301,824,1896,198,1128,277,2225,2477,4575,340,2518,7984,227,198,14780,225,22054,201,410,867,198,18294,4116,209,530,301,471,1474,221,659,218,198,7662,4116,201,669,6157,340,198,3161,1657,218,4968,563,201,225,1008,669,198,1128,277,1639,2076,218,198,1885,4116,209,261,2518,1942,5892,390,824,221,198,10768,214,74,1907,236,1162,70,74,412,213,225,18700,17840,65,201,3950,198,21427,9743,234,209,252,261,1293,684,293,4695,587,784,1851,269,1320,4344,5892,201,496,6098,1866],"8":[201,225,697,400,490,9068,1166,1309,227,1119,5501,1866,15321,201,4692,1309,227,8552,227,198,1492,5039,209,261,455,218,1021,301,198,1866,218,7162,9879,340,2518,7984,201,400,490,504,3242,227,281,198,239,16501,277,6271,17152,239,209,261,5892,221,198,7662,4116,390,4344,340,1413,221,198,2477,4575,209,261,1244,4344,1866,301,824,221,198,7533,466,12828,340,198,339,419,320,4651,352,69,360,4344,340,1413,5464,221,198,1182,4116,371,1545,7484,209,252,24135,83,490,504,746,227,14275,282,14,12331,2208,283,227,752,2952,201,4583,548,2364,1293,781,281,198,15494,21089,201,24872,224,330],"9":[260,283,209,5507,11329,225,5989,201,496,1079,17152,17469,360,1055,340,6066,2264,346,221,5222,204,197,201,445,198,1293,992,410,1638,1860,817,209,252,238,238,19976,238,238,252,11312,282,14,12331,2208,283,2253,263,198,10077,22975,319,21762,218,796,459,4178,3224,311,796,459,4468,18,1841,310,201,1360,410,7736,15408,702,1834,230,311,5722,1841,310,209,1065,5463,2804,8138,7075,201,781,281,10578,371,2804,11832,201,225,2253,221,12782,371,942,86,1656,201,16062,319,1948,227,7710,209,252,261,7546,218,282,14,12331,2208,283,3223,5886,218,548,219,2121,229,18499,209,1850,1766,6063,1466,201],"10":[230,1894,13799,83,201,3304,159,885,776,201,1322,11207,225,4291,18680,16328,21077,209,252,261,697,472,655,204,17152,1293,19210,12331,2208,283,201,282,14,196,760,12898,225,1853,575,635,83,5746,711,71,10491,390,12202,227,198,697,876,1293,218,198,4926,1091,930,327,247,2236,73,2401,4403,440,198,1293,263,282,14,12331,2208,283,499,410,504,1246,209,252,19210,12331,2208,283,301,22081,227,198,4516,267,1867,4219,7024,201,2035,307,198,14477,1026,5611,375,13489,283,9769,271,738,209,1647,322,301,5431,824,221,1013,7162,9879,201,198,4516,499,669,504,1809,221,19676,282,14,12331,2208,283],"11":[201,756,3119,9025,218,198,8284,307,282,14,196,760,12898,990,410,293,6620,573,209,252,238,238,6870,11391,238,238,252,19210,12331,2208,283,301,13924,239,3879,928,21008,204,239,281,196,2970,246,3760,225,258,4478,221,239,261,10835,11207,239,196,18409,926,1545,4676,647,209,530,977,13762,309,1365,823,15210,225,977,293,1574,5739,201,16791,7010,201,684,2457,371,4192,68,665,209,3965,198,23714,225,198,22051,218,282,14,12331,2208,283,4607,239,13123,239,1969,12394,201,225,659,218,198,15891,218,198,2374,15444,679,210,1215,390,15494,209,261,17588,390,198,21202,523,955,225,198,239,4936],"12":[22774,239,311,21201,310,209,261,11588,218,282,14,12331,2208,283,301,578,227,697,1696,4014,702,295,218,282,14,196,760,12898,201,225,198,2364,1293,301,1535,227,490,196,3077,12693,418,209,252,317,706,9879,390,3223,214,736,1919,17152,224,1637,201,1360,2875,22218,616,296,11820,20305,371,2439,84,234,11207,3019,8900,221,22179,223,642,573,201,227,1313,642,227,293,6637,221,196,8946,371,307,1414,209,357,1465,201,493,562,3075,22,194,218,282,14,12331,2208,283,360,6637,1896,1399,225,1182,3054,201,218,400,428,562,4683,18,194,311,6600,1124,310,258,6637,221,198,1038,11303,311],"13":[867,198,7484,7362,310,209,261,8240,5360,2892,274,282,14,12331,2208,283,301,196,1044,422,571,2591,218,6286,1999,311,428,511,493,221,310,209,252,13773,257,21384,5135,274,282,14,12331,2208,283,390,699,1921,201,225,1743,7928,390,1490,1365,1669,209,252,238,238,8937,16097,1514,238,238,252,19210,12331,2208,283,258,455,1782,196,12072,22267,1221,307,5064,317,909,19649,221,198,8813,4615,218,342,5599,65,318,4881,1240,201,1573,221,5986,24,209,2881,1221,258,24872,12331,2208,283,201,1211,317,909,19649,268,3301,218,198,4013,24872,319,295,582,1331,610,1178,18183,571,738,209,252,282,14],"14":[12331,2208,283,301,198,3108,1293,218,198,4013,19210,17518,201,20690,201,281,4729,307,300,4340,3945,218,198,2392,4774,263,16455,1839,318,249,22701,1411,209,7666,227,295,4568,201,11527,12061,1080,198,1293,398,504,3242,227,307,951,1579,2850,201,867,5265,257,283,1261,4487,14166,523,2215,201,22956,225,19210,18381,215,282,14,3248,1775,277,18903,201,14300,201,225,471,1080,14415,17518,268,83,5278,218,198,4013,398,504,17228,1040,1994,22708,307,2489,297,633,66,302,201,19912,1061,3119,13525,967,218,3108,1293,311,274,19210,282,14,3248,1775,277,18903,201,14300,310,21468,274,19210,17518,201,20690,209,252],"15":[261,3108,7139,218,19210,12331,2208,283,258,196,10529,7769,3941,307,23601,559,3059,299,85,215,221,5629,209,530,1501,340,5197,2719,4800,396,110,318,920,2719,3434,396,110,338,201,996,681,246,1245,201,9704,311,4067,5549,371,1342,1957,4587,218,9901,787,6012,310,201,445,857,322,225,198,1346,222,395,14403,490,1211,504,1777,209,252,261,1512,1221,274,282,14,12331,2208,283,7676,307,198,9758,225,15124,447,8473,301,239,2364,17152,239,201,445,198,1293,301,471,3991,876,281,198,239,1512,17152,239,209,252,238,3180,5577,4593,238,252,1934,8314,12825,3180,5577,4593,201,17920,201,17920,37]}} +""" + # 16 first rows of tokenized wikitext-2-raw-v1 training & validation -train_df = pd.read_json("./huggingface_data/train.json") -validation_df = pd.read_json("./huggingface_data/validation.json") +train_df = pd.read_json(train_data) +validation_df = pd.read_json(validation_data) # We are only testing Casual Language Modelling here From 5617524e284cf2db5b6a47dbe4bb10fba601ed1c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 21 Apr 2022 18:31:47 +0000 Subject: [PATCH 51/75] Add predictor test, fix small issues --- python/ray/ml/BUILD | 17 ++++ .../huggingface/huggingface_predictor.py | 84 ++++++++++++------- python/ray/ml/tests/_huggingface_data.py | 7 ++ .../ml/tests/test_huggingface_predictor.py | 45 ++++++++++ .../ray/ml/tests/test_huggingface_trainer.py | 24 ++---- .../huggingface/huggingface_trainer.py | 1 + 6 files changed, 130 insertions(+), 48 deletions(-) create mode 100644 python/ray/ml/tests/_huggingface_data.py create mode 100644 python/ray/ml/tests/test_huggingface_predictor.py diff --git a/python/ray/ml/BUILD b/python/ray/ml/BUILD index cd4c394010f7..5439b23af967 100644 --- a/python/ray/ml/BUILD +++ b/python/ray/ml/BUILD @@ -11,6 +11,15 @@ py_test( deps = [":ml_lib"] ) +py_test ( + name = "huggingface_basic_language_modeling_example", + size = "medium", + srcs = ["examples/huggingface/huggingface_basic_language_modeling_example.py"], + args = ["--smoke-test", "--num-epochs 3"], + tags = ["team:ml", "exclusive"], + deps = [":ml_lib"] +) + py_test ( name = "lightgbm_example", size = "medium", @@ -172,6 +181,14 @@ py_test( deps = [":ml_lib"] ) +py_test( + name = "test_huggingface_predictor", + size = "small", + srcs = ["tests/test_huggingface_predictor.py"], + tags = ["team:ml", "exclusive"], + deps = [":ml_lib"] +) + py_test( name = "test_huggingface_trainer", size = "medium", diff --git a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py index 3dad01464a34..0b6ba6eaa884 100644 --- a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py +++ b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py @@ -107,7 +107,8 @@ def _convert_to_tensor( ) def _predict(self, tensor: Dict[str, torch.Tensor]) -> pd.DataFrame: - self.training_args.local_rank = -1 + if self.training_args: + self.training_args.local_rank = -1 trainer = HFTrainer(model=self.model, args=self.training_args) dataset = HFIterableDatasetWithLen([tensor], 1) # squeeze out the extra dimension added by torch.stack @@ -145,41 +146,64 @@ def predict( .. code-block:: python import numpy as np - import torch - from ray.ml.predictors.torch import TorchPredictor - - model = torch.nn.Linear(1, 1) - predictor = TorchPredictor(model=model) - - data = np.array([[1, 2], [3, 4]]) - predictions = predictor.predict(data) - - # Only use first column as the feature - predictions = predictor.predict(data, feature_columns=[0]) - - .. code-block:: python - - import pandas as pd - import torch - from ray.ml.predictors.torch import TorchPredictor - - model = torch.nn.Linear(1, 1) - predictor = TorchPredictor(model=model) - - # Pandas dataframe. - data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) - - predictions = predictor.predict(data) - - # Only use first column as the feature - predictions = predictor.predict(data, feature_columns=["A"]) + from datasets import load_dataset + from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + from ray.ml.predictors.huggingface import HuggingFacePredictor + + model_checkpoint = "gpt2" + tokenizer_checkpoint = "sgugger/gpt2-like-tokenizer" + block_size = 128 + + datasets = load_dataset("wikitext", "wikitext-2-raw-v1") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) + + def tokenize_function(examples): + return tokenizer(examples["text"]) + + tokenized_datasets = datasets.map( + tokenize_function, batched=True, num_proc=1, remove_columns=["text"] + ) + + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = { + k: sum(examples[k], []) for k in examples.keys() + } + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model + # supported it. + # instead of this drop, you can customize this part to your needs. + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [ + t[i : i + block_size] + for i in range(0, total_length, block_size) + ] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + batch_size=1000, + num_proc=1, + ) + model_config = AutoConfig.from_pretrained(model_checkpoint) + model = AutoModelForCausalLM.from_config(model_config) + predictor = HuggingFacePredictor( + model=model, preprocessor=preprocessor + ) + + predictions = predictor.predict(lm_datasets["validation"].to_pandas()) Returns: DataBatchType: Prediction result. """ # We are just changing the signature and docstring. - print(data) return super().predict( data, feature_columns=feature_columns, dtype=dtype, unsqueeze=False ) diff --git a/python/ray/ml/tests/_huggingface_data.py b/python/ray/ml/tests/_huggingface_data.py new file mode 100644 index 000000000000..670cc7102e14 --- /dev/null +++ b/python/ray/ml/tests/_huggingface_data.py @@ -0,0 +1,7 @@ +train_data = """ +{"input_ids":{"0":[238,8576,9441,2987,238,252,4657,74,4762,826,8576,428,466,609,6881,412,204,9441,311,2746,466,10816,168,99,150,192,112,14328,3983,112,4446,94,18288,4446,193,3983,98,3983,22171,95,19,201,6374,209,8576,218,198,3455,1972,428,310,201,5099,3242,227,281,8576,9441,2987,2553,1759,201,301,196,13996,1496,277,2330,1464,674,1898,307,742,3541,225,7514,14,54,719,274,198,4777,15522,209,19895,221,1341,1633,221,1759,201,322,301,198,1368,674,221,198,8576,843,209,2468,1795,223,198,1049,9595,218,13996,225,1563,277,582,6493,281,457,14371,201,198,1422,3373,7452,227,198,455,674,225,4687],"1":[198,239,21976,239,201,196,21657,1680,3773,5591,198,4196,218,4679,427,661,198,3518,1288,220,1051,516,889,3947,1922,2500,225,390,2065,744,872,198,7592,3773,239,1975,251,208,89,22351,239,209,252,261,674,959,1921,221,1462,201,7600,547,196,1178,4753,218,198,630,3591,263,8576,9441,1180,209,1831,322,7568,198,3621,2240,218,198,843,201,322,471,9575,5291,16591,967,201,781,281,1815,198,674,604,10344,1252,274,843,664,3147,320,209,13290,8751,8124,2528,6023,74,235,225,7445,10040,17384,241,11487,8950,857,1835,340,1382,22582,201,1008,296,8576,9441,1180,2436,21134,5337,19463,5161,209,240,1178],"2":[927,218,3776,8650,198,3355,209,261,674,268,83,2511,3472,258,8288,307,1010,268,78,209,252,530,922,296,3096,4354,221,1759,201,225,258,2630,307,857,2746,225,2761,2234,209,929,1634,201,322,1217,13362,4798,201,1008,296,326,5585,4615,221,1389,218,295,544,209,530,258,471,6459,587,8630,225,326,1446,1464,10260,843,209,5023,227,1669,4354,218,8576,9441,1180,201,8576,9441,2987,258,410,18878,201,445,196,6418,8895,20885,296,198,674,268,83,5585,4615,258,1055,221,2474,209,7514,14,54,719,602,1063,227,198,7072,296,198,1921,218,8576,466,8111,447,5474,274,198,4777,493,209,252],"3":[238,238,15114,238,238,252,662,296,1382,7784,9345,9441,1374,201,8576,9441,2987,301,196,13996,1496,277,2330,674,756,1997,1581,1694,218,196,1680,3773,225,1581,540,221,6455,872,5000,1754,209,8410,390,2848,718,5926,1184,277,1093,13659,296,10102,1136,13018,201,296,2081,6041,8040,718,7674,6917,14877,6632,225,8040,718,434,15744,5263,3973,209,261,1538,15606,718,196,843,218,11382,6455,201,5546,13230,5382,281,12903,295,684,293,15086,518,6145,718,225,1486,13196,281,543,390,13230,5382,209,261,2262,227,1140,1422,3371,263,198,7464,11718,8326,263,326,3854,1538,268,83,4156,466,646,496,6938,301,3941,201],"4":[198,548,301,15410,555,227,198,1538,209,13067,6455,201,198,1538,2081,1231,221,196,1306,201,756,2726,684,293,7376,968,225,1136,5168,6395,209,24281,198,906,1422,6455,390,1136,277,2737,747,6455,13516,227,1579,5860,1718,209,929,198,674,268,83,7131,201,3079,3411,390,13230,5382,201,766,218,642,1763,196,4014,6669,702,1413,824,221,198,1231,218,198,674,209,1415,390,471,2610,1076,1388,2377,3130,227,198,674,268,83,506,906,3641,828,201,1360,543,1581,196,1365,2978,1496,209,252,261,674,268,83,2094,1232,201,198,1028,73,52,58,1232,201,301,3493,547,3660,340,7784,9345,9441,209],"5":[1303,6455,201,1997,6322,1140,3773,1919,196,1362,277,1004,8550,218,198,11332,7464,466,2109,196,1136,301,3941,201,198,1538,5757,198,1136,1118,198,11332,221,1368,277,1432,209,240,1136,684,669,733,2109,503,277,1407,201,445,2081,684,293,4929,5291,6398,319,198,13876,218,548,2081,268,6398,209,5056,1136,499,196,1606,225,3924,218,3672,3069,307,444,11840,315,15041,209,3264,227,2849,2081,684,293,3423,227,196,1314,4490,209,1303,6493,201,2081,1170,798,573,1358,3851,16048,227,642,201,781,281,444,3668,2423,311,282,48,310,6506,1669,371,812,11285,573,307,5000,3704,209,5056,1136,499,2737,239],"6":[8303,2453,83,239,201,8436,5025,227,1140,1136,209,1065,390,4695,587,239,12021,8303,2453,239,201,400,390,3295,365,8436,295,3052,434,222,8830,11270,7936,9067,380,307,198,1422,225,684,2292,1546,371,1017,21071,196,1136,201,225,239,3455,8303,2453,83,239,201,400,390,8498,1895,198,674,225,3365,3326,1958,675,227,196,1136,209,2000,2627,3455,8303,2453,83,201,1140,1136,499,196,5025,239,8119,236,614,239,201,196,14628,277,1356,11949,6230,295,684,293,788,227,13970,225,3261,1579,8436,209,16500,471,490,7447,1531,3004,295,3326,642,8535,1958,11069,263,198,11332,466,8625,684,1980,365,239,4495],"7":[3683,239,225,2341,1118,198,11332,1582,897,234,1236,342,11840,6957,16326,201,198,1136,789,9878,684,9627,587,468,239,8576,3408,239,225,1638,896,23892,201,732,2293,5336,684,4849,5291,5000,2726,296,468,2287,8263,209,252,13441,1355,390,4695,587,1160,6555,466,23804,201,808,722,84,228,19466,201,12785,201,7096,320,225,21755,18607,853,209,13441,19466,684,12454,6555,307,7048,444,3423,8263,209,11070,223,1249,1990,410,5993,11950,198,285,1201,4825,732,221,196,1382,1249,209,1585,2495,221,2094,201,3335,2423,390,3495,227,198,5860,201,400,390,6908,587,1160,1579,13376,6136,307,198,2068,5860,201,196,3135],"8":[1317,223,340,989,1374,268,3588,218,10078,5751,227,1579,3773,4325,209,252,238,238,6851,238,238,252,261,674,3544,1417,661,198,3518,1288,220,1051,209,4679,406,2205,13619,3775,18,201,471,876,281,239,261,21976,239,201,390,196,21657,1680,3773,3808,218,18490,201,4899,541,199,690,201,225,1680,555,16914,2821,1563,2850,390,1843,620,340,198,3517,225,817,200,4614,3242,227,307,3473,209,10730,665,307,198,4679,406,1680,227,889,198,659,8110,6455,295,198,22724,2205,225,18991,1170,410,804,201,543,390,15994,578,227,198,5214,201,20891,307,444,22232,79,201,18766,4209,1531,21157,201,3457,239],"9":[14996,789,2989,209,239,261,697,906,2081,390,1524,14,23,8625,2408,2513,201,326,3006,4417,16271,539,7184,218,15101,516,9461,227,24374,1920,440,8607,1524,14,17,2293,5336,201,196,2476,5447,67,3294,2287,5253,16021,516,17113,14241,872,198,8576,516,3443,468,1398,440,225,1524,14,17,19,297,793,65,20714,215,201,196,11573,536,197,88,204,2246,3063,516,301,4796,2616,5744,196,20713,218,198,8576,209,9332,296,444,5920,5860,1718,201,1021,697,390,10455,227,2172,872,196,12040,7592,3773,876,281,1975,251,379,22351,201,6355,218,3223,5447,67,3294,3478,209,252,662,198,21976,4614,804,410,2523],"10":[201,198,3578,5080,1193,675,218,198,4679,406,2205,21878,198,3301,218,21337,2903,14454,221,1688,227,5355,642,263,6455,295,602,7936,1445,4679,427,7843,3132,221,198,1060,209,1831,319,1696,526,2125,227,444,4517,201,781,281,196,1991,13719,587,7592,3833,201,548,4953,3522,2931,1718,218,198,3775,18,207,1428,16464,209,1539,781,2541,201,315,283,2809,201,4228,749,554,16361,295,331,3669,675,342,1753,225,770,2220,587,198,10100,218,1975,251,379,22351,201,6386,227,198,9541,218,5447,67,3294,5668,3412,307,444,5164,201,8954,971,209,775,198,1049,582,201,2377,1347,4679,406,2205,3683,2341,227],"11":[1843,626,198,21976,221,1688,227,6259,444,1059,7969,209,282,15902,307,857,13855,225,9782,201,225,4439,296,198,3752,218,196,836,3547,1347,444,10100,201,198,3775,18,207,11695,1206,2341,227,3292,3230,10852,732,319,198,1049,582,2172,227,1546,198,4679,406,1060,4585,209,795,4471,1040,198,21976,268,83,14643,4417,201,3790,5491,22737,201,516,398,504,4755,699,2010,5125,201,301,15218,227,198,2867,1047,218,12106,71,337,90,221,1688,227,1474,2525,377,200,199,761,198,4239,89,3478,225,20416,198,1563,836,3547,201,198,4679,406,1796,295,398,7184,8625,218,3430,599,209,252,2304,264,1309,227],"12":[1021,1988,201,225,7259,1309,227,198,1214,5662,221,23734,4679,427,2433,320,2375,198,726,218,198,1060,296,198,4444,201,198,21976,390,3946,196,7944,1676,281,196,5860,221,198,4679,406,2205,2472,702,5228,281,326,13483,21340,2033,209,795,301,1311,277,3364,201,1650,201,281,1149,19033,15786,268,83,3441,201,8954,971,225,1975,251,379,22351,2341,227,1980,365,326,4479,7784,2920,2157,8263,1347,198,4444,201,4755,3947,307,444,5698,17446,209,9663,198,1117,218,19033,15786,371,198,6241,227,8223,3230,221,198,1060,296,4679,427,201,322,301,8954,971,268,83,1244,677,1940,6368,221,5290,196,664,5447],"13":[67,3294,4196,209,662,326,4579,4679,406,2033,896,5021,198,4444,1318,1149,198,506,8947,268,12199,277,1739,602,9684,14048,444,664,11418,4206,201,8625,8420,227,2109,695,1445,342,5860,198,21976,201,9271,22737,227,1343,1920,225,610,699,342,1434,281,2196,277,221,277,2429,209,7243,15215,17157,227,7235,548,702,3230,201,198,3775,18,207,21955,8954,971,225,1503,24815,198,7784,2920,8263,209,5056,2541,862,6084,444,4142,6294,221,1688,227,4555,444,4658,326,361,209,252,238,238,5823,238,238,252,19611,630,274,8576,9441,2987,959,550,1921,2393,263,8576,9441,1180,221,989,1462,201,296,1971,1921,3053],"14":[4691,550,526,209,261,2436,218,8576,9441,1180,201,21134,5337,19463,5161,201,1835,227,295,1496,274,8576,9441,2987,209,5823,630,1260,3032,496,544,209,929,198,1634,218,8576,9441,1180,201,198,3154,1260,196,1976,319,857,198,1737,3092,274,198,674,225,1241,543,2512,227,804,1533,274,198,843,209,4064,457,8472,201,8576,9441,2987,258,1898,274,4777,15522,466,526,258,1309,227,198,927,12182,227,1308,415,198,16590,2054,274,8576,9441,1180,201,225,543,398,410,3029,578,296,198,239,12252,239,3213,295,602,13884,196,664,6494,274,198,4777,428,209,11929,1006,221,326,2692,201,322,258,1821,295],"15":[198,1921,927,1535,8576,9441,2987,227,293,198,843,268,455,4099,7056,466,732,8576,9441,1180,398,3034,196,1178,3126,218,5429,225,12738,661,1921,1309,227,198,5770,2341,201,198,1368,674,1706,642,196,6241,227,7127,2451,198,1386,2773,218,8576,9441,1180,1309,227,812,263,198,1049,5770,209,357,1513,227,742,3541,3154,340,198,1382,1374,201,1921,630,258,471,8650,307,7514,14,54,719,14,261,1446,13395,258,1617,15415,9455,558,9884,706,69,201,732,198,3355,258,1617,307,11856,1652,9455,12645,74,5435,201,11490,20317,12329,10190,201,384,424,11300,21187,10198,201,742,14619,9765,368,5161,225,10830,288]},"attention_mask":{"0":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"1":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"2":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"3":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"4":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"5":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"6":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"7":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"8":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"9":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"10":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"11":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"12":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"13":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"14":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"15":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]},"labels":{"0":[238,8576,9441,2987,238,252,4657,74,4762,826,8576,428,466,609,6881,412,204,9441,311,2746,466,10816,168,99,150,192,112,14328,3983,112,4446,94,18288,4446,193,3983,98,3983,22171,95,19,201,6374,209,8576,218,198,3455,1972,428,310,201,5099,3242,227,281,8576,9441,2987,2553,1759,201,301,196,13996,1496,277,2330,1464,674,1898,307,742,3541,225,7514,14,54,719,274,198,4777,15522,209,19895,221,1341,1633,221,1759,201,322,301,198,1368,674,221,198,8576,843,209,2468,1795,223,198,1049,9595,218,13996,225,1563,277,582,6493,281,457,14371,201,198,1422,3373,7452,227,198,455,674,225,4687],"1":[198,239,21976,239,201,196,21657,1680,3773,5591,198,4196,218,4679,427,661,198,3518,1288,220,1051,516,889,3947,1922,2500,225,390,2065,744,872,198,7592,3773,239,1975,251,208,89,22351,239,209,252,261,674,959,1921,221,1462,201,7600,547,196,1178,4753,218,198,630,3591,263,8576,9441,1180,209,1831,322,7568,198,3621,2240,218,198,843,201,322,471,9575,5291,16591,967,201,781,281,1815,198,674,604,10344,1252,274,843,664,3147,320,209,13290,8751,8124,2528,6023,74,235,225,7445,10040,17384,241,11487,8950,857,1835,340,1382,22582,201,1008,296,8576,9441,1180,2436,21134,5337,19463,5161,209,240,1178],"2":[927,218,3776,8650,198,3355,209,261,674,268,83,2511,3472,258,8288,307,1010,268,78,209,252,530,922,296,3096,4354,221,1759,201,225,258,2630,307,857,2746,225,2761,2234,209,929,1634,201,322,1217,13362,4798,201,1008,296,326,5585,4615,221,1389,218,295,544,209,530,258,471,6459,587,8630,225,326,1446,1464,10260,843,209,5023,227,1669,4354,218,8576,9441,1180,201,8576,9441,2987,258,410,18878,201,445,196,6418,8895,20885,296,198,674,268,83,5585,4615,258,1055,221,2474,209,7514,14,54,719,602,1063,227,198,7072,296,198,1921,218,8576,466,8111,447,5474,274,198,4777,493,209,252],"3":[238,238,15114,238,238,252,662,296,1382,7784,9345,9441,1374,201,8576,9441,2987,301,196,13996,1496,277,2330,674,756,1997,1581,1694,218,196,1680,3773,225,1581,540,221,6455,872,5000,1754,209,8410,390,2848,718,5926,1184,277,1093,13659,296,10102,1136,13018,201,296,2081,6041,8040,718,7674,6917,14877,6632,225,8040,718,434,15744,5263,3973,209,261,1538,15606,718,196,843,218,11382,6455,201,5546,13230,5382,281,12903,295,684,293,15086,518,6145,718,225,1486,13196,281,543,390,13230,5382,209,261,2262,227,1140,1422,3371,263,198,7464,11718,8326,263,326,3854,1538,268,83,4156,466,646,496,6938,301,3941,201],"4":[198,548,301,15410,555,227,198,1538,209,13067,6455,201,198,1538,2081,1231,221,196,1306,201,756,2726,684,293,7376,968,225,1136,5168,6395,209,24281,198,906,1422,6455,390,1136,277,2737,747,6455,13516,227,1579,5860,1718,209,929,198,674,268,83,7131,201,3079,3411,390,13230,5382,201,766,218,642,1763,196,4014,6669,702,1413,824,221,198,1231,218,198,674,209,1415,390,471,2610,1076,1388,2377,3130,227,198,674,268,83,506,906,3641,828,201,1360,543,1581,196,1365,2978,1496,209,252,261,674,268,83,2094,1232,201,198,1028,73,52,58,1232,201,301,3493,547,3660,340,7784,9345,9441,209],"5":[1303,6455,201,1997,6322,1140,3773,1919,196,1362,277,1004,8550,218,198,11332,7464,466,2109,196,1136,301,3941,201,198,1538,5757,198,1136,1118,198,11332,221,1368,277,1432,209,240,1136,684,669,733,2109,503,277,1407,201,445,2081,684,293,4929,5291,6398,319,198,13876,218,548,2081,268,6398,209,5056,1136,499,196,1606,225,3924,218,3672,3069,307,444,11840,315,15041,209,3264,227,2849,2081,684,293,3423,227,196,1314,4490,209,1303,6493,201,2081,1170,798,573,1358,3851,16048,227,642,201,781,281,444,3668,2423,311,282,48,310,6506,1669,371,812,11285,573,307,5000,3704,209,5056,1136,499,2737,239],"6":[8303,2453,83,239,201,8436,5025,227,1140,1136,209,1065,390,4695,587,239,12021,8303,2453,239,201,400,390,3295,365,8436,295,3052,434,222,8830,11270,7936,9067,380,307,198,1422,225,684,2292,1546,371,1017,21071,196,1136,201,225,239,3455,8303,2453,83,239,201,400,390,8498,1895,198,674,225,3365,3326,1958,675,227,196,1136,209,2000,2627,3455,8303,2453,83,201,1140,1136,499,196,5025,239,8119,236,614,239,201,196,14628,277,1356,11949,6230,295,684,293,788,227,13970,225,3261,1579,8436,209,16500,471,490,7447,1531,3004,295,3326,642,8535,1958,11069,263,198,11332,466,8625,684,1980,365,239,4495],"7":[3683,239,225,2341,1118,198,11332,1582,897,234,1236,342,11840,6957,16326,201,198,1136,789,9878,684,9627,587,468,239,8576,3408,239,225,1638,896,23892,201,732,2293,5336,684,4849,5291,5000,2726,296,468,2287,8263,209,252,13441,1355,390,4695,587,1160,6555,466,23804,201,808,722,84,228,19466,201,12785,201,7096,320,225,21755,18607,853,209,13441,19466,684,12454,6555,307,7048,444,3423,8263,209,11070,223,1249,1990,410,5993,11950,198,285,1201,4825,732,221,196,1382,1249,209,1585,2495,221,2094,201,3335,2423,390,3495,227,198,5860,201,400,390,6908,587,1160,1579,13376,6136,307,198,2068,5860,201,196,3135],"8":[1317,223,340,989,1374,268,3588,218,10078,5751,227,1579,3773,4325,209,252,238,238,6851,238,238,252,261,674,3544,1417,661,198,3518,1288,220,1051,209,4679,406,2205,13619,3775,18,201,471,876,281,239,261,21976,239,201,390,196,21657,1680,3773,3808,218,18490,201,4899,541,199,690,201,225,1680,555,16914,2821,1563,2850,390,1843,620,340,198,3517,225,817,200,4614,3242,227,307,3473,209,10730,665,307,198,4679,406,1680,227,889,198,659,8110,6455,295,198,22724,2205,225,18991,1170,410,804,201,543,390,15994,578,227,198,5214,201,20891,307,444,22232,79,201,18766,4209,1531,21157,201,3457,239],"9":[14996,789,2989,209,239,261,697,906,2081,390,1524,14,23,8625,2408,2513,201,326,3006,4417,16271,539,7184,218,15101,516,9461,227,24374,1920,440,8607,1524,14,17,2293,5336,201,196,2476,5447,67,3294,2287,5253,16021,516,17113,14241,872,198,8576,516,3443,468,1398,440,225,1524,14,17,19,297,793,65,20714,215,201,196,11573,536,197,88,204,2246,3063,516,301,4796,2616,5744,196,20713,218,198,8576,209,9332,296,444,5920,5860,1718,201,1021,697,390,10455,227,2172,872,196,12040,7592,3773,876,281,1975,251,379,22351,201,6355,218,3223,5447,67,3294,3478,209,252,662,198,21976,4614,804,410,2523],"10":[201,198,3578,5080,1193,675,218,198,4679,406,2205,21878,198,3301,218,21337,2903,14454,221,1688,227,5355,642,263,6455,295,602,7936,1445,4679,427,7843,3132,221,198,1060,209,1831,319,1696,526,2125,227,444,4517,201,781,281,196,1991,13719,587,7592,3833,201,548,4953,3522,2931,1718,218,198,3775,18,207,1428,16464,209,1539,781,2541,201,315,283,2809,201,4228,749,554,16361,295,331,3669,675,342,1753,225,770,2220,587,198,10100,218,1975,251,379,22351,201,6386,227,198,9541,218,5447,67,3294,5668,3412,307,444,5164,201,8954,971,209,775,198,1049,582,201,2377,1347,4679,406,2205,3683,2341,227],"11":[1843,626,198,21976,221,1688,227,6259,444,1059,7969,209,282,15902,307,857,13855,225,9782,201,225,4439,296,198,3752,218,196,836,3547,1347,444,10100,201,198,3775,18,207,11695,1206,2341,227,3292,3230,10852,732,319,198,1049,582,2172,227,1546,198,4679,406,1060,4585,209,795,4471,1040,198,21976,268,83,14643,4417,201,3790,5491,22737,201,516,398,504,4755,699,2010,5125,201,301,15218,227,198,2867,1047,218,12106,71,337,90,221,1688,227,1474,2525,377,200,199,761,198,4239,89,3478,225,20416,198,1563,836,3547,201,198,4679,406,1796,295,398,7184,8625,218,3430,599,209,252,2304,264,1309,227],"12":[1021,1988,201,225,7259,1309,227,198,1214,5662,221,23734,4679,427,2433,320,2375,198,726,218,198,1060,296,198,4444,201,198,21976,390,3946,196,7944,1676,281,196,5860,221,198,4679,406,2205,2472,702,5228,281,326,13483,21340,2033,209,795,301,1311,277,3364,201,1650,201,281,1149,19033,15786,268,83,3441,201,8954,971,225,1975,251,379,22351,2341,227,1980,365,326,4479,7784,2920,2157,8263,1347,198,4444,201,4755,3947,307,444,5698,17446,209,9663,198,1117,218,19033,15786,371,198,6241,227,8223,3230,221,198,1060,296,4679,427,201,322,301,8954,971,268,83,1244,677,1940,6368,221,5290,196,664,5447],"13":[67,3294,4196,209,662,326,4579,4679,406,2033,896,5021,198,4444,1318,1149,198,506,8947,268,12199,277,1739,602,9684,14048,444,664,11418,4206,201,8625,8420,227,2109,695,1445,342,5860,198,21976,201,9271,22737,227,1343,1920,225,610,699,342,1434,281,2196,277,221,277,2429,209,7243,15215,17157,227,7235,548,702,3230,201,198,3775,18,207,21955,8954,971,225,1503,24815,198,7784,2920,8263,209,5056,2541,862,6084,444,4142,6294,221,1688,227,4555,444,4658,326,361,209,252,238,238,5823,238,238,252,19611,630,274,8576,9441,2987,959,550,1921,2393,263,8576,9441,1180,221,989,1462,201,296,1971,1921,3053],"14":[4691,550,526,209,261,2436,218,8576,9441,1180,201,21134,5337,19463,5161,201,1835,227,295,1496,274,8576,9441,2987,209,5823,630,1260,3032,496,544,209,929,198,1634,218,8576,9441,1180,201,198,3154,1260,196,1976,319,857,198,1737,3092,274,198,674,225,1241,543,2512,227,804,1533,274,198,843,209,4064,457,8472,201,8576,9441,2987,258,1898,274,4777,15522,466,526,258,1309,227,198,927,12182,227,1308,415,198,16590,2054,274,8576,9441,1180,201,225,543,398,410,3029,578,296,198,239,12252,239,3213,295,602,13884,196,664,6494,274,198,4777,428,209,11929,1006,221,326,2692,201,322,258,1821,295],"15":[198,1921,927,1535,8576,9441,2987,227,293,198,843,268,455,4099,7056,466,732,8576,9441,1180,398,3034,196,1178,3126,218,5429,225,12738,661,1921,1309,227,198,5770,2341,201,198,1368,674,1706,642,196,6241,227,7127,2451,198,1386,2773,218,8576,9441,1180,1309,227,812,263,198,1049,5770,209,357,1513,227,742,3541,3154,340,198,1382,1374,201,1921,630,258,471,8650,307,7514,14,54,719,14,261,1446,13395,258,1617,15415,9455,558,9884,706,69,201,732,198,3355,258,1617,307,11856,1652,9455,12645,74,5435,201,11490,20317,12329,10190,201,384,424,11300,21187,10198,201,742,14619,9765,368,5161,225,10830,288]}} +""" + +validation_data = """ +{"input_ids":{"0":[238,19210,12331,2208,283,238,252,19210,12331,2208,283,201,876,281,198,2364,17152,371,1512,17152,201,301,196,1293,218,472,655,204,17152,340,198,2225,2477,4575,201,7662,4116,225,2773,218,198,1885,4116,209,530,301,5952,3130,227,198,1013,17152,201,282,14,196,760,12898,209,530,977,2156,227,196,2591,218,2838,3171,311,1678,221,310,225,196,2223,218,605,21695,311,1086,4336,310,201,225,10043,196,20778,3289,218,23714,209,357,1151,201,198,7162,9879,390,4071,201,669,3179,239,17152,1994,239,263,12897,209,259,761,6395,221,198,3271,201,6453,7144,400,390,3493,307,198,5138,274,578,227,196],"1":[544,806,6066,2753,587,23664,851,229,17469,209,19210,12331,2208,283,301,196,3879,928,21008,204,2970,201,225,301,3991,6637,1919,17152,224,1637,201,3223,1118,198,1038,11303,209,252,238,238,10150,238,238,252,19210,12331,2208,283,301,196,1178,18183,571,220,201,296,196,1825,2591,578,227,2838,10892,311,1678,221,310,225,19186,578,227,502,459,605,21695,311,920,459,1086,4336,310,201,1360,198,7162,9879,6637,221,17152,224,1637,390,3076,1625,459,3075,3171,311,670,459,916,221,310,908,225,11698,796,511,629,459,289,511,289,4890,311,245,511,502,459,493,511,670,4336,310,209,4064,548,18183],"2":[571,738,201,7162,9879,490,196,2804,377,316,559,8605,400,543,2686,14864,221,1688,227,2156,201,221,196,2169,994,5080,68,1162,215,311,24820,1236,310,209,795,977,2100,951,1696,196,544,274,2246,7162,9879,201,445,24202,227,2109,1621,245,459,289,818,274,3097,5503,209,252,261,455,3289,218,224,350,73,343,11051,301,4579,296,196,1178,201,20234,4371,3688,3289,218,23714,209,261,3097,496,301,198,239,1643,17959,239,201,225,499,12466,24880,5571,788,274,1643,5003,6490,440,198,548,301,198,239,2439,280,239,201,400,499,9629,6639,8864,201,225,301,788,274,6205,371,194,10866,198,6490],"3":[209,4891,957,201,198,1344,472,655,301,198,1643,17959,201,225,198,1557,301,198,2439,280,209,252,261,377,316,559,8605,301,2370,4071,2660,201,296,9915,295,767,1130,298,201,225,4578,4052,209,261,1994,6334,3720,296,7162,9879,669,3440,550,12897,209,795,6395,1080,201,221,1151,201,198,1994,24422,4624,1215,6233,197,301,3460,227,196,4279,3794,201,445,198,3794,301,4752,578,307,198,8126,218,12897,201,11701,198,1994,24422,209,252,261,8445,7339,218,282,14,12331,2208,283,301,198,1013,17152,201,19210,196,760,12898,209,261,506,1293,390,1365,1553,201,225,684,293,6982,982,778,1053,201],"4":[1360,15608,2521,390,9718,227,2100,221,198,4520,1211,444,8564,804,410,14358,209,261,506,1293,684,293,8793,307,196,845,218,6498,466,252,261,494,246,5398,218,282,14,196,760,12898,10043,496,371,604,20338,263,198,24534,201,400,390,10581,221,282,14,12331,2208,283,209,252,261,20338,263,198,23714,218,282,14,196,760,12898,390,1994,371,1994,277,22181,201,732,1413,218,282,14,12331,2208,283,390,1969,371,1969,277,22181,209,252,261,24534,218,198,472,655,218,282,14,196,760,12898,301,12960,371,1994,201,732,295,218,282,14,12331,2208,283,301,942,5676,1969,371,1365,8202,1994,209],"5":[252,238,238,5971,7308,238,238,252,14035,282,14,12331,2208,283,3484,3956,13559,646,543,490,8498,227,196,1044,422,571,2591,218,3105,459,5626,22608,311,428,511,245,459,428,511,428,221,310,201,6844,5020,10719,319,196,4370,3199,2892,209,259,761,4183,6395,221,3271,708,196,5527,24820,744,2476,201,2821,4093,301,4947,6144,201,225,196,2804,277,24110,2956,209,261,2476,9651,198,7144,274,578,227,837,2200,201,8326,263,198,5010,201,6386,227,468,4405,343,11051,209,17172,7600,7144,390,1014,227,293,239,20148,1635,239,225,684,293,824,1895,198,544,209,252,261,7144,15638,319,1948,201,225],"6":[198,17469,9408,227,198,1371,3239,756,543,11642,296,198,7153,20898,201,577,2464,263,22762,8222,858,851,209,795,2570,11533,697,24820,745,225,15964,274,916,459,2380,1651,209,929,198,1368,24820,84,201,198,16695,3544,263,196,740,7727,227,198,4047,201,225,3313,745,196,219,2121,229,13643,209,261,16696,390,6678,1809,221,198,4520,201,225,390,7808,876,201,1360,543,390,876,227,293,5992,218,15590,5536,2515,6846,209,530,301,2729,295,669,245,768,6820,221,1621,464,562,802,14272,227,198,219,2121,229,5427,209,1438,543,3484,196,1044,422,571,2591,218,916,1999,311,796,511,5931,221,310],"7":[201,198,16696,3489,444,2515,6846,225,1046,444,4047,4658,209,252,238,238,14720,238,238,252,19210,12331,2208,283,301,824,1896,198,1128,277,2225,2477,4575,340,2518,7984,227,198,14780,225,22054,201,410,867,198,18294,4116,209,530,301,471,1474,221,659,218,198,7662,4116,201,669,6157,340,198,3161,1657,218,4968,563,201,225,1008,669,198,1128,277,1639,2076,218,198,1885,4116,209,261,2518,1942,5892,390,824,221,198,10768,214,74,1907,236,1162,70,74,412,213,225,18700,17840,65,201,3950,198,21427,9743,234,209,252,261,1293,684,293,4695,587,784,1851,269,1320,4344,5892,201,496,6098,1866],"8":[201,225,697,400,490,9068,1166,1309,227,1119,5501,1866,15321,201,4692,1309,227,8552,227,198,1492,5039,209,261,455,218,1021,301,198,1866,218,7162,9879,340,2518,7984,201,400,490,504,3242,227,281,198,239,16501,277,6271,17152,239,209,261,5892,221,198,7662,4116,390,4344,340,1413,221,198,2477,4575,209,261,1244,4344,1866,301,824,221,198,7533,466,12828,340,198,339,419,320,4651,352,69,360,4344,340,1413,5464,221,198,1182,4116,371,1545,7484,209,252,24135,83,490,504,746,227,14275,282,14,12331,2208,283,227,752,2952,201,4583,548,2364,1293,781,281,198,15494,21089,201,24872,224,330],"9":[260,283,209,5507,11329,225,5989,201,496,1079,17152,17469,360,1055,340,6066,2264,346,221,5222,204,197,201,445,198,1293,992,410,1638,1860,817,209,252,238,238,19976,238,238,252,11312,282,14,12331,2208,283,2253,263,198,10077,22975,319,21762,218,796,459,4178,3224,311,796,459,4468,18,1841,310,201,1360,410,7736,15408,702,1834,230,311,5722,1841,310,209,1065,5463,2804,8138,7075,201,781,281,10578,371,2804,11832,201,225,2253,221,12782,371,942,86,1656,201,16062,319,1948,227,7710,209,252,261,7546,218,282,14,12331,2208,283,3223,5886,218,548,219,2121,229,18499,209,1850,1766,6063,1466,201],"10":[230,1894,13799,83,201,3304,159,885,776,201,1322,11207,225,4291,18680,16328,21077,209,252,261,697,472,655,204,17152,1293,19210,12331,2208,283,201,282,14,196,760,12898,225,1853,575,635,83,5746,711,71,10491,390,12202,227,198,697,876,1293,218,198,4926,1091,930,327,247,2236,73,2401,4403,440,198,1293,263,282,14,12331,2208,283,499,410,504,1246,209,252,19210,12331,2208,283,301,22081,227,198,4516,267,1867,4219,7024,201,2035,307,198,14477,1026,5611,375,13489,283,9769,271,738,209,1647,322,301,5431,824,221,1013,7162,9879,201,198,4516,499,669,504,1809,221,19676,282,14,12331,2208,283],"11":[201,756,3119,9025,218,198,8284,307,282,14,196,760,12898,990,410,293,6620,573,209,252,238,238,6870,11391,238,238,252,19210,12331,2208,283,301,13924,239,3879,928,21008,204,239,281,196,2970,246,3760,225,258,4478,221,239,261,10835,11207,239,196,18409,926,1545,4676,647,209,530,977,13762,309,1365,823,15210,225,977,293,1574,5739,201,16791,7010,201,684,2457,371,4192,68,665,209,3965,198,23714,225,198,22051,218,282,14,12331,2208,283,4607,239,13123,239,1969,12394,201,225,659,218,198,15891,218,198,2374,15444,679,210,1215,390,15494,209,261,17588,390,198,21202,523,955,225,198,239,4936],"12":[22774,239,311,21201,310,209,261,11588,218,282,14,12331,2208,283,301,578,227,697,1696,4014,702,295,218,282,14,196,760,12898,201,225,198,2364,1293,301,1535,227,490,196,3077,12693,418,209,252,317,706,9879,390,3223,214,736,1919,17152,224,1637,201,1360,2875,22218,616,296,11820,20305,371,2439,84,234,11207,3019,8900,221,22179,223,642,573,201,227,1313,642,227,293,6637,221,196,8946,371,307,1414,209,357,1465,201,493,562,3075,22,194,218,282,14,12331,2208,283,360,6637,1896,1399,225,1182,3054,201,218,400,428,562,4683,18,194,311,6600,1124,310,258,6637,221,198,1038,11303,311],"13":[867,198,7484,7362,310,209,261,8240,5360,2892,274,282,14,12331,2208,283,301,196,1044,422,571,2591,218,6286,1999,311,428,511,493,221,310,209,252,13773,257,21384,5135,274,282,14,12331,2208,283,390,699,1921,201,225,1743,7928,390,1490,1365,1669,209,252,238,238,8937,16097,1514,238,238,252,19210,12331,2208,283,258,455,1782,196,12072,22267,1221,307,5064,317,909,19649,221,198,8813,4615,218,342,5599,65,318,4881,1240,201,1573,221,5986,24,209,2881,1221,258,24872,12331,2208,283,201,1211,317,909,19649,268,3301,218,198,4013,24872,319,295,582,1331,610,1178,18183,571,738,209,252,282,14],"14":[12331,2208,283,301,198,3108,1293,218,198,4013,19210,17518,201,20690,201,281,4729,307,300,4340,3945,218,198,2392,4774,263,16455,1839,318,249,22701,1411,209,7666,227,295,4568,201,11527,12061,1080,198,1293,398,504,3242,227,307,951,1579,2850,201,867,5265,257,283,1261,4487,14166,523,2215,201,22956,225,19210,18381,215,282,14,3248,1775,277,18903,201,14300,201,225,471,1080,14415,17518,268,83,5278,218,198,4013,398,504,17228,1040,1994,22708,307,2489,297,633,66,302,201,19912,1061,3119,13525,967,218,3108,1293,311,274,19210,282,14,3248,1775,277,18903,201,14300,310,21468,274,19210,17518,201,20690,209,252],"15":[261,3108,7139,218,19210,12331,2208,283,258,196,10529,7769,3941,307,23601,559,3059,299,85,215,221,5629,209,530,1501,340,5197,2719,4800,396,110,318,920,2719,3434,396,110,338,201,996,681,246,1245,201,9704,311,4067,5549,371,1342,1957,4587,218,9901,787,6012,310,201,445,857,322,225,198,1346,222,395,14403,490,1211,504,1777,209,252,261,1512,1221,274,282,14,12331,2208,283,7676,307,198,9758,225,15124,447,8473,301,239,2364,17152,239,201,445,198,1293,301,471,3991,876,281,198,239,1512,17152,239,209,252,238,3180,5577,4593,238,252,1934,8314,12825,3180,5577,4593,201,17920,201,17920,37]},"attention_mask":{"0":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"1":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"2":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"3":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"4":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"5":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"6":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"7":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"8":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"9":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"10":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"11":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"12":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"13":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"14":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"15":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]},"labels":{"0":[238,19210,12331,2208,283,238,252,19210,12331,2208,283,201,876,281,198,2364,17152,371,1512,17152,201,301,196,1293,218,472,655,204,17152,340,198,2225,2477,4575,201,7662,4116,225,2773,218,198,1885,4116,209,530,301,5952,3130,227,198,1013,17152,201,282,14,196,760,12898,209,530,977,2156,227,196,2591,218,2838,3171,311,1678,221,310,225,196,2223,218,605,21695,311,1086,4336,310,201,225,10043,196,20778,3289,218,23714,209,357,1151,201,198,7162,9879,390,4071,201,669,3179,239,17152,1994,239,263,12897,209,259,761,6395,221,198,3271,201,6453,7144,400,390,3493,307,198,5138,274,578,227,196],"1":[544,806,6066,2753,587,23664,851,229,17469,209,19210,12331,2208,283,301,196,3879,928,21008,204,2970,201,225,301,3991,6637,1919,17152,224,1637,201,3223,1118,198,1038,11303,209,252,238,238,10150,238,238,252,19210,12331,2208,283,301,196,1178,18183,571,220,201,296,196,1825,2591,578,227,2838,10892,311,1678,221,310,225,19186,578,227,502,459,605,21695,311,920,459,1086,4336,310,201,1360,198,7162,9879,6637,221,17152,224,1637,390,3076,1625,459,3075,3171,311,670,459,916,221,310,908,225,11698,796,511,629,459,289,511,289,4890,311,245,511,502,459,493,511,670,4336,310,209,4064,548,18183],"2":[571,738,201,7162,9879,490,196,2804,377,316,559,8605,400,543,2686,14864,221,1688,227,2156,201,221,196,2169,994,5080,68,1162,215,311,24820,1236,310,209,795,977,2100,951,1696,196,544,274,2246,7162,9879,201,445,24202,227,2109,1621,245,459,289,818,274,3097,5503,209,252,261,455,3289,218,224,350,73,343,11051,301,4579,296,196,1178,201,20234,4371,3688,3289,218,23714,209,261,3097,496,301,198,239,1643,17959,239,201,225,499,12466,24880,5571,788,274,1643,5003,6490,440,198,548,301,198,239,2439,280,239,201,400,499,9629,6639,8864,201,225,301,788,274,6205,371,194,10866,198,6490],"3":[209,4891,957,201,198,1344,472,655,301,198,1643,17959,201,225,198,1557,301,198,2439,280,209,252,261,377,316,559,8605,301,2370,4071,2660,201,296,9915,295,767,1130,298,201,225,4578,4052,209,261,1994,6334,3720,296,7162,9879,669,3440,550,12897,209,795,6395,1080,201,221,1151,201,198,1994,24422,4624,1215,6233,197,301,3460,227,196,4279,3794,201,445,198,3794,301,4752,578,307,198,8126,218,12897,201,11701,198,1994,24422,209,252,261,8445,7339,218,282,14,12331,2208,283,301,198,1013,17152,201,19210,196,760,12898,209,261,506,1293,390,1365,1553,201,225,684,293,6982,982,778,1053,201],"4":[1360,15608,2521,390,9718,227,2100,221,198,4520,1211,444,8564,804,410,14358,209,261,506,1293,684,293,8793,307,196,845,218,6498,466,252,261,494,246,5398,218,282,14,196,760,12898,10043,496,371,604,20338,263,198,24534,201,400,390,10581,221,282,14,12331,2208,283,209,252,261,20338,263,198,23714,218,282,14,196,760,12898,390,1994,371,1994,277,22181,201,732,1413,218,282,14,12331,2208,283,390,1969,371,1969,277,22181,209,252,261,24534,218,198,472,655,218,282,14,196,760,12898,301,12960,371,1994,201,732,295,218,282,14,12331,2208,283,301,942,5676,1969,371,1365,8202,1994,209],"5":[252,238,238,5971,7308,238,238,252,14035,282,14,12331,2208,283,3484,3956,13559,646,543,490,8498,227,196,1044,422,571,2591,218,3105,459,5626,22608,311,428,511,245,459,428,511,428,221,310,201,6844,5020,10719,319,196,4370,3199,2892,209,259,761,4183,6395,221,3271,708,196,5527,24820,744,2476,201,2821,4093,301,4947,6144,201,225,196,2804,277,24110,2956,209,261,2476,9651,198,7144,274,578,227,837,2200,201,8326,263,198,5010,201,6386,227,468,4405,343,11051,209,17172,7600,7144,390,1014,227,293,239,20148,1635,239,225,684,293,824,1895,198,544,209,252,261,7144,15638,319,1948,201,225],"6":[198,17469,9408,227,198,1371,3239,756,543,11642,296,198,7153,20898,201,577,2464,263,22762,8222,858,851,209,795,2570,11533,697,24820,745,225,15964,274,916,459,2380,1651,209,929,198,1368,24820,84,201,198,16695,3544,263,196,740,7727,227,198,4047,201,225,3313,745,196,219,2121,229,13643,209,261,16696,390,6678,1809,221,198,4520,201,225,390,7808,876,201,1360,543,390,876,227,293,5992,218,15590,5536,2515,6846,209,530,301,2729,295,669,245,768,6820,221,1621,464,562,802,14272,227,198,219,2121,229,5427,209,1438,543,3484,196,1044,422,571,2591,218,916,1999,311,796,511,5931,221,310],"7":[201,198,16696,3489,444,2515,6846,225,1046,444,4047,4658,209,252,238,238,14720,238,238,252,19210,12331,2208,283,301,824,1896,198,1128,277,2225,2477,4575,340,2518,7984,227,198,14780,225,22054,201,410,867,198,18294,4116,209,530,301,471,1474,221,659,218,198,7662,4116,201,669,6157,340,198,3161,1657,218,4968,563,201,225,1008,669,198,1128,277,1639,2076,218,198,1885,4116,209,261,2518,1942,5892,390,824,221,198,10768,214,74,1907,236,1162,70,74,412,213,225,18700,17840,65,201,3950,198,21427,9743,234,209,252,261,1293,684,293,4695,587,784,1851,269,1320,4344,5892,201,496,6098,1866],"8":[201,225,697,400,490,9068,1166,1309,227,1119,5501,1866,15321,201,4692,1309,227,8552,227,198,1492,5039,209,261,455,218,1021,301,198,1866,218,7162,9879,340,2518,7984,201,400,490,504,3242,227,281,198,239,16501,277,6271,17152,239,209,261,5892,221,198,7662,4116,390,4344,340,1413,221,198,2477,4575,209,261,1244,4344,1866,301,824,221,198,7533,466,12828,340,198,339,419,320,4651,352,69,360,4344,340,1413,5464,221,198,1182,4116,371,1545,7484,209,252,24135,83,490,504,746,227,14275,282,14,12331,2208,283,227,752,2952,201,4583,548,2364,1293,781,281,198,15494,21089,201,24872,224,330],"9":[260,283,209,5507,11329,225,5989,201,496,1079,17152,17469,360,1055,340,6066,2264,346,221,5222,204,197,201,445,198,1293,992,410,1638,1860,817,209,252,238,238,19976,238,238,252,11312,282,14,12331,2208,283,2253,263,198,10077,22975,319,21762,218,796,459,4178,3224,311,796,459,4468,18,1841,310,201,1360,410,7736,15408,702,1834,230,311,5722,1841,310,209,1065,5463,2804,8138,7075,201,781,281,10578,371,2804,11832,201,225,2253,221,12782,371,942,86,1656,201,16062,319,1948,227,7710,209,252,261,7546,218,282,14,12331,2208,283,3223,5886,218,548,219,2121,229,18499,209,1850,1766,6063,1466,201],"10":[230,1894,13799,83,201,3304,159,885,776,201,1322,11207,225,4291,18680,16328,21077,209,252,261,697,472,655,204,17152,1293,19210,12331,2208,283,201,282,14,196,760,12898,225,1853,575,635,83,5746,711,71,10491,390,12202,227,198,697,876,1293,218,198,4926,1091,930,327,247,2236,73,2401,4403,440,198,1293,263,282,14,12331,2208,283,499,410,504,1246,209,252,19210,12331,2208,283,301,22081,227,198,4516,267,1867,4219,7024,201,2035,307,198,14477,1026,5611,375,13489,283,9769,271,738,209,1647,322,301,5431,824,221,1013,7162,9879,201,198,4516,499,669,504,1809,221,19676,282,14,12331,2208,283],"11":[201,756,3119,9025,218,198,8284,307,282,14,196,760,12898,990,410,293,6620,573,209,252,238,238,6870,11391,238,238,252,19210,12331,2208,283,301,13924,239,3879,928,21008,204,239,281,196,2970,246,3760,225,258,4478,221,239,261,10835,11207,239,196,18409,926,1545,4676,647,209,530,977,13762,309,1365,823,15210,225,977,293,1574,5739,201,16791,7010,201,684,2457,371,4192,68,665,209,3965,198,23714,225,198,22051,218,282,14,12331,2208,283,4607,239,13123,239,1969,12394,201,225,659,218,198,15891,218,198,2374,15444,679,210,1215,390,15494,209,261,17588,390,198,21202,523,955,225,198,239,4936],"12":[22774,239,311,21201,310,209,261,11588,218,282,14,12331,2208,283,301,578,227,697,1696,4014,702,295,218,282,14,196,760,12898,201,225,198,2364,1293,301,1535,227,490,196,3077,12693,418,209,252,317,706,9879,390,3223,214,736,1919,17152,224,1637,201,1360,2875,22218,616,296,11820,20305,371,2439,84,234,11207,3019,8900,221,22179,223,642,573,201,227,1313,642,227,293,6637,221,196,8946,371,307,1414,209,357,1465,201,493,562,3075,22,194,218,282,14,12331,2208,283,360,6637,1896,1399,225,1182,3054,201,218,400,428,562,4683,18,194,311,6600,1124,310,258,6637,221,198,1038,11303,311],"13":[867,198,7484,7362,310,209,261,8240,5360,2892,274,282,14,12331,2208,283,301,196,1044,422,571,2591,218,6286,1999,311,428,511,493,221,310,209,252,13773,257,21384,5135,274,282,14,12331,2208,283,390,699,1921,201,225,1743,7928,390,1490,1365,1669,209,252,238,238,8937,16097,1514,238,238,252,19210,12331,2208,283,258,455,1782,196,12072,22267,1221,307,5064,317,909,19649,221,198,8813,4615,218,342,5599,65,318,4881,1240,201,1573,221,5986,24,209,2881,1221,258,24872,12331,2208,283,201,1211,317,909,19649,268,3301,218,198,4013,24872,319,295,582,1331,610,1178,18183,571,738,209,252,282,14],"14":[12331,2208,283,301,198,3108,1293,218,198,4013,19210,17518,201,20690,201,281,4729,307,300,4340,3945,218,198,2392,4774,263,16455,1839,318,249,22701,1411,209,7666,227,295,4568,201,11527,12061,1080,198,1293,398,504,3242,227,307,951,1579,2850,201,867,5265,257,283,1261,4487,14166,523,2215,201,22956,225,19210,18381,215,282,14,3248,1775,277,18903,201,14300,201,225,471,1080,14415,17518,268,83,5278,218,198,4013,398,504,17228,1040,1994,22708,307,2489,297,633,66,302,201,19912,1061,3119,13525,967,218,3108,1293,311,274,19210,282,14,3248,1775,277,18903,201,14300,310,21468,274,19210,17518,201,20690,209,252],"15":[261,3108,7139,218,19210,12331,2208,283,258,196,10529,7769,3941,307,23601,559,3059,299,85,215,221,5629,209,530,1501,340,5197,2719,4800,396,110,318,920,2719,3434,396,110,338,201,996,681,246,1245,201,9704,311,4067,5549,371,1342,1957,4587,218,9901,787,6012,310,201,445,857,322,225,198,1346,222,395,14403,490,1211,504,1777,209,252,261,1512,1221,274,282,14,12331,2208,283,7676,307,198,9758,225,15124,447,8473,301,239,2364,17152,239,201,445,198,1293,301,471,3991,876,281,198,239,1512,17152,239,209,252,238,3180,5577,4593,238,252,1934,8314,12825,3180,5577,4593,201,17920,201,17920,37]}} +""" diff --git a/python/ray/ml/tests/test_huggingface_predictor.py b/python/ray/ml/tests/test_huggingface_predictor.py new file mode 100644 index 000000000000..7d4dd19404fc --- /dev/null +++ b/python/ray/ml/tests/test_huggingface_predictor.py @@ -0,0 +1,45 @@ +import pandas as pd +import pytest +from transformers import AutoConfig, AutoModelForCausalLM, TrainingArguments + +from ray.ml.preprocessor import Preprocessor +from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor + +from ._huggingface_data import validation_data + +# 16 first rows of tokenized wikitext-2-raw-v1 validation +validation_df = pd.read_json(validation_data) + +# We are only testing Casual Language Modelling here + +model_checkpoint = "sshleifer/tiny-gpt2" + + +class DummyPreprocessor(Preprocessor): + def transform_batch(self, df): + self._batch_transformed = True + return df + + +@pytest.mark.parametrize("preprocessor", [True, False]) +@pytest.mark.parametrize("training_args", [True, False]) +def test_predict(preprocessor, training_args, tmpdir): + if preprocessor: + preprocessor = DummyPreprocessor() + else: + preprocessor = None + if training_args: + training_args = TrainingArguments(tmpdir) + else: + training_args = None + model_config = AutoConfig.from_pretrained(model_checkpoint) + model = AutoModelForCausalLM.from_config(model_config) + predictor = HuggingFacePredictor( + model=model, preprocessor=preprocessor, training_args=training_args + ) + + predictions = predictor.predict(validation_df) + + assert len(predictions) == 16 + if preprocessor: + assert hasattr(predictor.preprocessor, "_batch_transformed") diff --git a/python/ray/ml/tests/test_huggingface_trainer.py b/python/ray/ml/tests/test_huggingface_trainer.py index 4660e4a3ae54..e492f972c90d 100644 --- a/python/ray/ml/tests/test_huggingface_trainer.py +++ b/python/ray/ml/tests/test_huggingface_trainer.py @@ -8,14 +8,9 @@ from ray.ml.train.integrations.huggingface import HuggingFaceTrainer from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor from ray.ml.train.integrations.huggingface.huggingface_utils import process_datasets +from ray.ml.batch_predictor import BatchPredictor -train_data = """ -{"input_ids":{"0":[238,8576,9441,2987,238,252,4657,74,4762,826,8576,428,466,609,6881,412,204,9441,311,2746,466,10816,168,99,150,192,112,14328,3983,112,4446,94,18288,4446,193,3983,98,3983,22171,95,19,201,6374,209,8576,218,198,3455,1972,428,310,201,5099,3242,227,281,8576,9441,2987,2553,1759,201,301,196,13996,1496,277,2330,1464,674,1898,307,742,3541,225,7514,14,54,719,274,198,4777,15522,209,19895,221,1341,1633,221,1759,201,322,301,198,1368,674,221,198,8576,843,209,2468,1795,223,198,1049,9595,218,13996,225,1563,277,582,6493,281,457,14371,201,198,1422,3373,7452,227,198,455,674,225,4687],"1":[198,239,21976,239,201,196,21657,1680,3773,5591,198,4196,218,4679,427,661,198,3518,1288,220,1051,516,889,3947,1922,2500,225,390,2065,744,872,198,7592,3773,239,1975,251,208,89,22351,239,209,252,261,674,959,1921,221,1462,201,7600,547,196,1178,4753,218,198,630,3591,263,8576,9441,1180,209,1831,322,7568,198,3621,2240,218,198,843,201,322,471,9575,5291,16591,967,201,781,281,1815,198,674,604,10344,1252,274,843,664,3147,320,209,13290,8751,8124,2528,6023,74,235,225,7445,10040,17384,241,11487,8950,857,1835,340,1382,22582,201,1008,296,8576,9441,1180,2436,21134,5337,19463,5161,209,240,1178],"2":[927,218,3776,8650,198,3355,209,261,674,268,83,2511,3472,258,8288,307,1010,268,78,209,252,530,922,296,3096,4354,221,1759,201,225,258,2630,307,857,2746,225,2761,2234,209,929,1634,201,322,1217,13362,4798,201,1008,296,326,5585,4615,221,1389,218,295,544,209,530,258,471,6459,587,8630,225,326,1446,1464,10260,843,209,5023,227,1669,4354,218,8576,9441,1180,201,8576,9441,2987,258,410,18878,201,445,196,6418,8895,20885,296,198,674,268,83,5585,4615,258,1055,221,2474,209,7514,14,54,719,602,1063,227,198,7072,296,198,1921,218,8576,466,8111,447,5474,274,198,4777,493,209,252],"3":[238,238,15114,238,238,252,662,296,1382,7784,9345,9441,1374,201,8576,9441,2987,301,196,13996,1496,277,2330,674,756,1997,1581,1694,218,196,1680,3773,225,1581,540,221,6455,872,5000,1754,209,8410,390,2848,718,5926,1184,277,1093,13659,296,10102,1136,13018,201,296,2081,6041,8040,718,7674,6917,14877,6632,225,8040,718,434,15744,5263,3973,209,261,1538,15606,718,196,843,218,11382,6455,201,5546,13230,5382,281,12903,295,684,293,15086,518,6145,718,225,1486,13196,281,543,390,13230,5382,209,261,2262,227,1140,1422,3371,263,198,7464,11718,8326,263,326,3854,1538,268,83,4156,466,646,496,6938,301,3941,201],"4":[198,548,301,15410,555,227,198,1538,209,13067,6455,201,198,1538,2081,1231,221,196,1306,201,756,2726,684,293,7376,968,225,1136,5168,6395,209,24281,198,906,1422,6455,390,1136,277,2737,747,6455,13516,227,1579,5860,1718,209,929,198,674,268,83,7131,201,3079,3411,390,13230,5382,201,766,218,642,1763,196,4014,6669,702,1413,824,221,198,1231,218,198,674,209,1415,390,471,2610,1076,1388,2377,3130,227,198,674,268,83,506,906,3641,828,201,1360,543,1581,196,1365,2978,1496,209,252,261,674,268,83,2094,1232,201,198,1028,73,52,58,1232,201,301,3493,547,3660,340,7784,9345,9441,209],"5":[1303,6455,201,1997,6322,1140,3773,1919,196,1362,277,1004,8550,218,198,11332,7464,466,2109,196,1136,301,3941,201,198,1538,5757,198,1136,1118,198,11332,221,1368,277,1432,209,240,1136,684,669,733,2109,503,277,1407,201,445,2081,684,293,4929,5291,6398,319,198,13876,218,548,2081,268,6398,209,5056,1136,499,196,1606,225,3924,218,3672,3069,307,444,11840,315,15041,209,3264,227,2849,2081,684,293,3423,227,196,1314,4490,209,1303,6493,201,2081,1170,798,573,1358,3851,16048,227,642,201,781,281,444,3668,2423,311,282,48,310,6506,1669,371,812,11285,573,307,5000,3704,209,5056,1136,499,2737,239],"6":[8303,2453,83,239,201,8436,5025,227,1140,1136,209,1065,390,4695,587,239,12021,8303,2453,239,201,400,390,3295,365,8436,295,3052,434,222,8830,11270,7936,9067,380,307,198,1422,225,684,2292,1546,371,1017,21071,196,1136,201,225,239,3455,8303,2453,83,239,201,400,390,8498,1895,198,674,225,3365,3326,1958,675,227,196,1136,209,2000,2627,3455,8303,2453,83,201,1140,1136,499,196,5025,239,8119,236,614,239,201,196,14628,277,1356,11949,6230,295,684,293,788,227,13970,225,3261,1579,8436,209,16500,471,490,7447,1531,3004,295,3326,642,8535,1958,11069,263,198,11332,466,8625,684,1980,365,239,4495],"7":[3683,239,225,2341,1118,198,11332,1582,897,234,1236,342,11840,6957,16326,201,198,1136,789,9878,684,9627,587,468,239,8576,3408,239,225,1638,896,23892,201,732,2293,5336,684,4849,5291,5000,2726,296,468,2287,8263,209,252,13441,1355,390,4695,587,1160,6555,466,23804,201,808,722,84,228,19466,201,12785,201,7096,320,225,21755,18607,853,209,13441,19466,684,12454,6555,307,7048,444,3423,8263,209,11070,223,1249,1990,410,5993,11950,198,285,1201,4825,732,221,196,1382,1249,209,1585,2495,221,2094,201,3335,2423,390,3495,227,198,5860,201,400,390,6908,587,1160,1579,13376,6136,307,198,2068,5860,201,196,3135],"8":[1317,223,340,989,1374,268,3588,218,10078,5751,227,1579,3773,4325,209,252,238,238,6851,238,238,252,261,674,3544,1417,661,198,3518,1288,220,1051,209,4679,406,2205,13619,3775,18,201,471,876,281,239,261,21976,239,201,390,196,21657,1680,3773,3808,218,18490,201,4899,541,199,690,201,225,1680,555,16914,2821,1563,2850,390,1843,620,340,198,3517,225,817,200,4614,3242,227,307,3473,209,10730,665,307,198,4679,406,1680,227,889,198,659,8110,6455,295,198,22724,2205,225,18991,1170,410,804,201,543,390,15994,578,227,198,5214,201,20891,307,444,22232,79,201,18766,4209,1531,21157,201,3457,239],"9":[14996,789,2989,209,239,261,697,906,2081,390,1524,14,23,8625,2408,2513,201,326,3006,4417,16271,539,7184,218,15101,516,9461,227,24374,1920,440,8607,1524,14,17,2293,5336,201,196,2476,5447,67,3294,2287,5253,16021,516,17113,14241,872,198,8576,516,3443,468,1398,440,225,1524,14,17,19,297,793,65,20714,215,201,196,11573,536,197,88,204,2246,3063,516,301,4796,2616,5744,196,20713,218,198,8576,209,9332,296,444,5920,5860,1718,201,1021,697,390,10455,227,2172,872,196,12040,7592,3773,876,281,1975,251,379,22351,201,6355,218,3223,5447,67,3294,3478,209,252,662,198,21976,4614,804,410,2523],"10":[201,198,3578,5080,1193,675,218,198,4679,406,2205,21878,198,3301,218,21337,2903,14454,221,1688,227,5355,642,263,6455,295,602,7936,1445,4679,427,7843,3132,221,198,1060,209,1831,319,1696,526,2125,227,444,4517,201,781,281,196,1991,13719,587,7592,3833,201,548,4953,3522,2931,1718,218,198,3775,18,207,1428,16464,209,1539,781,2541,201,315,283,2809,201,4228,749,554,16361,295,331,3669,675,342,1753,225,770,2220,587,198,10100,218,1975,251,379,22351,201,6386,227,198,9541,218,5447,67,3294,5668,3412,307,444,5164,201,8954,971,209,775,198,1049,582,201,2377,1347,4679,406,2205,3683,2341,227],"11":[1843,626,198,21976,221,1688,227,6259,444,1059,7969,209,282,15902,307,857,13855,225,9782,201,225,4439,296,198,3752,218,196,836,3547,1347,444,10100,201,198,3775,18,207,11695,1206,2341,227,3292,3230,10852,732,319,198,1049,582,2172,227,1546,198,4679,406,1060,4585,209,795,4471,1040,198,21976,268,83,14643,4417,201,3790,5491,22737,201,516,398,504,4755,699,2010,5125,201,301,15218,227,198,2867,1047,218,12106,71,337,90,221,1688,227,1474,2525,377,200,199,761,198,4239,89,3478,225,20416,198,1563,836,3547,201,198,4679,406,1796,295,398,7184,8625,218,3430,599,209,252,2304,264,1309,227],"12":[1021,1988,201,225,7259,1309,227,198,1214,5662,221,23734,4679,427,2433,320,2375,198,726,218,198,1060,296,198,4444,201,198,21976,390,3946,196,7944,1676,281,196,5860,221,198,4679,406,2205,2472,702,5228,281,326,13483,21340,2033,209,795,301,1311,277,3364,201,1650,201,281,1149,19033,15786,268,83,3441,201,8954,971,225,1975,251,379,22351,2341,227,1980,365,326,4479,7784,2920,2157,8263,1347,198,4444,201,4755,3947,307,444,5698,17446,209,9663,198,1117,218,19033,15786,371,198,6241,227,8223,3230,221,198,1060,296,4679,427,201,322,301,8954,971,268,83,1244,677,1940,6368,221,5290,196,664,5447],"13":[67,3294,4196,209,662,326,4579,4679,406,2033,896,5021,198,4444,1318,1149,198,506,8947,268,12199,277,1739,602,9684,14048,444,664,11418,4206,201,8625,8420,227,2109,695,1445,342,5860,198,21976,201,9271,22737,227,1343,1920,225,610,699,342,1434,281,2196,277,221,277,2429,209,7243,15215,17157,227,7235,548,702,3230,201,198,3775,18,207,21955,8954,971,225,1503,24815,198,7784,2920,8263,209,5056,2541,862,6084,444,4142,6294,221,1688,227,4555,444,4658,326,361,209,252,238,238,5823,238,238,252,19611,630,274,8576,9441,2987,959,550,1921,2393,263,8576,9441,1180,221,989,1462,201,296,1971,1921,3053],"14":[4691,550,526,209,261,2436,218,8576,9441,1180,201,21134,5337,19463,5161,201,1835,227,295,1496,274,8576,9441,2987,209,5823,630,1260,3032,496,544,209,929,198,1634,218,8576,9441,1180,201,198,3154,1260,196,1976,319,857,198,1737,3092,274,198,674,225,1241,543,2512,227,804,1533,274,198,843,209,4064,457,8472,201,8576,9441,2987,258,1898,274,4777,15522,466,526,258,1309,227,198,927,12182,227,1308,415,198,16590,2054,274,8576,9441,1180,201,225,543,398,410,3029,578,296,198,239,12252,239,3213,295,602,13884,196,664,6494,274,198,4777,428,209,11929,1006,221,326,2692,201,322,258,1821,295],"15":[198,1921,927,1535,8576,9441,2987,227,293,198,843,268,455,4099,7056,466,732,8576,9441,1180,398,3034,196,1178,3126,218,5429,225,12738,661,1921,1309,227,198,5770,2341,201,198,1368,674,1706,642,196,6241,227,7127,2451,198,1386,2773,218,8576,9441,1180,1309,227,812,263,198,1049,5770,209,357,1513,227,742,3541,3154,340,198,1382,1374,201,1921,630,258,471,8650,307,7514,14,54,719,14,261,1446,13395,258,1617,15415,9455,558,9884,706,69,201,732,198,3355,258,1617,307,11856,1652,9455,12645,74,5435,201,11490,20317,12329,10190,201,384,424,11300,21187,10198,201,742,14619,9765,368,5161,225,10830,288]},"attention_mask":{"0":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"1":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"2":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"3":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"4":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"5":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"6":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"7":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"8":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"9":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"10":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"11":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"12":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"13":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"14":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"15":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]},"labels":{"0":[238,8576,9441,2987,238,252,4657,74,4762,826,8576,428,466,609,6881,412,204,9441,311,2746,466,10816,168,99,150,192,112,14328,3983,112,4446,94,18288,4446,193,3983,98,3983,22171,95,19,201,6374,209,8576,218,198,3455,1972,428,310,201,5099,3242,227,281,8576,9441,2987,2553,1759,201,301,196,13996,1496,277,2330,1464,674,1898,307,742,3541,225,7514,14,54,719,274,198,4777,15522,209,19895,221,1341,1633,221,1759,201,322,301,198,1368,674,221,198,8576,843,209,2468,1795,223,198,1049,9595,218,13996,225,1563,277,582,6493,281,457,14371,201,198,1422,3373,7452,227,198,455,674,225,4687],"1":[198,239,21976,239,201,196,21657,1680,3773,5591,198,4196,218,4679,427,661,198,3518,1288,220,1051,516,889,3947,1922,2500,225,390,2065,744,872,198,7592,3773,239,1975,251,208,89,22351,239,209,252,261,674,959,1921,221,1462,201,7600,547,196,1178,4753,218,198,630,3591,263,8576,9441,1180,209,1831,322,7568,198,3621,2240,218,198,843,201,322,471,9575,5291,16591,967,201,781,281,1815,198,674,604,10344,1252,274,843,664,3147,320,209,13290,8751,8124,2528,6023,74,235,225,7445,10040,17384,241,11487,8950,857,1835,340,1382,22582,201,1008,296,8576,9441,1180,2436,21134,5337,19463,5161,209,240,1178],"2":[927,218,3776,8650,198,3355,209,261,674,268,83,2511,3472,258,8288,307,1010,268,78,209,252,530,922,296,3096,4354,221,1759,201,225,258,2630,307,857,2746,225,2761,2234,209,929,1634,201,322,1217,13362,4798,201,1008,296,326,5585,4615,221,1389,218,295,544,209,530,258,471,6459,587,8630,225,326,1446,1464,10260,843,209,5023,227,1669,4354,218,8576,9441,1180,201,8576,9441,2987,258,410,18878,201,445,196,6418,8895,20885,296,198,674,268,83,5585,4615,258,1055,221,2474,209,7514,14,54,719,602,1063,227,198,7072,296,198,1921,218,8576,466,8111,447,5474,274,198,4777,493,209,252],"3":[238,238,15114,238,238,252,662,296,1382,7784,9345,9441,1374,201,8576,9441,2987,301,196,13996,1496,277,2330,674,756,1997,1581,1694,218,196,1680,3773,225,1581,540,221,6455,872,5000,1754,209,8410,390,2848,718,5926,1184,277,1093,13659,296,10102,1136,13018,201,296,2081,6041,8040,718,7674,6917,14877,6632,225,8040,718,434,15744,5263,3973,209,261,1538,15606,718,196,843,218,11382,6455,201,5546,13230,5382,281,12903,295,684,293,15086,518,6145,718,225,1486,13196,281,543,390,13230,5382,209,261,2262,227,1140,1422,3371,263,198,7464,11718,8326,263,326,3854,1538,268,83,4156,466,646,496,6938,301,3941,201],"4":[198,548,301,15410,555,227,198,1538,209,13067,6455,201,198,1538,2081,1231,221,196,1306,201,756,2726,684,293,7376,968,225,1136,5168,6395,209,24281,198,906,1422,6455,390,1136,277,2737,747,6455,13516,227,1579,5860,1718,209,929,198,674,268,83,7131,201,3079,3411,390,13230,5382,201,766,218,642,1763,196,4014,6669,702,1413,824,221,198,1231,218,198,674,209,1415,390,471,2610,1076,1388,2377,3130,227,198,674,268,83,506,906,3641,828,201,1360,543,1581,196,1365,2978,1496,209,252,261,674,268,83,2094,1232,201,198,1028,73,52,58,1232,201,301,3493,547,3660,340,7784,9345,9441,209],"5":[1303,6455,201,1997,6322,1140,3773,1919,196,1362,277,1004,8550,218,198,11332,7464,466,2109,196,1136,301,3941,201,198,1538,5757,198,1136,1118,198,11332,221,1368,277,1432,209,240,1136,684,669,733,2109,503,277,1407,201,445,2081,684,293,4929,5291,6398,319,198,13876,218,548,2081,268,6398,209,5056,1136,499,196,1606,225,3924,218,3672,3069,307,444,11840,315,15041,209,3264,227,2849,2081,684,293,3423,227,196,1314,4490,209,1303,6493,201,2081,1170,798,573,1358,3851,16048,227,642,201,781,281,444,3668,2423,311,282,48,310,6506,1669,371,812,11285,573,307,5000,3704,209,5056,1136,499,2737,239],"6":[8303,2453,83,239,201,8436,5025,227,1140,1136,209,1065,390,4695,587,239,12021,8303,2453,239,201,400,390,3295,365,8436,295,3052,434,222,8830,11270,7936,9067,380,307,198,1422,225,684,2292,1546,371,1017,21071,196,1136,201,225,239,3455,8303,2453,83,239,201,400,390,8498,1895,198,674,225,3365,3326,1958,675,227,196,1136,209,2000,2627,3455,8303,2453,83,201,1140,1136,499,196,5025,239,8119,236,614,239,201,196,14628,277,1356,11949,6230,295,684,293,788,227,13970,225,3261,1579,8436,209,16500,471,490,7447,1531,3004,295,3326,642,8535,1958,11069,263,198,11332,466,8625,684,1980,365,239,4495],"7":[3683,239,225,2341,1118,198,11332,1582,897,234,1236,342,11840,6957,16326,201,198,1136,789,9878,684,9627,587,468,239,8576,3408,239,225,1638,896,23892,201,732,2293,5336,684,4849,5291,5000,2726,296,468,2287,8263,209,252,13441,1355,390,4695,587,1160,6555,466,23804,201,808,722,84,228,19466,201,12785,201,7096,320,225,21755,18607,853,209,13441,19466,684,12454,6555,307,7048,444,3423,8263,209,11070,223,1249,1990,410,5993,11950,198,285,1201,4825,732,221,196,1382,1249,209,1585,2495,221,2094,201,3335,2423,390,3495,227,198,5860,201,400,390,6908,587,1160,1579,13376,6136,307,198,2068,5860,201,196,3135],"8":[1317,223,340,989,1374,268,3588,218,10078,5751,227,1579,3773,4325,209,252,238,238,6851,238,238,252,261,674,3544,1417,661,198,3518,1288,220,1051,209,4679,406,2205,13619,3775,18,201,471,876,281,239,261,21976,239,201,390,196,21657,1680,3773,3808,218,18490,201,4899,541,199,690,201,225,1680,555,16914,2821,1563,2850,390,1843,620,340,198,3517,225,817,200,4614,3242,227,307,3473,209,10730,665,307,198,4679,406,1680,227,889,198,659,8110,6455,295,198,22724,2205,225,18991,1170,410,804,201,543,390,15994,578,227,198,5214,201,20891,307,444,22232,79,201,18766,4209,1531,21157,201,3457,239],"9":[14996,789,2989,209,239,261,697,906,2081,390,1524,14,23,8625,2408,2513,201,326,3006,4417,16271,539,7184,218,15101,516,9461,227,24374,1920,440,8607,1524,14,17,2293,5336,201,196,2476,5447,67,3294,2287,5253,16021,516,17113,14241,872,198,8576,516,3443,468,1398,440,225,1524,14,17,19,297,793,65,20714,215,201,196,11573,536,197,88,204,2246,3063,516,301,4796,2616,5744,196,20713,218,198,8576,209,9332,296,444,5920,5860,1718,201,1021,697,390,10455,227,2172,872,196,12040,7592,3773,876,281,1975,251,379,22351,201,6355,218,3223,5447,67,3294,3478,209,252,662,198,21976,4614,804,410,2523],"10":[201,198,3578,5080,1193,675,218,198,4679,406,2205,21878,198,3301,218,21337,2903,14454,221,1688,227,5355,642,263,6455,295,602,7936,1445,4679,427,7843,3132,221,198,1060,209,1831,319,1696,526,2125,227,444,4517,201,781,281,196,1991,13719,587,7592,3833,201,548,4953,3522,2931,1718,218,198,3775,18,207,1428,16464,209,1539,781,2541,201,315,283,2809,201,4228,749,554,16361,295,331,3669,675,342,1753,225,770,2220,587,198,10100,218,1975,251,379,22351,201,6386,227,198,9541,218,5447,67,3294,5668,3412,307,444,5164,201,8954,971,209,775,198,1049,582,201,2377,1347,4679,406,2205,3683,2341,227],"11":[1843,626,198,21976,221,1688,227,6259,444,1059,7969,209,282,15902,307,857,13855,225,9782,201,225,4439,296,198,3752,218,196,836,3547,1347,444,10100,201,198,3775,18,207,11695,1206,2341,227,3292,3230,10852,732,319,198,1049,582,2172,227,1546,198,4679,406,1060,4585,209,795,4471,1040,198,21976,268,83,14643,4417,201,3790,5491,22737,201,516,398,504,4755,699,2010,5125,201,301,15218,227,198,2867,1047,218,12106,71,337,90,221,1688,227,1474,2525,377,200,199,761,198,4239,89,3478,225,20416,198,1563,836,3547,201,198,4679,406,1796,295,398,7184,8625,218,3430,599,209,252,2304,264,1309,227],"12":[1021,1988,201,225,7259,1309,227,198,1214,5662,221,23734,4679,427,2433,320,2375,198,726,218,198,1060,296,198,4444,201,198,21976,390,3946,196,7944,1676,281,196,5860,221,198,4679,406,2205,2472,702,5228,281,326,13483,21340,2033,209,795,301,1311,277,3364,201,1650,201,281,1149,19033,15786,268,83,3441,201,8954,971,225,1975,251,379,22351,2341,227,1980,365,326,4479,7784,2920,2157,8263,1347,198,4444,201,4755,3947,307,444,5698,17446,209,9663,198,1117,218,19033,15786,371,198,6241,227,8223,3230,221,198,1060,296,4679,427,201,322,301,8954,971,268,83,1244,677,1940,6368,221,5290,196,664,5447],"13":[67,3294,4196,209,662,326,4579,4679,406,2033,896,5021,198,4444,1318,1149,198,506,8947,268,12199,277,1739,602,9684,14048,444,664,11418,4206,201,8625,8420,227,2109,695,1445,342,5860,198,21976,201,9271,22737,227,1343,1920,225,610,699,342,1434,281,2196,277,221,277,2429,209,7243,15215,17157,227,7235,548,702,3230,201,198,3775,18,207,21955,8954,971,225,1503,24815,198,7784,2920,8263,209,5056,2541,862,6084,444,4142,6294,221,1688,227,4555,444,4658,326,361,209,252,238,238,5823,238,238,252,19611,630,274,8576,9441,2987,959,550,1921,2393,263,8576,9441,1180,221,989,1462,201,296,1971,1921,3053],"14":[4691,550,526,209,261,2436,218,8576,9441,1180,201,21134,5337,19463,5161,201,1835,227,295,1496,274,8576,9441,2987,209,5823,630,1260,3032,496,544,209,929,198,1634,218,8576,9441,1180,201,198,3154,1260,196,1976,319,857,198,1737,3092,274,198,674,225,1241,543,2512,227,804,1533,274,198,843,209,4064,457,8472,201,8576,9441,2987,258,1898,274,4777,15522,466,526,258,1309,227,198,927,12182,227,1308,415,198,16590,2054,274,8576,9441,1180,201,225,543,398,410,3029,578,296,198,239,12252,239,3213,295,602,13884,196,664,6494,274,198,4777,428,209,11929,1006,221,326,2692,201,322,258,1821,295],"15":[198,1921,927,1535,8576,9441,2987,227,293,198,843,268,455,4099,7056,466,732,8576,9441,1180,398,3034,196,1178,3126,218,5429,225,12738,661,1921,1309,227,198,5770,2341,201,198,1368,674,1706,642,196,6241,227,7127,2451,198,1386,2773,218,8576,9441,1180,1309,227,812,263,198,1049,5770,209,357,1513,227,742,3541,3154,340,198,1382,1374,201,1921,630,258,471,8650,307,7514,14,54,719,14,261,1446,13395,258,1617,15415,9455,558,9884,706,69,201,732,198,3355,258,1617,307,11856,1652,9455,12645,74,5435,201,11490,20317,12329,10190,201,384,424,11300,21187,10198,201,742,14619,9765,368,5161,225,10830,288]}} -""" - -validation_data = """ -{"input_ids":{"0":[238,19210,12331,2208,283,238,252,19210,12331,2208,283,201,876,281,198,2364,17152,371,1512,17152,201,301,196,1293,218,472,655,204,17152,340,198,2225,2477,4575,201,7662,4116,225,2773,218,198,1885,4116,209,530,301,5952,3130,227,198,1013,17152,201,282,14,196,760,12898,209,530,977,2156,227,196,2591,218,2838,3171,311,1678,221,310,225,196,2223,218,605,21695,311,1086,4336,310,201,225,10043,196,20778,3289,218,23714,209,357,1151,201,198,7162,9879,390,4071,201,669,3179,239,17152,1994,239,263,12897,209,259,761,6395,221,198,3271,201,6453,7144,400,390,3493,307,198,5138,274,578,227,196],"1":[544,806,6066,2753,587,23664,851,229,17469,209,19210,12331,2208,283,301,196,3879,928,21008,204,2970,201,225,301,3991,6637,1919,17152,224,1637,201,3223,1118,198,1038,11303,209,252,238,238,10150,238,238,252,19210,12331,2208,283,301,196,1178,18183,571,220,201,296,196,1825,2591,578,227,2838,10892,311,1678,221,310,225,19186,578,227,502,459,605,21695,311,920,459,1086,4336,310,201,1360,198,7162,9879,6637,221,17152,224,1637,390,3076,1625,459,3075,3171,311,670,459,916,221,310,908,225,11698,796,511,629,459,289,511,289,4890,311,245,511,502,459,493,511,670,4336,310,209,4064,548,18183],"2":[571,738,201,7162,9879,490,196,2804,377,316,559,8605,400,543,2686,14864,221,1688,227,2156,201,221,196,2169,994,5080,68,1162,215,311,24820,1236,310,209,795,977,2100,951,1696,196,544,274,2246,7162,9879,201,445,24202,227,2109,1621,245,459,289,818,274,3097,5503,209,252,261,455,3289,218,224,350,73,343,11051,301,4579,296,196,1178,201,20234,4371,3688,3289,218,23714,209,261,3097,496,301,198,239,1643,17959,239,201,225,499,12466,24880,5571,788,274,1643,5003,6490,440,198,548,301,198,239,2439,280,239,201,400,499,9629,6639,8864,201,225,301,788,274,6205,371,194,10866,198,6490],"3":[209,4891,957,201,198,1344,472,655,301,198,1643,17959,201,225,198,1557,301,198,2439,280,209,252,261,377,316,559,8605,301,2370,4071,2660,201,296,9915,295,767,1130,298,201,225,4578,4052,209,261,1994,6334,3720,296,7162,9879,669,3440,550,12897,209,795,6395,1080,201,221,1151,201,198,1994,24422,4624,1215,6233,197,301,3460,227,196,4279,3794,201,445,198,3794,301,4752,578,307,198,8126,218,12897,201,11701,198,1994,24422,209,252,261,8445,7339,218,282,14,12331,2208,283,301,198,1013,17152,201,19210,196,760,12898,209,261,506,1293,390,1365,1553,201,225,684,293,6982,982,778,1053,201],"4":[1360,15608,2521,390,9718,227,2100,221,198,4520,1211,444,8564,804,410,14358,209,261,506,1293,684,293,8793,307,196,845,218,6498,466,252,261,494,246,5398,218,282,14,196,760,12898,10043,496,371,604,20338,263,198,24534,201,400,390,10581,221,282,14,12331,2208,283,209,252,261,20338,263,198,23714,218,282,14,196,760,12898,390,1994,371,1994,277,22181,201,732,1413,218,282,14,12331,2208,283,390,1969,371,1969,277,22181,209,252,261,24534,218,198,472,655,218,282,14,196,760,12898,301,12960,371,1994,201,732,295,218,282,14,12331,2208,283,301,942,5676,1969,371,1365,8202,1994,209],"5":[252,238,238,5971,7308,238,238,252,14035,282,14,12331,2208,283,3484,3956,13559,646,543,490,8498,227,196,1044,422,571,2591,218,3105,459,5626,22608,311,428,511,245,459,428,511,428,221,310,201,6844,5020,10719,319,196,4370,3199,2892,209,259,761,4183,6395,221,3271,708,196,5527,24820,744,2476,201,2821,4093,301,4947,6144,201,225,196,2804,277,24110,2956,209,261,2476,9651,198,7144,274,578,227,837,2200,201,8326,263,198,5010,201,6386,227,468,4405,343,11051,209,17172,7600,7144,390,1014,227,293,239,20148,1635,239,225,684,293,824,1895,198,544,209,252,261,7144,15638,319,1948,201,225],"6":[198,17469,9408,227,198,1371,3239,756,543,11642,296,198,7153,20898,201,577,2464,263,22762,8222,858,851,209,795,2570,11533,697,24820,745,225,15964,274,916,459,2380,1651,209,929,198,1368,24820,84,201,198,16695,3544,263,196,740,7727,227,198,4047,201,225,3313,745,196,219,2121,229,13643,209,261,16696,390,6678,1809,221,198,4520,201,225,390,7808,876,201,1360,543,390,876,227,293,5992,218,15590,5536,2515,6846,209,530,301,2729,295,669,245,768,6820,221,1621,464,562,802,14272,227,198,219,2121,229,5427,209,1438,543,3484,196,1044,422,571,2591,218,916,1999,311,796,511,5931,221,310],"7":[201,198,16696,3489,444,2515,6846,225,1046,444,4047,4658,209,252,238,238,14720,238,238,252,19210,12331,2208,283,301,824,1896,198,1128,277,2225,2477,4575,340,2518,7984,227,198,14780,225,22054,201,410,867,198,18294,4116,209,530,301,471,1474,221,659,218,198,7662,4116,201,669,6157,340,198,3161,1657,218,4968,563,201,225,1008,669,198,1128,277,1639,2076,218,198,1885,4116,209,261,2518,1942,5892,390,824,221,198,10768,214,74,1907,236,1162,70,74,412,213,225,18700,17840,65,201,3950,198,21427,9743,234,209,252,261,1293,684,293,4695,587,784,1851,269,1320,4344,5892,201,496,6098,1866],"8":[201,225,697,400,490,9068,1166,1309,227,1119,5501,1866,15321,201,4692,1309,227,8552,227,198,1492,5039,209,261,455,218,1021,301,198,1866,218,7162,9879,340,2518,7984,201,400,490,504,3242,227,281,198,239,16501,277,6271,17152,239,209,261,5892,221,198,7662,4116,390,4344,340,1413,221,198,2477,4575,209,261,1244,4344,1866,301,824,221,198,7533,466,12828,340,198,339,419,320,4651,352,69,360,4344,340,1413,5464,221,198,1182,4116,371,1545,7484,209,252,24135,83,490,504,746,227,14275,282,14,12331,2208,283,227,752,2952,201,4583,548,2364,1293,781,281,198,15494,21089,201,24872,224,330],"9":[260,283,209,5507,11329,225,5989,201,496,1079,17152,17469,360,1055,340,6066,2264,346,221,5222,204,197,201,445,198,1293,992,410,1638,1860,817,209,252,238,238,19976,238,238,252,11312,282,14,12331,2208,283,2253,263,198,10077,22975,319,21762,218,796,459,4178,3224,311,796,459,4468,18,1841,310,201,1360,410,7736,15408,702,1834,230,311,5722,1841,310,209,1065,5463,2804,8138,7075,201,781,281,10578,371,2804,11832,201,225,2253,221,12782,371,942,86,1656,201,16062,319,1948,227,7710,209,252,261,7546,218,282,14,12331,2208,283,3223,5886,218,548,219,2121,229,18499,209,1850,1766,6063,1466,201],"10":[230,1894,13799,83,201,3304,159,885,776,201,1322,11207,225,4291,18680,16328,21077,209,252,261,697,472,655,204,17152,1293,19210,12331,2208,283,201,282,14,196,760,12898,225,1853,575,635,83,5746,711,71,10491,390,12202,227,198,697,876,1293,218,198,4926,1091,930,327,247,2236,73,2401,4403,440,198,1293,263,282,14,12331,2208,283,499,410,504,1246,209,252,19210,12331,2208,283,301,22081,227,198,4516,267,1867,4219,7024,201,2035,307,198,14477,1026,5611,375,13489,283,9769,271,738,209,1647,322,301,5431,824,221,1013,7162,9879,201,198,4516,499,669,504,1809,221,19676,282,14,12331,2208,283],"11":[201,756,3119,9025,218,198,8284,307,282,14,196,760,12898,990,410,293,6620,573,209,252,238,238,6870,11391,238,238,252,19210,12331,2208,283,301,13924,239,3879,928,21008,204,239,281,196,2970,246,3760,225,258,4478,221,239,261,10835,11207,239,196,18409,926,1545,4676,647,209,530,977,13762,309,1365,823,15210,225,977,293,1574,5739,201,16791,7010,201,684,2457,371,4192,68,665,209,3965,198,23714,225,198,22051,218,282,14,12331,2208,283,4607,239,13123,239,1969,12394,201,225,659,218,198,15891,218,198,2374,15444,679,210,1215,390,15494,209,261,17588,390,198,21202,523,955,225,198,239,4936],"12":[22774,239,311,21201,310,209,261,11588,218,282,14,12331,2208,283,301,578,227,697,1696,4014,702,295,218,282,14,196,760,12898,201,225,198,2364,1293,301,1535,227,490,196,3077,12693,418,209,252,317,706,9879,390,3223,214,736,1919,17152,224,1637,201,1360,2875,22218,616,296,11820,20305,371,2439,84,234,11207,3019,8900,221,22179,223,642,573,201,227,1313,642,227,293,6637,221,196,8946,371,307,1414,209,357,1465,201,493,562,3075,22,194,218,282,14,12331,2208,283,360,6637,1896,1399,225,1182,3054,201,218,400,428,562,4683,18,194,311,6600,1124,310,258,6637,221,198,1038,11303,311],"13":[867,198,7484,7362,310,209,261,8240,5360,2892,274,282,14,12331,2208,283,301,196,1044,422,571,2591,218,6286,1999,311,428,511,493,221,310,209,252,13773,257,21384,5135,274,282,14,12331,2208,283,390,699,1921,201,225,1743,7928,390,1490,1365,1669,209,252,238,238,8937,16097,1514,238,238,252,19210,12331,2208,283,258,455,1782,196,12072,22267,1221,307,5064,317,909,19649,221,198,8813,4615,218,342,5599,65,318,4881,1240,201,1573,221,5986,24,209,2881,1221,258,24872,12331,2208,283,201,1211,317,909,19649,268,3301,218,198,4013,24872,319,295,582,1331,610,1178,18183,571,738,209,252,282,14],"14":[12331,2208,283,301,198,3108,1293,218,198,4013,19210,17518,201,20690,201,281,4729,307,300,4340,3945,218,198,2392,4774,263,16455,1839,318,249,22701,1411,209,7666,227,295,4568,201,11527,12061,1080,198,1293,398,504,3242,227,307,951,1579,2850,201,867,5265,257,283,1261,4487,14166,523,2215,201,22956,225,19210,18381,215,282,14,3248,1775,277,18903,201,14300,201,225,471,1080,14415,17518,268,83,5278,218,198,4013,398,504,17228,1040,1994,22708,307,2489,297,633,66,302,201,19912,1061,3119,13525,967,218,3108,1293,311,274,19210,282,14,3248,1775,277,18903,201,14300,310,21468,274,19210,17518,201,20690,209,252],"15":[261,3108,7139,218,19210,12331,2208,283,258,196,10529,7769,3941,307,23601,559,3059,299,85,215,221,5629,209,530,1501,340,5197,2719,4800,396,110,318,920,2719,3434,396,110,338,201,996,681,246,1245,201,9704,311,4067,5549,371,1342,1957,4587,218,9901,787,6012,310,201,445,857,322,225,198,1346,222,395,14403,490,1211,504,1777,209,252,261,1512,1221,274,282,14,12331,2208,283,7676,307,198,9758,225,15124,447,8473,301,239,2364,17152,239,201,445,198,1293,301,471,3991,876,281,198,239,1512,17152,239,209,252,238,3180,5577,4593,238,252,1934,8314,12825,3180,5577,4593,201,17920,201,17920,37]},"attention_mask":{"0":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"1":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"2":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"3":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"4":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"5":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"6":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"7":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"8":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"9":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"10":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"11":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"12":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"13":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"14":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"15":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]},"labels":{"0":[238,19210,12331,2208,283,238,252,19210,12331,2208,283,201,876,281,198,2364,17152,371,1512,17152,201,301,196,1293,218,472,655,204,17152,340,198,2225,2477,4575,201,7662,4116,225,2773,218,198,1885,4116,209,530,301,5952,3130,227,198,1013,17152,201,282,14,196,760,12898,209,530,977,2156,227,196,2591,218,2838,3171,311,1678,221,310,225,196,2223,218,605,21695,311,1086,4336,310,201,225,10043,196,20778,3289,218,23714,209,357,1151,201,198,7162,9879,390,4071,201,669,3179,239,17152,1994,239,263,12897,209,259,761,6395,221,198,3271,201,6453,7144,400,390,3493,307,198,5138,274,578,227,196],"1":[544,806,6066,2753,587,23664,851,229,17469,209,19210,12331,2208,283,301,196,3879,928,21008,204,2970,201,225,301,3991,6637,1919,17152,224,1637,201,3223,1118,198,1038,11303,209,252,238,238,10150,238,238,252,19210,12331,2208,283,301,196,1178,18183,571,220,201,296,196,1825,2591,578,227,2838,10892,311,1678,221,310,225,19186,578,227,502,459,605,21695,311,920,459,1086,4336,310,201,1360,198,7162,9879,6637,221,17152,224,1637,390,3076,1625,459,3075,3171,311,670,459,916,221,310,908,225,11698,796,511,629,459,289,511,289,4890,311,245,511,502,459,493,511,670,4336,310,209,4064,548,18183],"2":[571,738,201,7162,9879,490,196,2804,377,316,559,8605,400,543,2686,14864,221,1688,227,2156,201,221,196,2169,994,5080,68,1162,215,311,24820,1236,310,209,795,977,2100,951,1696,196,544,274,2246,7162,9879,201,445,24202,227,2109,1621,245,459,289,818,274,3097,5503,209,252,261,455,3289,218,224,350,73,343,11051,301,4579,296,196,1178,201,20234,4371,3688,3289,218,23714,209,261,3097,496,301,198,239,1643,17959,239,201,225,499,12466,24880,5571,788,274,1643,5003,6490,440,198,548,301,198,239,2439,280,239,201,400,499,9629,6639,8864,201,225,301,788,274,6205,371,194,10866,198,6490],"3":[209,4891,957,201,198,1344,472,655,301,198,1643,17959,201,225,198,1557,301,198,2439,280,209,252,261,377,316,559,8605,301,2370,4071,2660,201,296,9915,295,767,1130,298,201,225,4578,4052,209,261,1994,6334,3720,296,7162,9879,669,3440,550,12897,209,795,6395,1080,201,221,1151,201,198,1994,24422,4624,1215,6233,197,301,3460,227,196,4279,3794,201,445,198,3794,301,4752,578,307,198,8126,218,12897,201,11701,198,1994,24422,209,252,261,8445,7339,218,282,14,12331,2208,283,301,198,1013,17152,201,19210,196,760,12898,209,261,506,1293,390,1365,1553,201,225,684,293,6982,982,778,1053,201],"4":[1360,15608,2521,390,9718,227,2100,221,198,4520,1211,444,8564,804,410,14358,209,261,506,1293,684,293,8793,307,196,845,218,6498,466,252,261,494,246,5398,218,282,14,196,760,12898,10043,496,371,604,20338,263,198,24534,201,400,390,10581,221,282,14,12331,2208,283,209,252,261,20338,263,198,23714,218,282,14,196,760,12898,390,1994,371,1994,277,22181,201,732,1413,218,282,14,12331,2208,283,390,1969,371,1969,277,22181,209,252,261,24534,218,198,472,655,218,282,14,196,760,12898,301,12960,371,1994,201,732,295,218,282,14,12331,2208,283,301,942,5676,1969,371,1365,8202,1994,209],"5":[252,238,238,5971,7308,238,238,252,14035,282,14,12331,2208,283,3484,3956,13559,646,543,490,8498,227,196,1044,422,571,2591,218,3105,459,5626,22608,311,428,511,245,459,428,511,428,221,310,201,6844,5020,10719,319,196,4370,3199,2892,209,259,761,4183,6395,221,3271,708,196,5527,24820,744,2476,201,2821,4093,301,4947,6144,201,225,196,2804,277,24110,2956,209,261,2476,9651,198,7144,274,578,227,837,2200,201,8326,263,198,5010,201,6386,227,468,4405,343,11051,209,17172,7600,7144,390,1014,227,293,239,20148,1635,239,225,684,293,824,1895,198,544,209,252,261,7144,15638,319,1948,201,225],"6":[198,17469,9408,227,198,1371,3239,756,543,11642,296,198,7153,20898,201,577,2464,263,22762,8222,858,851,209,795,2570,11533,697,24820,745,225,15964,274,916,459,2380,1651,209,929,198,1368,24820,84,201,198,16695,3544,263,196,740,7727,227,198,4047,201,225,3313,745,196,219,2121,229,13643,209,261,16696,390,6678,1809,221,198,4520,201,225,390,7808,876,201,1360,543,390,876,227,293,5992,218,15590,5536,2515,6846,209,530,301,2729,295,669,245,768,6820,221,1621,464,562,802,14272,227,198,219,2121,229,5427,209,1438,543,3484,196,1044,422,571,2591,218,916,1999,311,796,511,5931,221,310],"7":[201,198,16696,3489,444,2515,6846,225,1046,444,4047,4658,209,252,238,238,14720,238,238,252,19210,12331,2208,283,301,824,1896,198,1128,277,2225,2477,4575,340,2518,7984,227,198,14780,225,22054,201,410,867,198,18294,4116,209,530,301,471,1474,221,659,218,198,7662,4116,201,669,6157,340,198,3161,1657,218,4968,563,201,225,1008,669,198,1128,277,1639,2076,218,198,1885,4116,209,261,2518,1942,5892,390,824,221,198,10768,214,74,1907,236,1162,70,74,412,213,225,18700,17840,65,201,3950,198,21427,9743,234,209,252,261,1293,684,293,4695,587,784,1851,269,1320,4344,5892,201,496,6098,1866],"8":[201,225,697,400,490,9068,1166,1309,227,1119,5501,1866,15321,201,4692,1309,227,8552,227,198,1492,5039,209,261,455,218,1021,301,198,1866,218,7162,9879,340,2518,7984,201,400,490,504,3242,227,281,198,239,16501,277,6271,17152,239,209,261,5892,221,198,7662,4116,390,4344,340,1413,221,198,2477,4575,209,261,1244,4344,1866,301,824,221,198,7533,466,12828,340,198,339,419,320,4651,352,69,360,4344,340,1413,5464,221,198,1182,4116,371,1545,7484,209,252,24135,83,490,504,746,227,14275,282,14,12331,2208,283,227,752,2952,201,4583,548,2364,1293,781,281,198,15494,21089,201,24872,224,330],"9":[260,283,209,5507,11329,225,5989,201,496,1079,17152,17469,360,1055,340,6066,2264,346,221,5222,204,197,201,445,198,1293,992,410,1638,1860,817,209,252,238,238,19976,238,238,252,11312,282,14,12331,2208,283,2253,263,198,10077,22975,319,21762,218,796,459,4178,3224,311,796,459,4468,18,1841,310,201,1360,410,7736,15408,702,1834,230,311,5722,1841,310,209,1065,5463,2804,8138,7075,201,781,281,10578,371,2804,11832,201,225,2253,221,12782,371,942,86,1656,201,16062,319,1948,227,7710,209,252,261,7546,218,282,14,12331,2208,283,3223,5886,218,548,219,2121,229,18499,209,1850,1766,6063,1466,201],"10":[230,1894,13799,83,201,3304,159,885,776,201,1322,11207,225,4291,18680,16328,21077,209,252,261,697,472,655,204,17152,1293,19210,12331,2208,283,201,282,14,196,760,12898,225,1853,575,635,83,5746,711,71,10491,390,12202,227,198,697,876,1293,218,198,4926,1091,930,327,247,2236,73,2401,4403,440,198,1293,263,282,14,12331,2208,283,499,410,504,1246,209,252,19210,12331,2208,283,301,22081,227,198,4516,267,1867,4219,7024,201,2035,307,198,14477,1026,5611,375,13489,283,9769,271,738,209,1647,322,301,5431,824,221,1013,7162,9879,201,198,4516,499,669,504,1809,221,19676,282,14,12331,2208,283],"11":[201,756,3119,9025,218,198,8284,307,282,14,196,760,12898,990,410,293,6620,573,209,252,238,238,6870,11391,238,238,252,19210,12331,2208,283,301,13924,239,3879,928,21008,204,239,281,196,2970,246,3760,225,258,4478,221,239,261,10835,11207,239,196,18409,926,1545,4676,647,209,530,977,13762,309,1365,823,15210,225,977,293,1574,5739,201,16791,7010,201,684,2457,371,4192,68,665,209,3965,198,23714,225,198,22051,218,282,14,12331,2208,283,4607,239,13123,239,1969,12394,201,225,659,218,198,15891,218,198,2374,15444,679,210,1215,390,15494,209,261,17588,390,198,21202,523,955,225,198,239,4936],"12":[22774,239,311,21201,310,209,261,11588,218,282,14,12331,2208,283,301,578,227,697,1696,4014,702,295,218,282,14,196,760,12898,201,225,198,2364,1293,301,1535,227,490,196,3077,12693,418,209,252,317,706,9879,390,3223,214,736,1919,17152,224,1637,201,1360,2875,22218,616,296,11820,20305,371,2439,84,234,11207,3019,8900,221,22179,223,642,573,201,227,1313,642,227,293,6637,221,196,8946,371,307,1414,209,357,1465,201,493,562,3075,22,194,218,282,14,12331,2208,283,360,6637,1896,1399,225,1182,3054,201,218,400,428,562,4683,18,194,311,6600,1124,310,258,6637,221,198,1038,11303,311],"13":[867,198,7484,7362,310,209,261,8240,5360,2892,274,282,14,12331,2208,283,301,196,1044,422,571,2591,218,6286,1999,311,428,511,493,221,310,209,252,13773,257,21384,5135,274,282,14,12331,2208,283,390,699,1921,201,225,1743,7928,390,1490,1365,1669,209,252,238,238,8937,16097,1514,238,238,252,19210,12331,2208,283,258,455,1782,196,12072,22267,1221,307,5064,317,909,19649,221,198,8813,4615,218,342,5599,65,318,4881,1240,201,1573,221,5986,24,209,2881,1221,258,24872,12331,2208,283,201,1211,317,909,19649,268,3301,218,198,4013,24872,319,295,582,1331,610,1178,18183,571,738,209,252,282,14],"14":[12331,2208,283,301,198,3108,1293,218,198,4013,19210,17518,201,20690,201,281,4729,307,300,4340,3945,218,198,2392,4774,263,16455,1839,318,249,22701,1411,209,7666,227,295,4568,201,11527,12061,1080,198,1293,398,504,3242,227,307,951,1579,2850,201,867,5265,257,283,1261,4487,14166,523,2215,201,22956,225,19210,18381,215,282,14,3248,1775,277,18903,201,14300,201,225,471,1080,14415,17518,268,83,5278,218,198,4013,398,504,17228,1040,1994,22708,307,2489,297,633,66,302,201,19912,1061,3119,13525,967,218,3108,1293,311,274,19210,282,14,3248,1775,277,18903,201,14300,310,21468,274,19210,17518,201,20690,209,252],"15":[261,3108,7139,218,19210,12331,2208,283,258,196,10529,7769,3941,307,23601,559,3059,299,85,215,221,5629,209,530,1501,340,5197,2719,4800,396,110,318,920,2719,3434,396,110,338,201,996,681,246,1245,201,9704,311,4067,5549,371,1342,1957,4587,218,9901,787,6012,310,201,445,857,322,225,198,1346,222,395,14403,490,1211,504,1777,209,252,261,1512,1221,274,282,14,12331,2208,283,7676,307,198,9758,225,15124,447,8473,301,239,2364,17152,239,201,445,198,1293,301,471,3991,876,281,198,239,1512,17152,239,209,252,238,3180,5577,4593,238,252,1934,8314,12825,3180,5577,4593,201,17920,201,17920,37]}} -""" +from ._huggingface_data import train_data, validation_data # 16 first rows of tokenized wikitext-2-raw-v1 training & validation train_df = pd.read_json(train_data) @@ -84,18 +79,11 @@ def test_e2e(ray_start_4_cpus, save_strategy): assert result2.metrics["epoch"] == 4 assert result2.checkpoint - class HuggingFaceScorer: - def __init__(self): - self.pred = HuggingFacePredictor.from_checkpoint( - result2.checkpoint, AutoModelForCausalLM - ) - - def __call__(self, x): - return self.pred.predict(x) - - predictions = ray_validation.map_batches( - HuggingFaceScorer, batch_size=8, batch_format="pandas", compute="actors" + predictor = BatchPredictor.from_checkpoint( + result2.checkpoint, HuggingFacePredictor, model=AutoModelForCausalLM ) + + predictions = predictor.predict(ray_validation) assert predictions.count() == 16 diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 9f13159dd09e..7810b8bcaf86 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -117,6 +117,7 @@ class HuggingFaceTrainer(TorchTrainer): Example: .. code-block:: python + # Based on # huggingface/notebooks/examples/language_modeling_from_scratch.ipynb From 5f7fda1a88890893defa014cd9b9edd1bc69e4dd Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 22 Apr 2022 17:51:40 +0000 Subject: [PATCH 52/75] CI fixes --- doc/source/custom_directives.py | 1 + python/ray/ml/tests/test_huggingface_predictor.py | 2 +- python/ray/ml/tests/test_huggingface_trainer.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index 80cfd311ab78..e477cdae5225 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -131,6 +131,7 @@ def update_context(app, pagename, templatename, context, doctree): "tensorflow.contrib.all_reduce", "transformers", "transformers.modeling_utils", + "transformers.models", "transformers.trainer", "transformers.training_args", "transformers.trainer_callback", diff --git a/python/ray/ml/tests/test_huggingface_predictor.py b/python/ray/ml/tests/test_huggingface_predictor.py index 7d4dd19404fc..4d0455ef3f4a 100644 --- a/python/ray/ml/tests/test_huggingface_predictor.py +++ b/python/ray/ml/tests/test_huggingface_predictor.py @@ -5,7 +5,7 @@ from ray.ml.preprocessor import Preprocessor from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor -from ._huggingface_data import validation_data +from ray.ml.tests._huggingface_data import validation_data # 16 first rows of tokenized wikitext-2-raw-v1 validation validation_df = pd.read_json(validation_data) diff --git a/python/ray/ml/tests/test_huggingface_trainer.py b/python/ray/ml/tests/test_huggingface_trainer.py index e492f972c90d..59f457b673ba 100644 --- a/python/ray/ml/tests/test_huggingface_trainer.py +++ b/python/ray/ml/tests/test_huggingface_trainer.py @@ -10,7 +10,7 @@ from ray.ml.train.integrations.huggingface.huggingface_utils import process_datasets from ray.ml.batch_predictor import BatchPredictor -from ._huggingface_data import train_data, validation_data +from ray.ml.tests._huggingface_data import train_data, validation_data # 16 first rows of tokenized wikitext-2-raw-v1 training & validation train_df = pd.read_json(train_data) From bdc738789810c261c21e6f383b4a8d48f8f74073 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 22 Apr 2022 19:10:35 +0000 Subject: [PATCH 53/75] Fix docs --- doc/source/custom_directives.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index e477cdae5225..efd26278396d 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -132,11 +132,13 @@ def update_context(app, pagename, templatename, context, doctree): "transformers", "transformers.modeling_utils", "transformers.models", + "transformers.models.auto", "transformers.trainer", "transformers.training_args", "transformers.trainer_callback", "transformers.utils", "transformers.utils.logging", + "transformers.utils.versions", "tree", "tensorflow.contrib.all_reduce.python", "tensorflow.contrib.layers", From 3ef8d1bb627cf5eb986857d29ff5293ee679530d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 26 Apr 2022 16:07:02 +0000 Subject: [PATCH 54/75] Expand predictor --- .../huggingface/huggingface_predictor.py | 37 +++++++++++++++++-- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py index 0b6ba6eaa884..85eb2cecdc31 100644 --- a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py +++ b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py @@ -26,23 +26,32 @@ class HuggingFacePredictor(TorchPredictor): model: The Transformers model to use for predictions. preprocessor: A preprocessor used to transform data batches prior to prediction. + training_args: ``transformers.TrainingArguments`` to use for the prediction. + trainer_class: ``transformers.Trainer`` subclass to use for prediction. + Defaults to ``transformers.Trainer``. """ def __init__( self, model: Union[PreTrainedModel, torch.nn.Module], preprocessor: Optional[Preprocessor] = None, + *, training_args: Optional[TrainingArguments] = None, + trainer_class: HFTrainer = HFTrainer, ): self.model = model self.preprocessor = preprocessor self.training_args = training_args + self.trainer_class = trainer_class @classmethod def from_checkpoint( cls, checkpoint: Checkpoint, model: Union[Type[PreTrainedModel], torch.nn.Module], + *, + training_args: Optional[TrainingArguments] = None, + trainer_class: HFTrainer = HFTrainer, **pretrained_model_kwargs, ) -> "HuggingFacePredictor": """Instantiate the predictor from a Checkpoint. @@ -56,6 +65,10 @@ def from_checkpoint( model: Either a ``transformers.PreTrainedModel`` class (eg. ``AutoModelForCausalLM``), or a PyTorch model to load the weights to. This should be the same model used for training. + training_args: ``transformers.TrainingArguments`` to use for the prediction. + Defaults to training arguments saved inside the checkpoint. + trainer_class: ``transformers.Trainer`` subclass to use for prediction. + Defaults to ``transformers.Trainer``. **pretrained_model_kwargs: Any kwargs to pass to the ``model.from_pretrained()`` call. Only used if ``model`` is a ``PreTrainerModel`` class. @@ -83,8 +96,26 @@ def from_checkpoint( else: training_args = None return HuggingFacePredictor( - model=model, preprocessor=preprocessor, training_args=training_args + model=model, + preprocessor=preprocessor, + training_args=training_args, + trainer_class=trainer_class, + ) + + def to_transformers_trainer(self, **trainer_kwargs) -> HFTrainer: + """Converts this predictor to a fitted ``transformers.Trainer``. + + Args: + **trainer_kwargs: Any kwargs to pass to the + ``trainer_class`` initialization. ``model`` and + ``args`` are preset. + """ + if self.training_args: + self.training_args.local_rank = -1 + trainer = self.trainer_class( + model=self.model, args=self.training_args, **trainer_kwargs ) + return trainer def _convert_to_tensor( self, @@ -107,9 +138,7 @@ def _convert_to_tensor( ) def _predict(self, tensor: Dict[str, torch.Tensor]) -> pd.DataFrame: - if self.training_args: - self.training_args.local_rank = -1 - trainer = HFTrainer(model=self.model, args=self.training_args) + trainer = self.to_transformers_trainer() dataset = HFIterableDatasetWithLen([tensor], 1) # squeeze out the extra dimension added by torch.stack # inside the HF data collator From c0da17e5b18995acf0f89beafcb472a395c423d9 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 26 Apr 2022 21:37:04 +0200 Subject: [PATCH 55/75] Apply suggestions from code review Co-authored-by: Amog Kamsetty --- .../huggingface_basic_language_modeling_example.py | 5 ++--- .../ml/train/integrations/huggingface/huggingface_trainer.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py index 0c947ef62061..db55941ab868 100644 --- a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py +++ b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py @@ -49,9 +49,8 @@ def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, we could add padding if the model supported - # it instead of this drop, you can - # customize this part to your needs. + # We drop the small remainder. We could add padding if the model supported + # it instead of this drop. You can customize this part to your needs. total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 7810b8bcaf86..be7faa90aaf4 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -314,7 +314,7 @@ def _validate_attributes(self): def as_trainable(self) -> Type[Trainable]: # Replace the directory checkpoint with a node ip & path dict checkpoint - # used to sync the directory. If we use a directry checkpoint directly, + # used to sync the directory. If we use a directory checkpoint directly, # it will get deepcopied & serialized unnecessarily original_param_dict = self._param_dict.copy() resume_from_checkpoint: Optional[Checkpoint] = self._param_dict.get( From 8c7a6eb991663e2a9e248d98ab058919455f5847 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 27 Apr 2022 17:00:07 +0000 Subject: [PATCH 56/75] WIP --- .../huggingface/huggingface_predictor.py | 98 +++----- .../integrations/torch/torch_predictor.py | 4 + .../ray/ml/tests/test_huggingface_trainer.py | 37 +-- .../integrations/huggingface/__init__.py | 3 +- .../huggingface/huggingface_trainer.py | 113 +++++---- .../huggingface/huggingface_utils.py | 214 +++++++++++------- .../ml/utils/huggingface_checkpoint_utils.py | 54 +++++ python/ray/ml/utils/torch_utils.py | 6 +- 8 files changed, 286 insertions(+), 243 deletions(-) create mode 100644 python/ray/ml/utils/huggingface_checkpoint_utils.py diff --git a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py index 85eb2cecdc31..81f41efe7529 100644 --- a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py +++ b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py @@ -1,25 +1,21 @@ -import os from typing import Dict, Optional, Type, Union, List +import numpy as np import pandas as pd -from ray.ml.train.integrations.huggingface.huggingface_utils import ( - HFIterableDatasetWithLen, -) + import torch +from datasets import Dataset as HFDataset from transformers.modeling_utils import PreTrainedModel -from transformers.trainer import WEIGHTS_NAME, TRAINING_ARGS_NAME, Trainer as HFTrainer +from transformers.trainer import Trainer as HFTrainer from transformers import TrainingArguments -import ray.cloudpickle as cpickle -from ray.ml.predictor import DataBatchType -from ray.ml.predictors.integrations.torch import TorchPredictor +from ray.ml.predictor import DataBatchType, Predictor from ray.ml.preprocessor import Preprocessor from ray.ml.checkpoint import Checkpoint -from ray.ml.utils.torch_utils import load_torch_model -from ray.ml.constants import PREPROCESSOR_KEY +from ray.ml.utils.huggingface_checkpoint_utils import load_huggingface_checkpoint -class HuggingFacePredictor(TorchPredictor): +class HuggingFacePredictor(Predictor): """A predictor for HuggingFace Transformers PyTorch models. Args: @@ -73,28 +69,10 @@ def from_checkpoint( ``model.from_pretrained()`` call. Only used if ``model`` is a ``PreTrainerModel`` class. """ - with checkpoint.as_directory() as checkpoint_path: - preprocessor_path = os.path.join(checkpoint_path, PREPROCESSOR_KEY) - if os.path.exists(preprocessor_path): - with open(preprocessor_path, "rb") as f: - preprocessor = cpickle.load(f) - else: - preprocessor = None - if isinstance(model, torch.nn.Module): - state_dict = torch.load( - os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu" - ) - model = load_torch_model(saved_model=state_dict, model_definition=model) - else: - model = model.from_pretrained( - checkpoint_path, **pretrained_model_kwargs - ) - training_args_path = os.path.join(checkpoint_path, TRAINING_ARGS_NAME) - if os.path.exists(training_args_path): - with open(training_args_path, "rb") as f: - training_args = torch.load(f, map_location="cpu") - else: - training_args = None + model, preprocessor, loaded_training_args = load_huggingface_checkpoint( + checkpoint, model, **pretrained_model_kwargs + ) + training_args = training_args or loaded_training_args return HuggingFacePredictor( model=model, preprocessor=preprocessor, @@ -103,7 +81,7 @@ def from_checkpoint( ) def to_transformers_trainer(self, **trainer_kwargs) -> HFTrainer: - """Converts this predictor to a fitted ``transformers.Trainer``. + """Converts this predictor to a ``transformers.Trainer``. Args: **trainer_kwargs: Any kwargs to pass to the @@ -117,32 +95,9 @@ def to_transformers_trainer(self, **trainer_kwargs) -> HFTrainer: ) return trainer - def _convert_to_tensor( - self, - data: pd.DataFrame, - feature_columns: Optional[ - Union[List[str], List[List[str]], List[int], List[List[int]]] - ] = None, - dtypes: Optional[torch.dtype] = None, - unsqueeze: bool = False, - ) -> Dict[str, torch.Tensor]: - if not feature_columns: - feature_columns = data.columns - - # HF-supported format - if not isinstance(feature_columns, dict): - feature_columns = {column: [column] for column in feature_columns} - - return super()._convert_to_tensor( - data, feature_columns=feature_columns, dtypes=dtypes, unsqueeze=unsqueeze - ) - - def _predict(self, tensor: Dict[str, torch.Tensor]) -> pd.DataFrame: + def _predict(self, dataset: HFDataset) -> pd.DataFrame: trainer = self.to_transformers_trainer() - dataset = HFIterableDatasetWithLen([tensor], 1) - # squeeze out the extra dimension added by torch.stack - # inside the HF data collator - ret = trainer.predict(dataset).predictions.squeeze() + ret = trainer.predict(dataset).predictions # TODO(yard1): Return just a numpy array once that's supported # by Ray Datasets df = pd.DataFrame([ret.tolist()]).T @@ -153,12 +108,12 @@ def predict( self, data: DataBatchType, feature_columns: Optional[List[str]] = None, - dtype: Optional[Union[Dict[str, torch.dtype], torch.dtype]] = None, + dtype: Optional[Union[Dict[str, np.dtype], np.dtype]] = None, ) -> DataBatchType: """Run inference on data batch. - The data is converted into a dict of torch Tensors before being inputted to - the model. + The data is converted into a HuggingFace ``datasets.Dataset`` + and passed to a ``transformers.Trainer.predict()`` method. Args: data: A batch of input data. Either a pandas DataFrame or numpy @@ -166,7 +121,7 @@ def predict( feature_columns: The names or indices of the columns in the data to use as features to predict on. If None, use all columns. - dtype: The torch dtypes to use when creating the torch tensor. + dtype: The numpy dtypes to cast the data to. Can be either a single dtype or a dict of ``column:dtype``. If set to None, then automatically infer the dtype. @@ -232,7 +187,16 @@ def group_texts(examples): Returns: DataBatchType: Prediction result. """ - # We are just changing the signature and docstring. - return super().predict( - data, feature_columns=feature_columns, dtype=dtype, unsqueeze=False - ) + if self.preprocessor: + data = self.preprocessor.transform_batch(data) + + if isinstance(data, np.ndarray): + # If numpy array, then convert to pandas dataframe. + data = pd.DataFrame(data) + + data = data[feature_columns] if feature_columns else data + if dtype: + data = data.astype(dtype) + + dataset = HFDataset.from_pandas(data) + return self._predict(dataset) diff --git a/python/ray/ml/predictors/integrations/torch/torch_predictor.py b/python/ray/ml/predictors/integrations/torch/torch_predictor.py index 609cd04fc542..c03493cbca8e 100644 --- a/python/ray/ml/predictors/integrations/torch/torch_predictor.py +++ b/python/ray/ml/predictors/integrations/torch/torch_predictor.py @@ -65,6 +65,9 @@ def _convert_to_tensor( dtypes: Optional[torch.dtype] = None, unsqueeze: bool = True, ) -> torch.Tensor: + """Handle conversion of data to tensor. + + Same arguments as in ``convert_pandas_to_torch_tensor``.""" # TODO(amog): Add `_convert_numpy_to_torch_tensor to use based on input type. # Reduce conversion cost if input is in Numpy if isinstance(feature_columns, dict): @@ -87,6 +90,7 @@ def _convert_to_tensor( return features_tensor def _predict(self, tensor: torch.Tensor) -> pd.DataFrame: + """Handle actual prediction.""" prediction = self.model(tensor).cpu().detach().numpy() return pd.DataFrame(prediction, columns=["predictions"]) diff --git a/python/ray/ml/tests/test_huggingface_trainer.py b/python/ray/ml/tests/test_huggingface_trainer.py index 59f457b673ba..55bc3a4c04b6 100644 --- a/python/ray/ml/tests/test_huggingface_trainer.py +++ b/python/ray/ml/tests/test_huggingface_trainer.py @@ -65,6 +65,7 @@ def test_e2e(ray_start_4_cpus, save_strategy): result = trainer.fit() assert result.metrics["epoch"] == 3 + assert result.metrics["training_iteration"] == 3 assert result.checkpoint trainer2 = HuggingFaceTrainer( @@ -85,39 +86,3 @@ def test_e2e(ray_start_4_cpus, save_strategy): predictions = predictor.predict(ray_validation) assert predictions.count() == 16 - - -def test_same_data_format(ray_start_4_cpus): - train_hf_dataset = Dataset.from_pandas(train_df) - validation_hf_dataset = Dataset.from_pandas(validation_df) - hf_trainer = train_function(train_hf_dataset, validation_hf_dataset) - hf_trainer._get_train_sampler = lambda: None # No randomness - hf_train_dataloader = hf_trainer.get_train_dataloader() - - ray_train = ray.data.from_pandas(train_df) - ray_validation = ray.data.from_pandas(validation_df) - ray_train, ray_validation = process_datasets(ray_train, ray_validation) - ray_trainer = train_function(ray_train, ray_validation) - ray_train_dataloader = ray_trainer.get_train_dataloader() - - hf_train_dataloader_inputs = [ - hf_trainer._prepare_inputs(inputs) for inputs in hf_train_dataloader - ] - ray_train_dataloader_inputs = [ - ray_trainer._prepare_inputs(inputs) for inputs in ray_train_dataloader - ] - - def equal_or_exception(a: torch.Tensor, b: torch.Tensor): - if not torch.equal(a, b): - raise AssertionError( - f"Tensor A ({a.shape}) doesn't equal tensor B ({b.shape}):" - f"\n{a}\n{b}\n" - ) - - # We squeeze to get rid of the extra dimension added by the HF - # torch_default_data_collator. The models seem to train and predict - # fine with that extra dimension. - [ - [equal_or_exception(a[k], b[k].squeeze()) for k in a] - for a, b in zip(hf_train_dataloader_inputs, ray_train_dataloader_inputs) - ] diff --git a/python/ray/ml/train/integrations/huggingface/__init__.py b/python/ray/ml/train/integrations/huggingface/__init__.py index b2344f98dcb3..37d23738ba3e 100644 --- a/python/ray/ml/train/integrations/huggingface/__init__.py +++ b/python/ray/ml/train/integrations/huggingface/__init__.py @@ -1,5 +1,6 @@ from ray.ml.train.integrations.huggingface.huggingface_trainer import ( HuggingFaceTrainer, ) +from ray.ml.utils.huggingface_checkpoint_utils import load_huggingface_checkpoint -__all__ = ["HuggingFaceTrainer"] +__all__ = ["HuggingFaceTrainer", "load_huggingface_checkpoint"] diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index be7faa90aaf4..8ce3ac494729 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -6,12 +6,10 @@ from pathlib import Path from typing import Any, Callable, Dict, Optional, Type -import torch import transformers import transformers.trainer import ray.cloudpickle as cpickle -from torch.utils.data import DataLoader, Dataset as TorchDataset -from transformers.training_args import TrainingArguments +from torch.utils.data import Dataset as TorchDataset from ray import train from ray import tune @@ -28,6 +26,7 @@ NODE_IP_KEY, process_datasets, TrainReportCallback, + wrap_transformers_trainer, ) from ray.train.constants import TUNE_CHECKPOINT_ID from ray.train.torch import TorchConfig @@ -49,8 +48,7 @@ # in HuggingFaceTrainer.as_trainable # TODO(team-ml): Refactor checkpoint management along with Tune. class _DataParallelSyncingCheckpointManager(_DataParallelCheckpointManager): - """Same as _DataParallelCheckpointManager, but syncs the dir instead - of serializing it.""" + """As _DataParallelCheckpointManager, but syncs the dir instead of serializing.""" def write_checkpoint(self, checkpoint: Dict): # If inside a Tune Trainable, then checkpoint with Tune. @@ -309,35 +307,55 @@ def _validate_attributes(self): "one GPU per worker in DDP mode and will fail " "if more are assigned." ) + if gpus_per_worker != int(gpus_per_worker): + raise ValueError( + f"You have assigned {gpus_per_worker} GPUs per worker, " + "but fractional GPUs are not supported by HuggingFace." + ) super()._validate_attributes() + def _convert_directory_checkpoint_to_sync( + self, checkpoint: Checkpoint + ) -> Checkpoint: + """Replace the directory checkpoint with a node ip & path dict checkpoint + used to sync the directory. If we use a directory checkpoint directly, + it will get deepcopied & serialized unnecessarily.""" + with checkpoint.as_directory() as checkpoint_path: + # Load checkpoint from path. + checkpoint_path = Path(checkpoint_path).expanduser().absolute() + if not checkpoint_path.exists(): + raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.") + with open(checkpoint_path.joinpath(TUNE_CHECKPOINT_ID), "r") as f: + tune_checkpoint_id = int(f.read()) + + return Checkpoint.from_dict( + { + NODE_IP_KEY: get_node_ip_address(), + CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), + TUNE_CHECKPOINT_ID: tune_checkpoint_id, + } + ) + + def setup(self) -> None: + if ( + self.resume_from_checkpoint + and self.resume_from_checkpoint.get_internal_representation() + == "local_path" + ): + self.resume_from_checkpoint = self._convert_directory_checkpoint_to_sync( + self.resume_from_checkpoint + ) + def as_trainable(self) -> Type[Trainable]: - # Replace the directory checkpoint with a node ip & path dict checkpoint - # used to sync the directory. If we use a directory checkpoint directly, - # it will get deepcopied & serialized unnecessarily original_param_dict = self._param_dict.copy() resume_from_checkpoint: Optional[Checkpoint] = self._param_dict.get( "resume_from_checkpoint", None ) if resume_from_checkpoint: - with resume_from_checkpoint.as_directory() as checkpoint_path: - # Load checkpoint from path. - checkpoint_path = Path(checkpoint_path).expanduser().absolute() - if not checkpoint_path.exists(): - raise ValueError( - f"Checkpoint path {checkpoint_path} does not exist." - ) - with open(checkpoint_path.joinpath(TUNE_CHECKPOINT_ID), "r") as f: - tune_checkpoint_id = int(f.read()) - - self._param_dict["resume_from_checkpoint"] = Checkpoint.from_dict( - { - NODE_IP_KEY: get_node_ip_address(), - CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), - TUNE_CHECKPOINT_ID: tune_checkpoint_id, - } - ) + self._param_dict[ + "resume_from_checkpoint" + ] = self._convert_directory_checkpoint_to_sync(resume_from_checkpoint) try: ret = super().as_trainable() finally: @@ -356,8 +374,10 @@ def _huggingface_train_loop_per_worker(config): train_dataset = train.get_dataset_shard(TRAIN_DATASET_KEY) eval_dataset = train.get_dataset_shard(EVALUATION_DATASET_KEY) + train_torch_dataset, eval_torch_dataset = process_datasets( - train_dataset, eval_dataset + train_dataset, + eval_dataset, ) # TODO(yard1): Automatically set `no_cuda` somehow @@ -365,37 +385,16 @@ def _huggingface_train_loop_per_worker(config): train_torch_dataset, eval_torch_dataset, **config ) - base_training_arguments_class: Type[TrainingArguments] = trainer.args.__class__ - - class RayTrainingArguments(base_training_arguments_class): - @property - def device(self) -> "torch.device": - return train.torch.get_device() - - base_trainer_class: Type[transformers.trainer.Trainer] = trainer.__class__ - - class RayTrainer(base_trainer_class): - def get_train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.args.per_device_train_batch_size, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.dataloader_pin_memory, - ) - - def _save(self, *args, **kwargs): - # Workaround for RayTrainingArguments not being - # pickleable due to it being defined in a local - # scope - self.args.__class__ = base_training_arguments_class - ret = super()._save(*args, **kwargs) - self.args.__class__ = RayTrainingArguments - return ret + if trainer.args.push_to_hub: + raise ValueError( + "`push_to_hub` parameter in `TrainingArgs` is not supported by " + "`HuggingFaceTrainer`. If you would like to push your model to hub " + "after training, use the `load_huggingface_checkpoint` function " + "to obtain the model from a returned checkpoint, and use it to " + "instantiate the `transformers.Trainer` class." + ) - trainer.__class__ = RayTrainer - trainer.args.__class__ = RayTrainingArguments - trainer.args.no_cuda = not torch.cuda.is_available() + trainer = wrap_transformers_trainer(trainer) # ensure no HF logging callbacks are added # aside from doubling functionality with our callbacks, @@ -407,8 +406,6 @@ def _save(self, *args, **kwargs): trainer.pop_callback(callback) trainer.add_callback(TrainReportCallback) - if trainer.args.device.type == "cuda": - torch.cuda.set_device(trainer.args.device) checkpoint = train.load_checkpoint() checkpoint_path = None diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_utils.py b/python/ray/ml/train/integrations/huggingface/huggingface_utils.py index 21b3ecaa4aa3..503c5b672fce 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_utils.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_utils.py @@ -1,99 +1,143 @@ from pathlib import Path -from typing import Dict, Generator, Iterator, List, Tuple +from typing import Any, Callable, Optional, Tuple, Type -import torch +import datasets.iterable_dataset import transformers.trainer -from torch.utils.data import IterableDataset +from torch.utils.data import IterableDataset, DataLoader from transformers.trainer_callback import TrainerCallback from ray import train from ray.util import get_node_ip_address from ray.data.dataset import Dataset +# Constants for the sync checkpoint dict. See huggingface_trainer.py CHECKPOINT_PATH_ON_NODE_KEY = "checkpoint_path_on_node" NODE_IP_KEY = "node_ip" -class HFIterableDataset(IterableDataset): - """Special Torch IterableDataset with HF format.""" +def maybe_add_length(obj: Any, length: Optional[int]) -> Any: + """Change the class of obj to a subclass with predefined __len__ if needed.""" + # By adding length to the dataset we let HF calculate steps per epoch + # and other such values. Without length, it's not possible to use + # epochs as the evaluation strategy, which makes for poor UX. - def __init__(self, generator: Generator): - self.generator = generator + if not length or hasattr(obj, "__len__"): + return obj - def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: - it = self.generator - for x in it: - # HF-specific format. See transformers.Trainer._prepare_inputs - if isinstance(x, dict): - # Just features - yield x + def __len__(self): + return length + + new_class = type( + f"{obj.__class__.__name__}WithLength", (obj.__class__,), {"__len__": __len__} + ) + obj.__class__ = new_class + return obj + + +def wrap_transformers_trainer( + trainer: transformers.trainer.Trainer, +) -> transformers.trainer.Trainer: + """Change the class of trainer to a subclass implementing Ray-specific logic.""" + base_trainer_class: Type[transformers.trainer.Trainer] = trainer.__class__ + + class RayTrainer(base_trainer_class): + def _prepare_data_collator(self): + try: + self._remove_unused_columns(None, description="nan") + except AttributeError: + pass + + self.data_collator = self._get_remove_columns_data_collator() + + def _get_remove_columns_data_collator(self) -> Callable: + if self._signature_columns and not hasattr(self, "_original_data_collator"): + + self._original_data_collator = self.data_collator + + def remove_columns_collator(features): + features = [ + { + k: v + for k, v in feature.items() + if k in self._signature_columns + } + for feature in features + ] + return self._original_data_collator(features) + + collator = remove_columns_collator else: - # Features and labels - yield x[0] + collator = self.data_collator + return collator + def get_train_dataloader(self): + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") -class HFIterableDatasetWithLen(HFIterableDataset): - """Special Torch IterableDataset with preset length.""" + train_dataset = self.train_dataset - def __init__(self, generator: Generator, length: int): - self.generator = generator - self._len = length + # While we are not sharding the datasets again, this + # class ensures that the last batch has a consistent size. + train_dataset = transformers.trainer.IterableDatasetShard( + train_dataset, + batch_size=self.args.train_batch_size, + drop_last=self.args.dataloader_drop_last, + ) + + return DataLoader( + train_dataset, + batch_size=self.args.per_device_train_batch_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + trainer.__class__ = RayTrainer + trainer._prepare_data_collator() + return trainer + + +class RayDatasetHFIterable(datasets.iterable_dataset.ExamplesIterable): + """HF ExamplesIterable backed by a Ray Dataset.""" + def __init__(self, dataset: Dataset) -> None: + self.dataset = dataset + self.generate_examples_fn = self.dataset.iter_rows + + # Required for the superclass + self.kwargs = {} + + def __iter__(self): + for row in self.generate_examples_fn(**self.kwargs): + yield (0, {k: v for k, v in row.as_pydict().items()}) + + +def process_dataset_for_hf(dataset: Dataset) -> IterableDataset: + """Converts a Ray Dataset into a HF IterableDataset.""" + hf_iterable = RayDatasetHFIterable(dataset) + + iterable_dataset = datasets.iterable_dataset.IterableDataset( + hf_iterable, format_type="torch" + ).with_format("torch") - def __len__(self): - return self._len - - -def process_dataset_for_hf( - dataset: Dataset, feature_columns: Dict[str, List[str]], batch_size: int = 1 -) -> IterableDataset: - """Converts a Ray Dataset into a HF-compatible Torch Dataset.""" - torch_dataset = dataset.to_torch( - batch_size=batch_size, - feature_columns=feature_columns, - label_column=None, - unsqueeze_label_tensor=False, - unsqueeze_feature_tensors=False, - ) try: - count = dataset.count() + dataset_length = dataset.count() except ValueError: # pipeline case - count = None - if count: - # By adding length to the dataset we let HF calculate steps per epoch - # and other such values. Without length, it's not possible to use - # epochs as the evaluation strategy. - torch_dataset = HFIterableDatasetWithLen(torch_dataset, count) - else: - torch_dataset = HFIterableDataset(torch_dataset) - return torch_dataset + dataset_length = None + + iterable_dataset = maybe_add_length(iterable_dataset, dataset_length) + return iterable_dataset def process_datasets( - train_dataset: Dataset, eval_dataset: Dataset + train_dataset: Dataset, + eval_dataset: Dataset, ) -> Tuple[IterableDataset, IterableDataset]: """Convert Ray train and validation to HF-friendly IterableDatasets.""" - train_columns = set(train_dataset.schema(fetch_if_missing=True).names) - - # HF-specific format. See transformers.Trainer._prepare_inputs - feature_columns = {column: [column] for column in train_columns} - - # This is set to 1 to ensure that the model input format - # is the same as with HF's Dataset. If we were to pass - # an n>1 batch obtained from to_torch to HF Trainer, - # the format will differ, and the example count calculation - # will be messed up (as it assumes that it will always get - # just one row per output of the IterableDataset). - # TODO (yard1): Investigate if we can work around this. - batch_size = 1 - train_torch_dataset = process_dataset_for_hf( - train_dataset, feature_columns, batch_size=batch_size - ) + train_torch_dataset = process_dataset_for_hf(train_dataset) if eval_dataset: - eval_torch_dataset = process_dataset_for_hf( - eval_dataset, feature_columns, batch_size=batch_size - ) + eval_torch_dataset = process_dataset_for_hf(eval_dataset) else: eval_torch_dataset = None @@ -105,15 +149,12 @@ class TrainReportCallback(TrainerCallback): def __init__(self) -> None: # HF first logs metrics, and then checkpoints. With Ray AIR, we need the - # opposite. Therefore, if we detect that a checkpoint will be created, + # opposite. Furthermore, some metrics are logged in several calls. + # Therefore, if we detect that a checkpoint will be created, # we delay the train.report call after the checkpoint is reported # to Ray Train. - self.delayed_report = None - # Avoid double reporting at the end. - # TODO(yard1): Train statistics are only reported at the end. Combine - # the second to last report and the last report somehow. We want - # steps/epochs to match the training iteration. - self.last_step = None + self.delayed_report = {} + self.first_report_keys = None super().__init__() def on_step_end(self, args, state, control, **kwargs): @@ -123,15 +164,30 @@ def on_step_end(self, args, state, control, **kwargs): return control def on_log(self, args, state, control, model=None, logs=None, **kwargs): - if state.global_step == self.last_step: - return - self.last_step = state.global_step report = {**logs, "step": state.global_step, "epoch": state.epoch} - if control.should_save: - self.delayed_report = report + if not self.first_report_keys: + self.first_report_keys = set(report) + # if saving or evaluation is coming, delay reporting + if ( + control.should_save + or control.should_evaluate + or not set(report).issuperset(self.first_report_keys) + ): + self.delayed_report.update(report) else: train.report(**report) + def on_evaluate(self, args, state, control, **kwargs): + # saving comes after evaluation, so report if we + # aren't going to save + if ( + self.delayed_report + and set(self.delayed_report).issuperset(self.first_report_keys) + and not control.should_save + ): + train.report(**self.delayed_report) + self.delayed_report = {} + def on_save(self, args, state, control, **kwargs): checkpoint_path = Path( transformers.trainer.get_last_checkpoint(args.output_dir) @@ -145,4 +201,4 @@ def on_save(self, args, state, control, **kwargs): ) if self.delayed_report: train.report(**self.delayed_report) - self.delayed_report = None + self.delayed_report = {} diff --git a/python/ray/ml/utils/huggingface_checkpoint_utils.py b/python/ray/ml/utils/huggingface_checkpoint_utils.py new file mode 100644 index 000000000000..46cd2102ef54 --- /dev/null +++ b/python/ray/ml/utils/huggingface_checkpoint_utils.py @@ -0,0 +1,54 @@ +import os +from typing import Tuple, Type, Union + +import torch +from transformers.modeling_utils import PreTrainedModel +from transformers.trainer import WEIGHTS_NAME, TRAINING_ARGS_NAME +from transformers import TrainingArguments + +import ray.cloudpickle as cpickle +from ray.ml.preprocessor import Preprocessor +from ray.ml.checkpoint import Checkpoint +from ray.ml.utils.torch_utils import load_torch_model +from ray.ml.constants import PREPROCESSOR_KEY +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +def load_huggingface_checkpoint( + checkpoint: Checkpoint, + model: Union[Type[PreTrainedModel], torch.nn.Module], + **pretrained_model_kwargs +) -> Tuple[Union[PreTrainedModel, torch.nn.Module], Preprocessor, TrainingArguments]: + """Load a Checkpoint from ``HuggingFaceTrainer`` and return the + model, preprocessor and ``TrainingArguments`` contained within. + + Args: + checkpoint: The checkpoint to load the model and + preprocessor from. It is expected to be from the result of a + ``HuggingFaceTrainer`` run. + model: Either a ``transformers.PreTrainedModel`` class + (eg. ``AutoModelForCausalLM``), or a PyTorch model to load the + weights to. This should be the same model used for training. + """ + with checkpoint.as_directory() as checkpoint_path: + preprocessor_path = os.path.join(checkpoint_path, PREPROCESSOR_KEY) + if os.path.exists(preprocessor_path): + with open(preprocessor_path, "rb") as f: + preprocessor = cpickle.load(f) + else: + preprocessor = None + if isinstance(model, torch.nn.Module): + state_dict = torch.load( + os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu" + ) + model = load_torch_model(saved_model=state_dict, model_definition=model) + else: + model = model.from_pretrained(checkpoint_path, **pretrained_model_kwargs) + training_args_path = os.path.join(checkpoint_path, TRAINING_ARGS_NAME) + if os.path.exists(training_args_path): + with open(training_args_path, "rb") as f: + training_args = torch.load(f, map_location="cpu") + else: + training_args = None + return model, preprocessor, training_args diff --git a/python/ray/ml/utils/torch_utils.py b/python/ray/ml/utils/torch_utils.py index 18634a2cac9f..963fca711109 100644 --- a/python/ray/ml/utils/torch_utils.py +++ b/python/ray/ml/utils/torch_utils.py @@ -28,8 +28,10 @@ def convert_pandas_to_torch_tensor( column_dtype (Optional[Union[torch.dtype, List[torch.dtype]): The torch dtype to use for the tensor. If set to None, then automatically infer the dtype. - unsqueeze: Whether to unsqueeze (reshape to a 2d, 1 column tensor) - the columns or not. + unsqueeze: If set to True, the tensors + will be unsqueezed (reshaped to (N, 1)) before being concatenated into + the final tensor. Otherwise, they will be left as is, that is + (N, ). Defaults to True. Returns: Either a torch tensor of size (N, len(columns)) where N is the From 2d5f94e1936aebea4f0598ba7e47d651e1aa3236 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 27 Apr 2022 21:06:04 +0000 Subject: [PATCH 57/75] Complete refactor --- ...ingface_basic_language_modeling_example.py | 12 ++ .../huggingface/huggingface_predictor.py | 194 ++++++++---------- .../ml/tests/test_huggingface_predictor.py | 34 +-- .../ray/ml/tests/test_huggingface_trainer.py | 25 ++- .../huggingface/huggingface_trainer.py | 5 +- .../huggingface/huggingface_utils.py | 42 ++-- 6 files changed, 153 insertions(+), 159 deletions(-) diff --git a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py index db55941ab868..9d8615f6d226 100644 --- a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py +++ b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py @@ -12,11 +12,13 @@ TrainingArguments, ) +import pandas as pd import torch import ray import ray.data from ray.ml.train.integrations.huggingface import HuggingFaceTrainer +from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor def main( @@ -112,6 +114,16 @@ def train_function(train_dataset, eval_dataset=None, **config): results = trainer.fit() print(results.metrics) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) + prompt = ["My text: Complete me..."] + predictor = HuggingFacePredictor.from_checkpoint( + results.checkpoint, task="text-generation", tokenizer=tokenizer + ) + prediction = predictor.predict(pd.DataFrame(prompt, columns=["prompt"])) + prediction = prediction.iloc[0]["generated_text"] + + print(f"Generated text for prompt '{prompt}': '{prediction}'") + if __name__ == "__main__": # Training settings diff --git a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py index 81f41efe7529..a2f4700162ac 100644 --- a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py +++ b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py @@ -1,54 +1,47 @@ -from typing import Dict, Optional, Type, Union, List +import os +from typing import Optional, Type, Union, List import numpy as np import pandas as pd +from ray.ml.constants import PREPROCESSOR_KEY -import torch -from datasets import Dataset as HFDataset -from transformers.modeling_utils import PreTrainedModel -from transformers.trainer import Trainer as HFTrainer -from transformers import TrainingArguments +from transformers.pipelines import Pipeline, pipeline as pipeline_factory +from transformers.pipelines.table_question_answering import ( + TableQuestionAnsweringPipeline, +) +import ray.cloudpickle as cpickle from ray.ml.predictor import DataBatchType, Predictor from ray.ml.preprocessor import Preprocessor from ray.ml.checkpoint import Checkpoint -from ray.ml.utils.huggingface_checkpoint_utils import load_huggingface_checkpoint class HuggingFacePredictor(Predictor): """A predictor for HuggingFace Transformers PyTorch models. + This predictor uses Transformers Pipelines for inference. + Args: - model: The Transformers model to use for predictions. + pipeline: The Transformers pipeline to use for inference. preprocessor: A preprocessor used to transform data batches prior to prediction. - training_args: ``transformers.TrainingArguments`` to use for the prediction. - trainer_class: ``transformers.Trainer`` subclass to use for prediction. - Defaults to ``transformers.Trainer``. """ def __init__( self, - model: Union[PreTrainedModel, torch.nn.Module], + pipeline: Optional[Pipeline] = None, preprocessor: Optional[Preprocessor] = None, - *, - training_args: Optional[TrainingArguments] = None, - trainer_class: HFTrainer = HFTrainer, ): - self.model = model + self.pipeline = pipeline self.preprocessor = preprocessor - self.training_args = training_args - self.trainer_class = trainer_class @classmethod def from_checkpoint( cls, checkpoint: Checkpoint, - model: Union[Type[PreTrainedModel], torch.nn.Module], *, - training_args: Optional[TrainingArguments] = None, - trainer_class: HFTrainer = HFTrainer, - **pretrained_model_kwargs, + pipeline: Optional[Type[Pipeline]] = None, + **pipeline_kwargs, ) -> "HuggingFacePredictor": """Instantiate the predictor from a Checkpoint. @@ -58,62 +51,72 @@ def from_checkpoint( checkpoint: The checkpoint to load the model and preprocessor from. It is expected to be from the result of a ``HuggingFaceTrainer`` run. - model: Either a ``transformers.PreTrainedModel`` class - (eg. ``AutoModelForCausalLM``), or a PyTorch model to load the - weights to. This should be the same model used for training. - training_args: ``transformers.TrainingArguments`` to use for the prediction. - Defaults to training arguments saved inside the checkpoint. - trainer_class: ``transformers.Trainer`` subclass to use for prediction. - Defaults to ``transformers.Trainer``. - **pretrained_model_kwargs: Any kwargs to pass to the - ``model.from_pretrained()`` call. Only used if - ``model`` is a ``PreTrainerModel`` class. + pipeline: A ``transformers.pipelines.Pipeline`` class to use. + If not specified, will use the ``pipeline`` abstraction + wrapper. + **pipeline_kwargs: Any kwargs to pass to the pipeline + initialization. If ``pipeline`` is None, this must contain + the 'task' argument. Cannot contain 'model'. """ - model, preprocessor, loaded_training_args = load_huggingface_checkpoint( - checkpoint, model, **pretrained_model_kwargs - ) - training_args = training_args or loaded_training_args + if not pipeline and "task" not in pipeline_kwargs: + raise ValueError( + "If `pipeline` is not specified, 'task' must be passed as a " "kwarg." + ) + pipeline = pipeline or pipeline_factory + with checkpoint.as_directory() as checkpoint_path: + preprocessor_path = os.path.join(checkpoint_path, PREPROCESSOR_KEY) + if os.path.exists(preprocessor_path): + with open(preprocessor_path, "rb") as f: + preprocessor = cpickle.load(f) + else: + preprocessor = None + pipeline = pipeline(model=checkpoint_path, **pipeline_kwargs) return HuggingFacePredictor( - model=model, + pipeline=pipeline, preprocessor=preprocessor, - training_args=training_args, - trainer_class=trainer_class, ) - def to_transformers_trainer(self, **trainer_kwargs) -> HFTrainer: - """Converts this predictor to a ``transformers.Trainer``. - - Args: - **trainer_kwargs: Any kwargs to pass to the - ``trainer_class`` initialization. ``model`` and - ``args`` are preset. - """ - if self.training_args: - self.training_args.local_rank = -1 - trainer = self.trainer_class( - model=self.model, args=self.training_args, **trainer_kwargs - ) - return trainer - - def _predict(self, dataset: HFDataset) -> pd.DataFrame: - trainer = self.to_transformers_trainer() - ret = trainer.predict(dataset).predictions - # TODO(yard1): Return just a numpy array once that's supported - # by Ray Datasets - df = pd.DataFrame([ret.tolist()]).T - df.columns = ["predictions"] + def _predict( + self, data: Union[list, pd.DataFrame], **pipeline_call_kwargs + ) -> pd.DataFrame: + ret = self.pipeline(data, **pipeline_call_kwargs) + # Remove unnecessary lists + try: + new_ret = [x[0] if isinstance(x, list) and len(x) == 1 else x for x in ret] + df = pd.DataFrame(new_ret) + except Exception: + # if we fail for any reason, just give up + df = pd.DataFrame(ret) + df.columns = [str(col) for col in df.columns] return df + def _convert_data_for_pipeline( + self, data: pd.DataFrame + ) -> Union[list, pd.DataFrame]: + """Convert the data into a format accepted by the pipeline. + + In most cases, this format is a list of strings.""" + # Special case + if isinstance(self.pipeline, TableQuestionAnsweringPipeline): + return data + # Otherwise, a list of columns as lists + columns = [data[col].to_list() for col in data.columns] + # Flatten if it's only one column + if len(columns) == 1: + columns = columns[0] + return columns + def predict( self, data: DataBatchType, feature_columns: Optional[List[str]] = None, - dtype: Optional[Union[Dict[str, np.dtype], np.dtype]] = None, + **pipeline_call_kwargs, ) -> DataBatchType: """Run inference on data batch. - The data is converted into a HuggingFace ``datasets.Dataset`` - and passed to a ``transformers.Trainer.predict()`` method. + The data is converted into a list (unless ``pipeline`` is a + ``TableQuestionAnsweringPipeline``) and passed to the ``pipeline`` + object. Args: data: A batch of input data. Either a pandas DataFrame or numpy @@ -121,67 +124,34 @@ def predict( feature_columns: The names or indices of the columns in the data to use as features to predict on. If None, use all columns. - dtype: The numpy dtypes to cast the data to. - Can be either a single dtype or a dict of ``column:dtype``. - If set to None, then automatically infer the dtype. + **pipeline_call_kwargs: additional kwargs to pass to the + ``pipeline`` object. Examples: .. code-block:: python - import numpy as np - from datasets import load_dataset + import pandas as pd from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer - from ray.ml.predictors.huggingface import HuggingFacePredictor + from transformers.pipelines import pipeline + from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor model_checkpoint = "gpt2" tokenizer_checkpoint = "sgugger/gpt2-like-tokenizer" - block_size = 128 - - datasets = load_dataset("wikitext", "wikitext-2-raw-v1") tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) - def tokenize_function(examples): - return tokenizer(examples["text"]) - - tokenized_datasets = datasets.map( - tokenize_function, batched=True, num_proc=1, remove_columns=["text"] - ) - - def group_texts(examples): - # Concatenate all texts. - concatenated_examples = { - k: sum(examples[k], []) for k in examples.keys() - } - total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, we could add padding if the model - # supported it. - # instead of this drop, you can customize this part to your needs. - total_length = (total_length // block_size) * block_size - # Split by chunks of max_len. - result = { - k: [ - t[i : i + block_size] - for i in range(0, total_length, block_size) - ] - for k, t in concatenated_examples.items() - } - result["labels"] = result["input_ids"].copy() - return result - - lm_datasets = tokenized_datasets.map( - group_texts, - batched=True, - batch_size=1000, - num_proc=1, - ) model_config = AutoConfig.from_pretrained(model_checkpoint) model = AutoModelForCausalLM.from_config(model_config) predictor = HuggingFacePredictor( - model=model, preprocessor=preprocessor + pipeline=pipeline( + task="text-generation", model=model, tokenizer=tokenizer + ) ) - predictions = predictor.predict(lm_datasets["validation"].to_pandas()) + prompts = pd.DataFrame( + ["Complete me", "And me", "Please complete"], columns=["sentences"] + ) + predictions = predictor.predict(prompts) Returns: @@ -195,8 +165,6 @@ def group_texts(examples): data = pd.DataFrame(data) data = data[feature_columns] if feature_columns else data - if dtype: - data = data.astype(dtype) - dataset = HFDataset.from_pandas(data) - return self._predict(dataset) + data = self._convert_data_for_pipeline(data) + return self._predict(data, **pipeline_call_kwargs) diff --git a/python/ray/ml/tests/test_huggingface_predictor.py b/python/ray/ml/tests/test_huggingface_predictor.py index 4d0455ef3f4a..2b760bbc9a57 100644 --- a/python/ray/ml/tests/test_huggingface_predictor.py +++ b/python/ray/ml/tests/test_huggingface_predictor.py @@ -1,18 +1,24 @@ import pandas as pd import pytest -from transformers import AutoConfig, AutoModelForCausalLM, TrainingArguments + +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, +) +from transformers.pipelines import pipeline from ray.ml.preprocessor import Preprocessor from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor -from ray.ml.tests._huggingface_data import validation_data - -# 16 first rows of tokenized wikitext-2-raw-v1 validation -validation_df = pd.read_json(validation_data) +prompts = pd.DataFrame( + ["Complete me", "And me", "Please complete"], columns=["sentences"] +) # We are only testing Casual Language Modelling here model_checkpoint = "sshleifer/tiny-gpt2" +tokenizer_checkpoint = "sgugger/gpt2-like-tokenizer" class DummyPreprocessor(Preprocessor): @@ -22,24 +28,24 @@ def transform_batch(self, df): @pytest.mark.parametrize("preprocessor", [True, False]) -@pytest.mark.parametrize("training_args", [True, False]) -def test_predict(preprocessor, training_args, tmpdir): +def test_predict(preprocessor, tmpdir): if preprocessor: preprocessor = DummyPreprocessor() else: preprocessor = None - if training_args: - training_args = TrainingArguments(tmpdir) - else: - training_args = None model_config = AutoConfig.from_pretrained(model_checkpoint) model = AutoModelForCausalLM.from_config(model_config) predictor = HuggingFacePredictor( - model=model, preprocessor=preprocessor, training_args=training_args + pipeline=pipeline( + task="text-generation", + model=model, + tokenizer=AutoTokenizer.from_pretrained(tokenizer_checkpoint), + ), + preprocessor=preprocessor, ) - predictions = predictor.predict(validation_df) + predictions = predictor.predict(prompts) - assert len(predictions) == 16 + assert len(predictions) == 3 if preprocessor: assert hasattr(predictor.preprocessor, "_batch_transformed") diff --git a/python/ray/ml/tests/test_huggingface_trainer.py b/python/ray/ml/tests/test_huggingface_trainer.py index 55bc3a4c04b6..9b9c382455c5 100644 --- a/python/ray/ml/tests/test_huggingface_trainer.py +++ b/python/ray/ml/tests/test_huggingface_trainer.py @@ -1,13 +1,17 @@ import pandas as pd import pytest -import torch -from datasets.arrow_dataset import Dataset -from transformers import AutoConfig, AutoModelForCausalLM, Trainer, TrainingArguments + +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + Trainer, + TrainingArguments, +) import ray.data from ray.ml.train.integrations.huggingface import HuggingFaceTrainer from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor -from ray.ml.train.integrations.huggingface.huggingface_utils import process_datasets from ray.ml.batch_predictor import BatchPredictor from ray.ml.tests._huggingface_data import train_data, validation_data @@ -15,10 +19,14 @@ # 16 first rows of tokenized wikitext-2-raw-v1 training & validation train_df = pd.read_json(train_data) validation_df = pd.read_json(validation_data) +prompts = pd.DataFrame( + ["Complete me", "And me", "Please complete"], columns=["sentences"] +) # We are only testing Casual Language Modelling here model_checkpoint = "sshleifer/tiny-gpt2" +tokenizer_checkpoint = "sgugger/gpt2-like-tokenizer" @pytest.fixture @@ -81,8 +89,11 @@ def test_e2e(ray_start_4_cpus, save_strategy): assert result2.checkpoint predictor = BatchPredictor.from_checkpoint( - result2.checkpoint, HuggingFacePredictor, model=AutoModelForCausalLM + result2.checkpoint, + HuggingFacePredictor, + task="text-generation", + tokenizer=AutoTokenizer.from_pretrained(tokenizer_checkpoint), ) - predictions = predictor.predict(ray_validation) - assert predictions.count() == 16 + predictions = predictor.predict(ray.data.from_pandas(prompts)) + assert predictions.count() == 3 diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 8ce3ac494729..3d2278909db4 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -107,11 +107,11 @@ class HuggingFaceTrainer(TorchTrainer): the ``get_train_dataloader`` method will be overriden to disable distributed sampling, as the dataset will already be sharded. - Hugging Face loggers will be automatically disabled, and the ``local_rank`` + HuggingFace loggers will be automatically disabled, and the ``local_rank`` argument in ``TrainingArguments`` will be automatically set. Please note that if you want to use CPU training, you will need to set the ``no_cuda`` argument in ``TrainingArguments`` manually - otherwise, an exception - may be thrown. + (segfault) may be thrown. Example: .. code-block:: python @@ -380,7 +380,6 @@ def _huggingface_train_loop_per_worker(config): eval_dataset, ) - # TODO(yard1): Automatically set `no_cuda` somehow trainer: transformers.trainer.Trainer = trainer_init_per_worker( train_torch_dataset, eval_torch_dataset, **config ) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_utils.py b/python/ray/ml/train/integrations/huggingface/huggingface_utils.py index 503c5b672fce..a18b0efcab6a 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_utils.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_utils.py @@ -41,7 +41,10 @@ def wrap_transformers_trainer( base_trainer_class: Type[transformers.trainer.Trainer] = trainer.__class__ class RayTrainer(base_trainer_class): + # TODO(yard1): Upstream data collator removing unused columns to + # transformers. def _prepare_data_collator(self): + # Hack to set the self._signature_columns attribute. try: self._remove_unused_columns(None, description="nan") except AttributeError: @@ -97,8 +100,10 @@ def get_train_dataloader(self): return trainer +# TODO(ml-team): Replace with a Ray Datasets-HuggingFace integration when available. class RayDatasetHFIterable(datasets.iterable_dataset.ExamplesIterable): """HF ExamplesIterable backed by a Ray Dataset.""" + def __init__(self, dataset: Dataset) -> None: self.dataset = dataset self.generate_examples_fn = self.dataset.iter_rows @@ -133,7 +138,7 @@ def process_datasets( train_dataset: Dataset, eval_dataset: Dataset, ) -> Tuple[IterableDataset, IterableDataset]: - """Convert Ray train and validation to HF-friendly IterableDatasets.""" + """Convert Ray train and validation to HF IterableDatasets.""" train_torch_dataset = process_dataset_for_hf(train_dataset) if eval_dataset: @@ -149,7 +154,7 @@ class TrainReportCallback(TrainerCallback): def __init__(self) -> None: # HF first logs metrics, and then checkpoints. With Ray AIR, we need the - # opposite. Furthermore, some metrics are logged in several calls. + # opposite. Furthermore, some metrics are logged just at the end. # Therefore, if we detect that a checkpoint will be created, # we delay the train.report call after the checkpoint is reported # to Ray Train. @@ -159,36 +164,23 @@ def __init__(self) -> None: def on_step_end(self, args, state, control, **kwargs): if control.should_training_stop: - # always save at end + # Always save at the end. control.should_save = True return control def on_log(self, args, state, control, model=None, logs=None, **kwargs): + # Log is called in multiple places (evaluation, train metrics). report = {**logs, "step": state.global_step, "epoch": state.epoch} if not self.first_report_keys: self.first_report_keys = set(report) - # if saving or evaluation is coming, delay reporting - if ( - control.should_save - or control.should_evaluate - or not set(report).issuperset(self.first_report_keys) - ): - self.delayed_report.update(report) - else: + # if saving or training end is coming, delay reporting + if not control.should_save and not control.should_training_stop: train.report(**report) - - def on_evaluate(self, args, state, control, **kwargs): - # saving comes after evaluation, so report if we - # aren't going to save - if ( - self.delayed_report - and set(self.delayed_report).issuperset(self.first_report_keys) - and not control.should_save - ): - train.report(**self.delayed_report) - self.delayed_report = {} + else: + self.delayed_report.update(report) def on_save(self, args, state, control, **kwargs): + # Save is called after evaluation. checkpoint_path = Path( transformers.trainer.get_last_checkpoint(args.output_dir) ).absolute() @@ -199,6 +191,12 @@ def on_save(self, args, state, control, **kwargs): CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), } ) + if self.delayed_report and not control.should_training_stop: + train.report(**self.delayed_report) + self.delayed_report = {} + + def on_train_end(self, args, state, control, **kwargs): + # Final callback. Train metrics are logged right before this. if self.delayed_report: train.report(**self.delayed_report) self.delayed_report = {} From 8f085b443717e5eeb71605dbaa92ff9eead739fa Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 27 Apr 2022 21:14:36 +0000 Subject: [PATCH 58/75] Clarify --- .../integrations/huggingface/huggingface_utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_utils.py b/python/ray/ml/train/integrations/huggingface/huggingface_utils.py index a18b0efcab6a..5cc48578f0e5 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_utils.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Callable, Optional, Tuple, Type +from typing import Any, Optional, Tuple, Type import datasets.iterable_dataset import transformers.trainer @@ -43,18 +43,18 @@ def wrap_transformers_trainer( class RayTrainer(base_trainer_class): # TODO(yard1): Upstream data collator removing unused columns to # transformers. + # This is necessary to provide the same experience as with a + # non-iterable HuggingFace Dataset, which can remove columns + # not supported by the model. def _prepare_data_collator(self): + """Wrap the data collator in a function removing superflous columns.""" # Hack to set the self._signature_columns attribute. try: self._remove_unused_columns(None, description="nan") except AttributeError: pass - self.data_collator = self._get_remove_columns_data_collator() - - def _get_remove_columns_data_collator(self) -> Callable: if self._signature_columns and not hasattr(self, "_original_data_collator"): - self._original_data_collator = self.data_collator def remove_columns_collator(features): @@ -71,7 +71,8 @@ def remove_columns_collator(features): collator = remove_columns_collator else: collator = self.data_collator - return collator + + self.data_collator = collator def get_train_dataloader(self): if self.train_dataset is None: @@ -79,7 +80,7 @@ def get_train_dataloader(self): train_dataset = self.train_dataset - # While we are not sharding the datasets again, this + # While we are not sharding the train dataset again, this # class ensures that the last batch has a consistent size. train_dataset = transformers.trainer.IterableDatasetShard( train_dataset, From 2126556e011c8a85b44b8971e67eaf6e8f53f5f2 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 27 Apr 2022 21:33:52 +0000 Subject: [PATCH 59/75] Remove shuffle mention from docstring --- .../ml/train/integrations/huggingface/huggingface_trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 3d2278909db4..7cf1380532ac 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -99,10 +99,6 @@ class HuggingFaceTrainer(TorchTrainer): shards, with each Actor training on a single shard. All the other datasets will not be split. - The datasets will NOT be shuffled by default. Call ``Dataset.random_shuffle()`` - on the "train" dataset you are passing in ``datasets`` if you wish for the - training data to be shuffled. - Please note that if you use a custom ``transformers.Trainer`` subclass, the ``get_train_dataloader`` method will be overriden to disable distributed sampling, as the dataset will already be sharded. From 0abab40a6172acd76b06126214e1527ee0af34ec Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 28 Apr 2022 20:33:23 +0000 Subject: [PATCH 60/75] Doc fix --- doc/source/custom_directives.py | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index efd26278396d..8ef5f4b4bac4 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -92,6 +92,7 @@ def update_context(app, pagename, templatename, context, doctree): "ConfigSpace", "dask.distributed", "datasets", + "datasets.iterable_dataset", "gym", "gym.spaces", "horovod", From 3c6936771821a02046c585e5ecd86c53c0dbc7ee Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 28 Apr 2022 20:33:30 +0000 Subject: [PATCH 61/75] Upgrade torch version --- python/requirements/ml/requirements_dl.txt | 8 ++++---- python/requirements_ml_docker.txt | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/requirements/ml/requirements_dl.txt b/python/requirements/ml/requirements_dl.txt index 04d41897a8be..7ac6c4b2462d 100644 --- a/python/requirements/ml/requirements_dl.txt +++ b/python/requirements/ml/requirements_dl.txt @@ -6,10 +6,10 @@ tensorflow-probability==0.14.1 # If you make changes to the torch versions below, please also make the corresponding changes to `requirements_ml_docker.txt`! -torch==1.9.0;sys_platform=="darwin" -torchvision==0.10.0;sys_platform=="darwin" +torch==1.11.0;sys_platform=="darwin" +torchvision==0.12.0;sys_platform=="darwin" # On non-OSX machines only install CPU version of torch and torchvision -f https://download.pytorch.org/whl/torch_stable.html -torch==1.9.0+cpu;sys_platform!="darwin" -torchvision==0.10.0+cpu;sys_platform!="darwin" +torch==1.11.0+cpu;sys_platform!="darwin" +torchvision==0.12.0+cpu;sys_platform!="darwin" diff --git a/python/requirements_ml_docker.txt b/python/requirements_ml_docker.txt index c38150ad56c2..01fa9b478314 100644 --- a/python/requirements_ml_docker.txt +++ b/python/requirements_ml_docker.txt @@ -5,8 +5,8 @@ tblib # If you make changes to the torch versions, please also make the corresponding changes to `requirements_dl.txt`! -f https://download.pytorch.org/whl/torch_stable.html -torch==1.9.0+cu111 -torchvision==0.10.0+cu111 +torch==1.11.0+cu111 +torchvision==0.12.0+cu111 -f https://data.pyg.org/whl/torch-1.9.0+cu111.html torch-scatter==2.0.9 From dc4fb411f2a4ae3fb7cd4c092c31d136c0ec8fe4 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 28 Apr 2022 23:47:53 +0200 Subject: [PATCH 62/75] Update requirements_ml_docker.txt --- python/requirements_ml_docker.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/requirements_ml_docker.txt b/python/requirements_ml_docker.txt index 01fa9b478314..236e0fe2fd78 100644 --- a/python/requirements_ml_docker.txt +++ b/python/requirements_ml_docker.txt @@ -5,8 +5,8 @@ tblib # If you make changes to the torch versions, please also make the corresponding changes to `requirements_dl.txt`! -f https://download.pytorch.org/whl/torch_stable.html -torch==1.11.0+cu111 -torchvision==0.12.0+cu111 +torch==1.11.0 +torchvision==0.12.0 -f https://data.pyg.org/whl/torch-1.9.0+cu111.html torch-scatter==2.0.9 From 2c8387839b6060c60abbf233306e0f24c403cc2c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 28 Apr 2022 23:48:31 +0200 Subject: [PATCH 63/75] Update requirements_dl.txt --- python/requirements/ml/requirements_dl.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/requirements/ml/requirements_dl.txt b/python/requirements/ml/requirements_dl.txt index 7ac6c4b2462d..31c4d03b7967 100644 --- a/python/requirements/ml/requirements_dl.txt +++ b/python/requirements/ml/requirements_dl.txt @@ -10,6 +10,6 @@ torch==1.11.0;sys_platform=="darwin" torchvision==0.12.0;sys_platform=="darwin" # On non-OSX machines only install CPU version of torch and torchvision --f https://download.pytorch.org/whl/torch_stable.html -torch==1.11.0+cpu;sys_platform!="darwin" -torchvision==0.12.0+cpu;sys_platform!="darwin" +-f https://download.pytorch.org/whl/cpu +torch==1.11.0;sys_platform!="darwin" +torchvision==0.12.0;sys_platform!="darwin" From 58d81367c14616d87e34b748c15b66745294d1b0 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 28 Apr 2022 23:59:54 +0200 Subject: [PATCH 64/75] Revert --- python/requirements_ml_docker.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/requirements_ml_docker.txt b/python/requirements_ml_docker.txt index 236e0fe2fd78..c38150ad56c2 100644 --- a/python/requirements_ml_docker.txt +++ b/python/requirements_ml_docker.txt @@ -5,8 +5,8 @@ tblib # If you make changes to the torch versions, please also make the corresponding changes to `requirements_dl.txt`! -f https://download.pytorch.org/whl/torch_stable.html -torch==1.11.0 -torchvision==0.12.0 +torch==1.9.0+cu111 +torchvision==0.10.0+cu111 -f https://data.pyg.org/whl/torch-1.9.0+cu111.html torch-scatter==2.0.9 From 790a31e490a5dac7db76052cd3b31f1fe27fc1d2 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 29 Apr 2022 00:00:13 +0200 Subject: [PATCH 65/75] Revert --- python/requirements/ml/requirements_dl.txt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/requirements/ml/requirements_dl.txt b/python/requirements/ml/requirements_dl.txt index 31c4d03b7967..04d41897a8be 100644 --- a/python/requirements/ml/requirements_dl.txt +++ b/python/requirements/ml/requirements_dl.txt @@ -6,10 +6,10 @@ tensorflow-probability==0.14.1 # If you make changes to the torch versions below, please also make the corresponding changes to `requirements_ml_docker.txt`! -torch==1.11.0;sys_platform=="darwin" -torchvision==0.12.0;sys_platform=="darwin" +torch==1.9.0;sys_platform=="darwin" +torchvision==0.10.0;sys_platform=="darwin" # On non-OSX machines only install CPU version of torch and torchvision --f https://download.pytorch.org/whl/cpu -torch==1.11.0;sys_platform!="darwin" -torchvision==0.12.0;sys_platform!="darwin" +-f https://download.pytorch.org/whl/torch_stable.html +torch==1.9.0+cpu;sys_platform!="darwin" +torchvision==0.10.0+cpu;sys_platform!="darwin" From 3ed8551f51804984bfbbf45747e350a1990d4f70 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 29 Apr 2022 00:03:22 +0200 Subject: [PATCH 66/75] Update huggingface_basic_language_modeling_example.py --- .../huggingface_basic_language_modeling_example.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py index 9d8615f6d226..9e7f35fb843f 100644 --- a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py +++ b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py @@ -186,10 +186,12 @@ def train_function(train_dataset, eval_dataset=None, **config): args = parser.parse_args() + # Requires at least torch 1.11 to pass + runtime_env = {"pip": ["torch==1.11.0"]} if args.address: - ray.init(args.address) + ray.init(args.address, runtime_env=runtime_env) else: - ray.init() + ray.init(runtime_env=runtime_env) main( model_checkpoint=args.model_checkpoint, From 3bc93e4f2d46bbd1cc99375d47dec6ebc35e20dd Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 29 Apr 2022 00:23:16 +0200 Subject: [PATCH 67/75] Update huggingface_basic_language_modeling_example.py --- .../huggingface_basic_language_modeling_example.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py index 9e7f35fb843f..a89d9cca2b1e 100644 --- a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py +++ b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py @@ -19,6 +19,7 @@ import ray.data from ray.ml.train.integrations.huggingface import HuggingFaceTrainer from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor +from ray.ml.batch_predictor import BatchPredictor def main( @@ -116,11 +117,15 @@ def train_function(train_dataset, eval_dataset=None, **config): tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) prompt = ["My text: Complete me..."] - predictor = HuggingFacePredictor.from_checkpoint( - results.checkpoint, task="text-generation", tokenizer=tokenizer + predictor = BatchPredictor.from_checkpoint( + results.checkpoint, + HuggingFacePredictor, + task="text-generation", + tokenizer=tokenizer ) - prediction = predictor.predict(pd.DataFrame(prompt, columns=["prompt"])) - prediction = prediction.iloc[0]["generated_text"] + data = ray.data.from_pandas(pd.DataFrame(prompt, columns=["prompt"])) + prediction = predictor.predict(data, num_gpus_per_worker=int(use_gpu)) + prediction = prediction.to_pandas().iloc[0]["generated_text"] print(f"Generated text for prompt '{prompt}': '{prediction}'") From e48794c43d7810a341e0d918cb725d23c3a7a16c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 29 Apr 2022 01:40:23 +0200 Subject: [PATCH 68/75] Update huggingface_basic_language_modeling_example.py --- .../huggingface/huggingface_basic_language_modeling_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py index a89d9cca2b1e..7a5438faa573 100644 --- a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py +++ b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py @@ -121,7 +121,7 @@ def train_function(train_dataset, eval_dataset=None, **config): results.checkpoint, HuggingFacePredictor, task="text-generation", - tokenizer=tokenizer + tokenizer=tokenizer, ) data = ray.data.from_pandas(pd.DataFrame(prompt, columns=["prompt"])) prediction = predictor.predict(data, num_gpus_per_worker=int(use_gpu)) From 42d99dd0ba8cbddac726faaa47fd0cd515102bba Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 29 Apr 2022 03:21:48 +0200 Subject: [PATCH 69/75] Update custom_directives.py --- doc/source/custom_directives.py | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index 8ef5f4b4bac4..dcaedf604c87 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -93,6 +93,7 @@ def update_context(app, pagename, templatename, context, doctree): "dask.distributed", "datasets", "datasets.iterable_dataset", + "transformers.pipelines", "gym", "gym.spaces", "horovod", From cf322b212d445249e95e1ecb48f6423beb8775fb Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 29 Apr 2022 09:55:19 +0200 Subject: [PATCH 70/75] Update custom_directives.py --- doc/source/custom_directives.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index dcaedf604c87..d5686e8c5bf3 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -93,7 +93,6 @@ def update_context(app, pagename, templatename, context, doctree): "dask.distributed", "datasets", "datasets.iterable_dataset", - "transformers.pipelines", "gym", "gym.spaces", "horovod", @@ -135,6 +134,8 @@ def update_context(app, pagename, templatename, context, doctree): "transformers.modeling_utils", "transformers.models", "transformers.models.auto", + "transformers.pipelines", + "transformers.pipelines.table_question_answering", "transformers.trainer", "transformers.training_args", "transformers.trainer_callback", From 71d2f1b61a5fc599f1be3b4c8934e0687f08c0a8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 29 Apr 2022 18:31:10 +0000 Subject: [PATCH 71/75] Better checkpoint detection --- .../huggingface/huggingface_trainer.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 7cf1380532ac..203f8f1b4839 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -311,7 +311,7 @@ def _validate_attributes(self): super()._validate_attributes() - def _convert_directory_checkpoint_to_sync( + def _convert_directory_checkpoint_to_sync_if_needed( self, checkpoint: Checkpoint ) -> Checkpoint: """Replace the directory checkpoint with a node ip & path dict checkpoint @@ -320,8 +320,19 @@ def _convert_directory_checkpoint_to_sync( with checkpoint.as_directory() as checkpoint_path: # Load checkpoint from path. checkpoint_path = Path(checkpoint_path).expanduser().absolute() - if not checkpoint_path.exists(): - raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.") + if not checkpoint_path.joinpath(TUNE_CHECKPOINT_ID).exists(): + # If the ID file is missing, we assume that this is already + # a sync checkpoint + dict_checkpoint = checkpoint.to_dict() + if ( + NODE_IP_KEY not in dict_checkpoint + or CHECKPOINT_PATH_ON_NODE_KEY not in dict_checkpoint + ): + raise ValueError( + "Wrong checkpoint format. Ensure the checkpoint is a " + "result of `HuggingFaceTrainer`." + ) + return checkpoint with open(checkpoint_path.joinpath(TUNE_CHECKPOINT_ID), "r") as f: tune_checkpoint_id = int(f.read()) @@ -334,13 +345,11 @@ def _convert_directory_checkpoint_to_sync( ) def setup(self) -> None: - if ( - self.resume_from_checkpoint - and self.resume_from_checkpoint.get_internal_representation() - == "local_path" - ): - self.resume_from_checkpoint = self._convert_directory_checkpoint_to_sync( - self.resume_from_checkpoint + if self.resume_from_checkpoint: + self.resume_from_checkpoint = ( + self._convert_directory_checkpoint_to_sync_if_needed( + self.resume_from_checkpoint + ) ) def as_trainable(self) -> Type[Trainable]: @@ -351,7 +360,9 @@ def as_trainable(self) -> Type[Trainable]: if resume_from_checkpoint: self._param_dict[ "resume_from_checkpoint" - ] = self._convert_directory_checkpoint_to_sync(resume_from_checkpoint) + ] = self._convert_directory_checkpoint_to_sync_if_needed( + resume_from_checkpoint + ) try: ret = super().as_trainable() finally: From 3b4593924e1f4e0cb4752ec3fba6fd236b0b2c28 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 29 Apr 2022 20:34:31 +0200 Subject: [PATCH 72/75] Apply suggestions from code review Co-authored-by: Amog Kamsetty --- .../integrations/huggingface/huggingface_predictor.py | 2 +- python/ray/ml/tests/test_huggingface_predictor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py index a2f4700162ac..38933cf22a5e 100644 --- a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py +++ b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py @@ -60,7 +60,7 @@ def from_checkpoint( """ if not pipeline and "task" not in pipeline_kwargs: raise ValueError( - "If `pipeline` is not specified, 'task' must be passed as a " "kwarg." + "If `pipeline` is not specified, 'task' must be passed as a kwarg." ) pipeline = pipeline or pipeline_factory with checkpoint.as_directory() as checkpoint_path: diff --git a/python/ray/ml/tests/test_huggingface_predictor.py b/python/ray/ml/tests/test_huggingface_predictor.py index 2b760bbc9a57..8818ff808c6c 100644 --- a/python/ray/ml/tests/test_huggingface_predictor.py +++ b/python/ray/ml/tests/test_huggingface_predictor.py @@ -15,7 +15,7 @@ ["Complete me", "And me", "Please complete"], columns=["sentences"] ) -# We are only testing Casual Language Modelling here +# We are only testing Casual Language Modeling here model_checkpoint = "sshleifer/tiny-gpt2" tokenizer_checkpoint = "sgugger/gpt2-like-tokenizer" From b26f67fb7cbfd8d927803804be7eb3e68ca7c383 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 29 Apr 2022 18:55:22 +0000 Subject: [PATCH 73/75] Add context --- .../huggingface_basic_language_modeling_example.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py index 7a5438faa573..1b8c49175721 100644 --- a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py +++ b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py @@ -32,9 +32,12 @@ def main( use_gpu=False, smoke_test=False, ): - # block_size = tokenizer.model_max_length block_size = 128 + # Uncomment the following if the maximum length thr model was + # pretrained with can fit in your memory. + # block_size = tokenizer.model_max_length + # Run this as a remote function to avoid downloading on the driver @ray.remote def get_dataset(): From 2c8f48c28fd8e594f33032c91a9b6b45b3422e2e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 29 Apr 2022 18:55:51 +0000 Subject: [PATCH 74/75] Load huggingface checkpoint to staticmethod --- .../integrations/huggingface/__init__.py | 3 +- .../huggingface/huggingface_trainer.py | 87 +++++++++++++++++-- .../ml/utils/huggingface_checkpoint_utils.py | 54 ------------ 3 files changed, 82 insertions(+), 62 deletions(-) delete mode 100644 python/ray/ml/utils/huggingface_checkpoint_utils.py diff --git a/python/ray/ml/train/integrations/huggingface/__init__.py b/python/ray/ml/train/integrations/huggingface/__init__.py index 37d23738ba3e..b2344f98dcb3 100644 --- a/python/ray/ml/train/integrations/huggingface/__init__.py +++ b/python/ray/ml/train/integrations/huggingface/__init__.py @@ -1,6 +1,5 @@ from ray.ml.train.integrations.huggingface.huggingface_trainer import ( HuggingFaceTrainer, ) -from ray.ml.utils.huggingface_checkpoint_utils import load_huggingface_checkpoint -__all__ = ["HuggingFaceTrainer", "load_huggingface_checkpoint"] +__all__ = ["HuggingFaceTrainer"] diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 203f8f1b4839..7376cc7b7ff9 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -4,10 +4,15 @@ import tempfile from distutils.version import LooseVersion from pathlib import Path -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from ray.ml.utils.torch_utils import load_torch_model +import torch import transformers +import transformers.modeling_utils import transformers.trainer +from transformers.trainer import WEIGHTS_NAME, TRAINING_ARGS_NAME +import transformers.training_args import ray.cloudpickle as cpickle from torch.utils.data import Dataset as TorchDataset @@ -314,9 +319,11 @@ def _validate_attributes(self): def _convert_directory_checkpoint_to_sync_if_needed( self, checkpoint: Checkpoint ) -> Checkpoint: - """Replace the directory checkpoint with a node ip & path dict checkpoint - used to sync the directory. If we use a directory checkpoint directly, - it will get deepcopied & serialized unnecessarily.""" + """Replace the directory checkpoint with a node ip & path dict checkpoint. + + This dict checkpoint will be used used to sync the directory. + If we were to use a directory checkpoint directly, it would get deepcopied & + serialized unnecessarily.""" with checkpoint.as_directory() as checkpoint_path: # Load checkpoint from path. checkpoint_path = Path(checkpoint_path).expanduser().absolute() @@ -369,6 +376,74 @@ def as_trainable(self) -> Type[Trainable]: self._param_dict = original_param_dict return ret + @staticmethod + def load_huggingface_checkpoint( + checkpoint: Checkpoint, + model: Union[ + Type[transformers.modeling_utils.PreTrainedModel], torch.nn.Module + ], + tokenizer: Optional[Type[transformers.PreTrainedTokenizer]] = None, + *, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + **pretrained_model_kwargs, + ) -> Tuple[ + Union[transformers.modeling_utils.PreTrainedModel, torch.nn.Module], + transformers.training_args.TrainingArguments, + Optional[transformers.PreTrainedTokenizer], + Optional[Preprocessor], + ]: + """Load a Checkpoint from ``HuggingFaceTrainer``. + + Return the model, ``TrainingArguments``, tokenizer and AIR preprocessor + contained within. Those can be used to initialize a ``transformers.Trainer`` + object locally. + + Args: + checkpoint: The checkpoint to load the model and + preprocessor from. It is expected to be from the result of a + ``HuggingFaceTrainer`` run. + model: Either a ``transformers.PreTrainedModel`` class + (eg. ``AutoModelForCausalLM``), or a PyTorch model to load the + weights to. This should be the same model used for training. + tokenizer: A ``transformers.PreTrainedTokenizer`` class to load + the model tokenizer to. If not specified, the tokenizer will + not be loaded. Will throw an exception if specified, but no + tokenizer was found in the checkpoint. + tokenizer_kwargs: Dict of kwargs to pass to ``tokenizer.from_pretrained`` + call. Ignored if ``tokenizer`` is None. + **pretrained_model_kwargs: Kwargs to pass to ``mode.from_pretrained`` + call. Ignored if ``model`` is not a ``transformers.PreTrainedModel`` + class. + """ + tokenizer_kwargs = tokenizer_kwargs or {} + with checkpoint.as_directory() as checkpoint_path: + preprocessor_path = os.path.join(checkpoint_path, PREPROCESSOR_KEY) + if os.path.exists(preprocessor_path): + with open(preprocessor_path, "rb") as f: + preprocessor = cpickle.load(f) + else: + preprocessor = None + if isinstance(model, torch.nn.Module): + state_dict = torch.load( + os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu" + ) + model = load_torch_model(saved_model=state_dict, model_definition=model) + else: + model = model.from_pretrained( + checkpoint_path, **pretrained_model_kwargs + ) + if tokenizer: + tokenizer = tokenizer.from_pretrained( + checkpoint_path, **tokenizer_kwargs + ) + training_args_path = os.path.join(checkpoint_path, TRAINING_ARGS_NAME) + if os.path.exists(training_args_path): + with open(training_args_path, "rb") as f: + training_args = torch.load(f, map_location="cpu") + else: + training_args = None + return model, training_args, tokenizer, preprocessor + def _huggingface_train_loop_per_worker(config): """Per-worker training loop for HuggingFace Transformers.""" @@ -395,8 +470,8 @@ def _huggingface_train_loop_per_worker(config): raise ValueError( "`push_to_hub` parameter in `TrainingArgs` is not supported by " "`HuggingFaceTrainer`. If you would like to push your model to hub " - "after training, use the `load_huggingface_checkpoint` function " - "to obtain the model from a returned checkpoint, and use it to " + "after training, use the `HuggingFaceTrainer.load_huggingface_checkpoint`" + " method to obtain the model from a returned checkpoint, and use it to " "instantiate the `transformers.Trainer` class." ) diff --git a/python/ray/ml/utils/huggingface_checkpoint_utils.py b/python/ray/ml/utils/huggingface_checkpoint_utils.py deleted file mode 100644 index 46cd2102ef54..000000000000 --- a/python/ray/ml/utils/huggingface_checkpoint_utils.py +++ /dev/null @@ -1,54 +0,0 @@ -import os -from typing import Tuple, Type, Union - -import torch -from transformers.modeling_utils import PreTrainedModel -from transformers.trainer import WEIGHTS_NAME, TRAINING_ARGS_NAME -from transformers import TrainingArguments - -import ray.cloudpickle as cpickle -from ray.ml.preprocessor import Preprocessor -from ray.ml.checkpoint import Checkpoint -from ray.ml.utils.torch_utils import load_torch_model -from ray.ml.constants import PREPROCESSOR_KEY -from ray.util.annotations import PublicAPI - - -@PublicAPI(stability="alpha") -def load_huggingface_checkpoint( - checkpoint: Checkpoint, - model: Union[Type[PreTrainedModel], torch.nn.Module], - **pretrained_model_kwargs -) -> Tuple[Union[PreTrainedModel, torch.nn.Module], Preprocessor, TrainingArguments]: - """Load a Checkpoint from ``HuggingFaceTrainer`` and return the - model, preprocessor and ``TrainingArguments`` contained within. - - Args: - checkpoint: The checkpoint to load the model and - preprocessor from. It is expected to be from the result of a - ``HuggingFaceTrainer`` run. - model: Either a ``transformers.PreTrainedModel`` class - (eg. ``AutoModelForCausalLM``), or a PyTorch model to load the - weights to. This should be the same model used for training. - """ - with checkpoint.as_directory() as checkpoint_path: - preprocessor_path = os.path.join(checkpoint_path, PREPROCESSOR_KEY) - if os.path.exists(preprocessor_path): - with open(preprocessor_path, "rb") as f: - preprocessor = cpickle.load(f) - else: - preprocessor = None - if isinstance(model, torch.nn.Module): - state_dict = torch.load( - os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu" - ) - model = load_torch_model(saved_model=state_dict, model_definition=model) - else: - model = model.from_pretrained(checkpoint_path, **pretrained_model_kwargs) - training_args_path = os.path.join(checkpoint_path, TRAINING_ARGS_NAME) - if os.path.exists(training_args_path): - with open(training_args_path, "rb") as f: - training_args = torch.load(f, map_location="cpu") - else: - training_args = None - return model, preprocessor, training_args From b549dbe63b02b217ee4697be481ae89455acb899 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Fri, 29 Apr 2022 14:31:33 -0700 Subject: [PATCH 75/75] Update python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py --- .../huggingface/huggingface_basic_language_modeling_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py index 1b8c49175721..178b7b202a4e 100644 --- a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py +++ b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py @@ -34,7 +34,7 @@ def main( ): block_size = 128 - # Uncomment the following if the maximum length thr model was + # Uncomment the following if the maximum length the model was # pretrained with can fit in your memory. # block_size = tokenizer.model_max_length